diff --git a/README.md b/README.md index 456067c82..7704dbd5b 100644 --- a/README.md +++ b/README.md @@ -1,148 +1,147 @@ -# Welcome to LLAISYS +# 欢迎使用 LLAISYS

English中文

-## Introduction +## 简介 -LLAISYS (Let's Learn AI SYStem) is an educational project that aims to provide a platform for new and future AI engineers to learn how to build AI systems from scratch. LLAISYS consists of several assignments, which help students learn and build the basic modules, and projects that challenge them to add more fancy features to their systems. LLAISYS uses C++ as primary programming language for system backend, and is compiled into shared libraries exposing C language APIs. Frontend codes are written in Python which calls these APIs to provide more convenient testing and interaction with other architectures such as PyTorch. +LLAISYS(Let's Learn AI SYStem)是一个教育项目,旨在为新手和未来的AI工程师提供一个从零开始构建AI系统的学习平台。LLAISYS包含多个作业,帮助学生学习和构建基础模块;以及一些项目挑战,让他们为系统添加更多高级功能。LLAISYS使用C++作为系统后端的主要编程语言,并编译成共享库,提供C语言API。前端代码使用Python编写,调用这些API以提供更便捷的测试和与其他架构(如PyTorch)的交互。 -### Project Structure Overview +### 项目结构概览 -- `\include`: directory that contains of the header files which defines all the C APIs exposed by the shared library. (Functions declarations start with `__export`) +- `\include`:包含所有定义共享库提供的C API的头文件的目录。(函数声明以`__export`开头) -- `\src`: C++ source files. - - `\src\llaisys` contains all the direct implementation of waht are defined in the header files and follows the same directory structure as the `\include`. This is also as far as C++ codes can go. - - other directories contain the actual implementaion of different modules. +- `\src`:C++源文件。 + - `\src\llaisys`包含头文件中定义的所有直接实现,并遵循与`\include`相同的目录结构。这也是C++代码的边界。 + - 其他目录包含不同模块的实际实现。 -- `xmake.lua`: build rules for llaisys backend. `\xmake` directory contains the sub-xmake files for different devices. You may add `nvidia.lua` in the directory in the future for instance to support CUDA. +- `xmake.lua`:llaisys后端的构建规则。`\xmake`目录包含不同设备的子xmake文件。例如,将来可以在目录中添加`nvidia.lua`来支持CUDA。 -- `\python`: Python source files. - - `\python\llaisys\libllaisys` contains all the ctypes wrapper functions of llaisys APIs. It basically matches the structure of C header files. - - `\python\llaisys` contains Python warppers of the ctypes functions to make the package more Python-like. +- `\python`:Python源文件。 + - `\python\llaisys\libllaisys`包含llaisys API的所有ctypes封装函数。它基本上与C头文件的结构相匹配。 + - `\python\llaisys`包含ctypes函数的Python包装器,使包更符合Python风格。 -- `\test`: Python test files that import llaisys python package. +- `\test`:导入llaisys python包的Python测试文件。 -## Assignment #0: Getting Started +## 作业 #0:入门 -### Task-0.1 Install Prerequisites +### 任务-0.1 安装必备组件 -- Compile Tool: [Xmake](https://xmake.io/) -- C++ Compiler: MSVC (Windows) or Clang or GCC -- Python >= 3.9 (PyTorch, Transformers, etc.) -- Clang-Format-16 (Optional): for formatting C++ codes. +- 编译工具:[Xmake](https://xmake.io/) +- C++编译器:MSVC(Windows)或Clang或GCC +- Python >= 3.9(PyTorch、Transformers等) +- Clang-Format-16(可选):用于格式化C++代码。 -### Task-0.2 Fork and Build LLAISYS +### 任务-0.2 Fork并构建LLAISYS -- FORK LLAISYS Repository and Clone it to your local machine. Both Windows and Linux are supported. +- Fork LLAISYS仓库并克隆到本地机器。支持Windows和Linux。 -- Compile and Install +- 编译和安装 ```bash - # compile c++ codes + # 编译c++代码 xmake - # install llaisys shared library + # 安装llaisys共享库 xmake install - # install llaisys python package + # 安装llaisys python包 pip install ./python/ ``` -- Github Auto Tests +- Github自动测试 - LLAISYS uses Github Actions to run automated tests on every push and pull request. You can see testing results on your repo page. All tests should pass once you have finished all assignment tasks. + LLAISYS使用Github Actions在每次推送和拉取请求时运行自动化测试。你可以在仓库页面上看到测试结果。完成所有作业任务后,所有测试都应该通过。 -### Task-0.3 Run LLAISYS for the First Time +### 任务-0.3 首次运行LLAISYS -- Run cpu runtime tests +- 运行cpu运行时测试 ```bash python test/test_runtime.py --device cpu ``` - You should see the test passed. + 你应该看到测试通过。 -### Task-0.4 Download test model +### 任务-0.4 下载测试模型 -- The model we use for assignments is [DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B). +- 我们用于作业的模型是[DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)。 -- Run an inference test with the model using PyTorch +- 使用PyTorch运行模型推理测试 ```bash python test/test_infer.py --model [dir_path/to/model] ``` - You can see that PyTorch is able to load the model and perform inference with the sample input. You can debug into `transformers` library codes to see how what is going on behind. Right now, your code cannot do anything yet, but you are going to build a system that can achieve the same functionality in the assignments. + 你可以看到PyTorch能够加载模型并使用示例输入执行推理。你可以调试进入`transformers`库代码来深入查看并了解其内部运作原理。现在,你的代码还无法执行任何操作,但在后续的作业中,你将构建一个能够实现相同功能的系统。 -## Assignment #1: Tensor +## 作业 #1:张量 -Tensor is a data structure that represents multi-dimensional data. It is the basic building block of LLAISYS, and most AI frameworks such as PyTorch. In this assignment, you will learn how to implement a basic tensor class. +张量是表示多维数据的数据结构。它是LLAISYS和大多数AI框架(如PyTorch)的基本构建单元。在这个作业中,你将学习如何实现一个基本的张量类。 -A Tensor object has the following fields: +张量对象具有以下字段: -- `storage`: a shared pointer to a memory block that stores the tensor's data. It can be shared by multiple tensors. Check storage class for more details. -- `offset`: the starting index (in bytes) of the tensor in the storage. -- `meta`: metadata that describes the tensor's shape, data type, and strides. +- `storage`:指向存储张量数据的内存块的共享指针。它可以被多个张量共享。有关更多详细信息,请查看storage类。 +- `offset`:张量在存储中的起始索引(以字节为单位)。 +- `meta`:描述张量形状、数据类型和步长的元数据。 -Implement the following functions defined in the `src/tensor/tensor.hpp`: +实现`src/tensor/tensor.hpp`中定义的以下函数: -### Task-1.1 +### 任务-1.1 ```c++ void load(const void *src); ``` -Load host (cpu) data to the tensor (can be on device). Check contructor to see how to get runtime apis of the current device context, and do a memcpy from host to device. +将主机(cpu)数据加载到张量(可以在设备上)。查看构造函数了解如何获取当前设备上下文的运行时API,并执行从主机到设备的内存复制。 -### Task-1.2 +### 任务-1.2 ```c++ bool isContiguous() const; ``` -Check shape and strides of the tensor, and tell wether it is contiguous in memory. +检查张量的形状和步长,判断它在内存中是否连续。 -### Task-1.3 +### 任务-1.3 ```c++ tensor_t view(const std::vector &shape) const; ``` -Create a new tensor which reshapes the original tensor to the given shape by splitting or merging the original dimensions. No data transfer is involved. For example change a tensor of shape (2, 3, 5) to (2, 15) by merging the last two dimensions. +创建一个新张量,通过拆分或合并原始维度将原始张量重塑为给定形状。不涉及数据传输。例如,通过合并最后两个维度,将形状为(2, 3, 5)的张量更改为(2, 15)。 -This function is not as easy as simply changing the shape of the tensor, although the test will pass. It should raise an error if new view is not compatible with the original tensor. Think about a tensor of shape (2, 3, 5) and strides (30, 10, 1). Can you still reshape it to (2, 15) without data transfer? +这个函数不是简单地改变张量的形状那么简单,尽管测试会通过。如果新视图与原始张量不兼容,它应该引发错误。想想一个形状为(2, 3, 5)、步长为(30, 10, 1)的张量。你还能在不传输数据的情况下将其重塑为(2, 15)吗? -### Task-1.4 +### 任务-1.4 ```c++ tensor_t permute(const std::vector &order) const; ``` -Create a new tensor which changes the order of the dimensions of original tensor. Transpose can be achieved by this function without moving data around. +创建一个新张量,改变原始张量维度的顺序。转置可以通过这个函数实现,而无需移动数据。 -### Task-1.5 +### 任务-1.5 ```c++ tensor_t slice(size_t dim, size_t start, size_t end) const; ``` -Create a new tensor which slices the original tensor along the given dimension, -start (inclusive) and end (exclusive) indices. +创建一个新张量,沿给定维度,start(包含)和end(不包含)索引对原始张量进行切片操作。 -### Task-1.6 +### 任务-1.6 -Run tensor tests. +运行张量测试。 ```bash python test/test_tensor.py ``` -You should see all tests passed. Commit and push your changes. You should see the auto tests for assignment #1 passed. +你应该看到所有测试都通过了。提交并推送你的更改。你应该看到作业#1的自动测试通过了。 -## Assignment #2: Operators +## 作业 #2:算子 -In this assignment, you will implement the cpu verision the following operators: +在这个作业中,你将实现以下算子的cpu版本: - argmax - embedding @@ -152,102 +151,102 @@ In this assignment, you will implement the cpu verision the following operators: - self_attention - swiglu -Read the codes in `src/ops/add/` to see how "add" operator is implemented. Make sure you understand how the operator codes are organized, compiled, linked, and exposed to Python frontend. **Your operators should at least support Float32, Float16 and BFloat16 data types**. A helper function for naive type casting is provided in `src/utils/`. All python tests are in `test/ops`, you implementation should at least pass these tests. Try running the test script for "add" operator for starting. +阅读`src/ops/add/`中的代码,了解"add"算子是如何实现的。确保你理解算子代码是如何组织、编译、链接以及暴露给Python前端的。**你的算子应该至少支持Float32、Float16和BFloat16数据类型**。`src/utils/`中提供了一个用于简单类型转换的辅助函数。所有python测试都在`test/ops`中,你的实现应该至少通过这些测试。首先尝试运行"add"算子的测试脚本。 -### Task-2.1 argmax +### 任务-2.1 Argmax ```c++ void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals); ``` -Get the max value and its index of tensor `vals`, and store them in `max_val` and `max_idx` respectively. You can assume that `vals` is a 1D tensor for now, and `max_idx` and `max_val` are both 1D tensors with a single element (, which means the dimension of `vals` is kept). +获取张量`vals`的最大值及其索引,并分别存储在`max_val`和`max_idx`中。你暂时可以假设`vals`是一个1D张量,`max_idx`和`max_val`都是包含单个元素的1D张量(这意味着保留了`vals`的维度)。 -You should be able to pass the test cases in `test/ops/argmax.py` after you finish the implementation. +完成实现后,你应该能够通过`test/ops/argmax.py`中的测试用例。 -### Task-2.2 embedding +### 任务-2.2 Embedding ```c++ void embedding(tensor_t out, tensor_t index, tensor_t weight); ``` -Copy the rows in `index` (1-D) from `weight` (2-D) to `output` (2-D). `index` must be of type Int64 (the default data type for int of PyTorch). +从`weight`(2-D)中复制`index`(1-D)中的行到`output`(2-D)。`index`必须是Int64类型(PyTorch中int的默认数据类型)。 -You should be able to pass the test cases in `test/ops/embedding.py` after you finish the implementation. +完成实现后,你应该能够通过`test/ops/embedding.py`中的测试用例。 -### Task-2.3 linear +### 任务-2.3 Linear ```c++ void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias); ``` -Compute the following: +计算以下内容: $$ Y = xW^T + b $$ -- `out`: output $Y$ . You can assume output is a 2D contiguous tensor and no broadcasting is involved for now. -- `input`: input $X$ . You can assume input is a 2D contiguous tensor and no broadcasting is involved for now. -- `weight`: weight $W$ . 2D contiguous tensor. Note that weight tensor is not transposed. You need to deal with this during your calculation. -- `bias` (optional): bias $b$ . 1D tensor. You need to support the situation where bias is not provided. +- `out`:输出 $Y$ 。你暂时可以假设输出是一个2D连续张量,不涉及广播。 +- `input`:输入 $X$ 。你暂时可以假设输入是一个2D连续张量,不涉及广播。 +- `weight`:权重 $W$ 。2D连续张量。注意权重张量没有转置。你需要在计算过程中处理这个问题。 +- `bias`(可选):偏置 $b$ 。1D张量。你需要支持不提供偏置的情况。 -You should be able to pass the test cases in `test/ops/linear.py` after you finish the implementation. +完成实现后,你应该能够通过`test/ops/linear.py`中的测试用例。 -### Task-2.4 rms normalization +### 任务-2.4 RMS Normalization ```c++ void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps); ``` -Compute the following for each row: +为每一行计算以下内容: $$ Y_i = \frac{W_i \times X_i}{\sqrt{\frac{1}{d}(\sum_{j=1}^d X_j^2) + \epsilon}} $$ -- `out`: output $Y$ . You can assume output is a 2D contiguous tensor and no broadcasting is involved for now. -- `input`: input $X$ . You can assume input is a 2D contiguous tensor and no broadcasting is involved for now. The normalization is performed along the last dimension (a.k.a. each row of length $d$ ) of the input tensor. -- `weight`: weight $W$ . 1D tensor, same length as a row of input tensor. -- `eps`: small value $\epsilon$ to avoid division by zero. +- `out`:输出 $Y$ 。你暂时可以假设输出是一个2D连续张量,不涉及广播。 +- `input`:输入 $X$ 。你暂时可以假设输入是一个2D连续张量,不涉及广播。标准化沿输入张量的最后一个维度(即每一行,长度为 $d$ )执行。 +- `weight`:权重 $W$ 。1D张量,与输入张量的一行长度相同。 +- `eps`:小值 $\epsilon$ 以避免除以零。 -You should be able to pass the test cases in `test/ops/rms_norm.py` after you finish the implementation. +完成实现后,你应该能够通过`test/ops/rms_norm.py`中的测试用例。 -### Task-2.5 rope +### 任务-2.5 旋转位置编码(RoPE) ```c++ void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta); ``` -Compute the following for each vector of input tensor `in`, corresponding to a position id in `pos_ids`: +为输入张量`in`的每个向量(这些向量与 pos_ids 中的位置 id 相对应)计算以下内容: -Let $\mathbf{x}_i = [\mathbf{a}_i, \mathbf{b}_i] \in \mathbb{R}^d$ be the input vector and $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i] \in \mathbb{R}^d$ be the output vector at index $i$, where $\mathbf{a}_i, \mathbf{b}_i,\mathbf{a}'_i, \mathbf{b}'_i \in \mathbb{R}^{d/2}$ . +设 $\mathbf{x}_i = [\mathbf{a}_i, \mathbf{b}_i] \in \mathbb{R}^d$ 为输入向量, $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i] \in \mathbb{R}^d$ 为索引 $i$ 处的输出向量,其中 $\mathbf{a}_i, \mathbf{b}_i,\mathbf{a}'_i, \mathbf{b}'_i \in \mathbb{R}^{d/2}$ 。 -Let $\theta$ be a fixed base (e.g. $\theta = 10000$) and $j = 0, 1, \ldots, d/2 - 1$. +设 $\theta$ 为固定基数(例如 $\theta = 10000$), $j = 0, 1, \ldots, d/2 - 1$。 -Let $p_i \in \mathbb{N}$ is the position id for token at input index i. +设 $p_i \in \mathbb{N}$ 是输入索引i处token的位置id。 -Then the angle for RoPE is $\phi_{i,j} = \frac{p_i}{\theta^{2j/d}}$ +那么RoPE的角度为 $\phi_{i,j} = \frac{p_i}{\theta^{2j/d}}$ -The output vector $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i]$ is computed as follows: +输出向量 $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i]$ 计算如下: $$a_{i,j}' = a_{i,j} \cos(\phi_{i,j}) - b_{i,j} \sin(\phi_{i,j})$$ $$b_{i,j}' = b_{i,j} \cos(\phi_{i,j}) + a_{i,j} \sin(\phi_{i,j})$$ -- `out`: the resulting **q** or **k** tensor. Shape should be [seqlen, nhead, d] or [seqlen, nkvhead, d]. You can assume that the tensor is contiguous for now. -- `in`: the orignal **q** or **k** tensor. Shape should be [seqlen, nhead, d] or [seqlen, nkvhead, d]. You can assume that the tensor is contiguous for now. -- `pos_ids`: the position id (index in the whole context) for each token in the input sequence. Shape should be [seqlen,], dtype should be int64. -- `theta`: the base value for the frequency vector. +- `out`:结果**q**或**k**张量。形状应该是 [seqlen, nhead, d] 或 [seqlen, nkvhead, d]。你暂时可以假设张量是连续的。 +- `in`:原始**q**或**k**张量。形状应该是 [seqlen, nhead, d] 或 [seqlen, nkvhead, d]。你暂时可以假设张量是连续的。 +- `pos_ids`:输入序列中每个token的位置id(整个上下文中的索引)。形状应该是 [seqlen,],dtype应该是int64。 +- `theta`:频率向量的基值。 -You should be able to pass the test cases in `test/ops/rope.py` after you finish the implementation. +完成实现后,你应该能够通过`test/ops/rope.py`中的测试用例。 -### Task-2.6 self-attention +### 任务-2.6 自注意力(self-attention) ```c++ void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale); ``` -Compute the self-attention for query tensor `q`, key tensor `k`, and value tensor `v`. You should concat kvcache tensors, if needed, before doing this calculation. +为查询张量`q`、键张量`k`和值张量`v`计算自注意力。如果需要,你应该在进行此计算之前连接kvcache张量。 $$ A = Q K^\top * scale \\ @@ -257,108 +256,109 @@ $$ Y = \mathrm{causalsoftmax}(A) \cdot V \\ $$ -- `attn_val`: the resulting attention value tensor. Shape should be [seqlen, nhead, dv]. You can assume that the tensor is contiguous for now. -- `q`: the query tensor. Shape should be [seqlen, nhead, d]. You can assume that the tensor is contiguous for now. -- `k`: the key tensor. Shape should be [total_len, nkvhead, d]. You can assume that the tensor is contiguous for now. -- `v`: the value tensor. Shape should be [total_len, nkvhead, dv]. You can assume that the tensor is contiguous for now. -- `scale`: a scaling factor. It is set to $\frac{1}{\sqrt{d}}$ in most cases. +- `attn_val`:结果注意力值张量。形状应该是[seqlen, nhead, dv]。你暂时可以假设张量是连续的。 +- `q`:查询张量。形状应该是 [seqlen, nhead, d]。你暂时可以假设张量是连续的。 +- `k`:键张量。形状应该是 [total_len, nkvhead, d]。你暂时可以假设张量是连续的。 +- `v`:值张量。形状应该是 [total_len, nkvhead, dv]。你暂时可以假设张量是连续的。 +- `scale`:缩放因子。在大多数情况下取值为 $\frac{1}{\sqrt{d}}$ 。 -You should be able to pass the test cases in `test/ops/self_attention.py` after you finish the implementation. +完成实现后,你应该能够通过`test/ops/self_attention.py`中的测试用例。 -### Task-2.7 swiglu +### 任务-2.7 SwiGLU ```c++ void swiglu(tensor_t out, tensor_t gate, tensor_t up); ``` -This is an element-wise function that computes the following: +这是一个逐元素函数,计算以下内容: $$ out_{i} = up_{i} \circ \frac { gate_{i}}{1 + e^{-gate_{i}}} $$ -`out`, `up` and `gate` are 2D contiguous tensors with the same shape [seqlen, intermediate_size]. +`out`、`up`和`gate`是具有相同形状 [seqlen, intermediate_size] 的2D连续张量。 -You should be able to pass the test cases in `test/ops/swiglu.py` after you finish the implementation. +完成实现后,你应该能够通过`test/ops/swiglu.py`中的测试用例。 -### Task-2.8 +### 任务-2.8 -Run operator tests. +运行算子测试。 ```bash python test/test_ops.py ``` -You should see all tests passed. Commit and push your changes. You should see the auto tests for assignment #2 passed. +你应该看到所有测试都通过了。提交并推送你的更改。你应该看到作业#2的自动测试通过了。 -### Task-2.9 (Optional) rearrange +### 任务-2.9(可选)rearrange -This is a bonus task. You may or may not need it for model inference. +这是一个奖励任务。你在模型推理中可能需要也可能不需要它。 ```c++ void rearrange(tensor_t out, tensor_t in); ``` -This operator is used to copy data from a tensor to another tensor with the same shape but different strides. With this, you can easily implement `contiguous` functionality for tensors. +此算子用于将数据从一个张量复制到另一个具有相同形状但不同步长的张量。有了这个,你可以轻松地为张量实现`contiguous`功能。 -## Assignment #3: Large Language Model Inference +## 作业 #3:大语言模型推理 -Finally, it is the time for you to achieve text generation with LLAISYS. +终于,是时候用LLAISYS实现文本生成了。 -- In `test/test_infer.py`, your implementation should be able to generate the same texts as PyTorch, using argmax sampling. The model we use for this assignment is [DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B). +- 在`test/test_infer.py`中,你的实现应该能够使用argmax采样生成与PyTorch相同的文本。我们用于此作业的模型是[DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)。 -- The python wrapper of your implementation is in `python/llaisys/models/qwen2.py`. You are NOT allowed to implement your model infer logic here using any python based frameworks, such as PyTorch. Instead, you need to implement the model with C/C++ in LLAISYS backend. The script loads each tensor in the safetensors file, and you will need to load data from them into your model backend. +- 你的实现的python包装器在`python/llaisys/models/qwen2.py`中。你不允许在这里使用任何基于python的框架(如PyTorch)实现你的模型推理逻辑。相反,你需要在LLAISYS后端用C/C++实现模型。脚本加载safetensors文件中的每个张量,你需要从它们加载数据到你的模型后端。 -- In `include/llaisys/models/qwen2.h`, a prototype is defined for you. Feel free to modify the codes as you want, but you should at least provide basic APIs for model creation, destruction, data loading, and infer. Implement your C APIs in `src/llaisys/` and organize your C++ codes as other modules in `src/`. Remember to define the compiling procedures in `xmake.lua`. +- 在`include/llaisys/models/qwen2.h`中,为你定义了一个原型。你可以随意修改代码,但你应该至少提供模型创建、销毁、数据加载和推理的基本API。在`src/llaisys/`中实现你的C API,并像`src/`中的其他模块一样组织你的C++代码。记得在`xmake.lua`中定义编译过程。 -- In `python/llaisys/libllaisys/`, define the ctypes wrapper functions for your C APIs. Implement `python/llaisys/models/qwen2.py` with your wrapper functions. +- 在`python/llaisys/libllaisys/`中,为你的C API定义ctypes包装函数。使用你的包装函数实现`python/llaisys/models/qwen2.py`。 -- You need to implement KV Cache, or your model will be too slow. +- 你需要实现 KV-Cache 功能,否则模型推理速度会过慢。 -- Debug until your model works. Take advantage of tensor's `debug` function which prints the tensor data. It allows you to compare the data of any tensor during the model inference with PyTorch. +- 调试直到你的模型工作。利用张量的`debug`函数打印张量数据。它允许你在模型推理期间将任何张量的数据与PyTorch进行比较。 -After you finish the implementation, you can run the following command to test your model: +完成实现后,你可以运行以下命令来测试你的模型: ```bash python test/test_infer.py --model [dir_path/to/model] --test ``` -Commit and push your changes. You should see the auto tests for assignment #3 passed. +提交并推送你的更改。你应该看到作业#3的自动测试通过了。 +## 只有完成作业后,才能开始做项目。 -## You can proceed to the projects only after you finish the assignments. +## 项目#1:优化 LLAISYS 的 CPU 推理 -## Project #1: Optimize LLAISYS for CPU -You probably have already noticed that your model inference is very slow compared to PyTorch. This is mostly because your operators are not optimized. Run your operater test scripts with "--profile" flag to see how your operators perform. You would probably see that `linear` operation is much slower than PyTorch. This operator is mainly a matrix multiplication, and is the most time consuming operation in transformer-based models. +你可能已经注意到,你的模型推理速度相比 PyTorch 非常慢。这主要是因为你的算子没有经过优化。运行算子测试脚本时加上 ``--profile`` 参数,看看算子的性能表现。你可能会发现 ``linear`` 操作比 PyTorch 慢很多。这个算子本质上是矩阵乘法,是 Transformer 模型里最耗时的操作。 -There are several ways to optimize your operators for CPU: +以下是几种优化 CPU 算子的方法: -### SIMD instructions +### 使用 SIMD 指令 -SIMD (Single Instruction Multiple Data) instructions are instructions that can perform the same operation on multiple data elements in a single instruction. Modern CPUs have support for SIMD instructions. Look for online materials to learn about compiler intrinsics (such as AVX2, AVX-512, NEON, SVE) to vectorize your operations. +SIMD(单指令多数据)是一类可以在单条指令中对多个数据元素同时执行相同操作的指令。现代 CPU 都支持 SIMD。你可以查阅相关资料,学习编译器内建函数(如 AVX2、AVX-512、NEON、SVE)来向量化你的算子。 -### Use OpenMP for parallelism +### 使用 OpenMP 实现并行 -You can use multi-threading to parallelize your operators. OpenMP is a popular library for multi-threading in C/C++. Add OpenMP support for LLAISYS to parallelize your `linear` and other operators. +你可以用多线程来并行化算子。OpenMP 是 C/C++ 中常见的多线程库。为 LLAISYS 增加 OpenMP 支持,使得 ``linear`` 等算子能够并行执行。 -### 3rd-party Libraries +### 使用第三方库 -There are several libraries that can help you optimize your operators for CPU. Look for libraries like Eigen, OpenBLAS, MKL, etc. to optimize your linear algebra operations. Note that some libraries are supported only for certain hardware platforms. Check their documentations and use them in your codes with care. You can also try to dig out how PyTorch implement these operators and see if you can use them. +有很多库能帮你优化 CPU 上的算子,例如 Eigen、OpenBLAS、MKL 等,它们能高效处理线性代数运算。但要注意,有些库只支持特定硬件平台,需要仔细阅读文档并小心使用。你也可以参考 PyTorch 的算子实现,看是否能复用。 -Optimize your implementation with any methods you like and report your performance improvement. +用任何你喜欢的方法优化你的推理实现,并报告性能提升情况。 -## Project #2: Intigrate CUDA into LLAISYS +## 项目#2:在 LLAISYS 中集成 CUDA,适配两款CUDA或类CUDA平台(以下统称CUDA) -This project does not depend on **Project #1**. You should choose two CUDA/CUDA-ish hardware platforms from Nvidia, Iluvatar, Metax, and Moore Threads. +这个项目不依赖 ``项目#1``。需要选择 Nvidia、天数、摩尔、沐曦中的至少两款平台。 -This camp session provides computation resources from the four platforms above, access to which is granted based on applications from the official website. You can accelerate your model with CUDA on these GPU platforms. Before doing that, let's dive deeper into LLAISYS framework. +本次训练营提供了以上四种平台的算力,可以在官方进行申请算力,并用 CUDA 加速模型推理。在动手前,先深入理解 LLAISYS 框架。 -LLAISYS is actually a framework with homogeous hardware support. When using LLAISYS, each thread will create a thread-local `Context` object which manages all the device `Runtime` objects used by this thread. A `Runtime` object is a resource manager for a device, and `Context` will create (with lazy initialization) a single `Runtime` object for each device. You can set and switch between them using `setDevice` function in `Context`. Only one device will be active at a time for each thread. Check `src/core/context.hpp` for more details. +事实上,LLAISYS 是一个支持同构硬件的框架。使用时,每个线程会创建一个线程唯一的 **Context** 对象,管理该线程使用的所有设备 **Runtime**。**Runtime** 对象是设备的资源管理器,**Context** 会为每个设备(以延迟初始化的方式)创建唯一的 **Runtime**。你可以用 ``setDevice`` 在不同设备间切换,每个线程同一时间只会激活一个设备。详情见 ``src/core/context.hpp``。 -### Implement CUDA Runtime APIs -Each `Runtime` object is intialized with a set of generic functions called `Runtime APIs`. You will need to implement CUDA version of these APIS. Check `src/device/cpu/cpu_runtime_api.cpp` to see how these functions are implemented for CPU and look for CUDA APIs to use in [`CUDA Runtime documentation`](https://docs.nvidia.com/cuda/cuda-runtime-api/index.html). +### 实现 CUDA Runtime API -You can see in `src/device/runtime_api.hpp` that `nvidia::getRuntimeAPI()` is guarded by `ENABLE_NVIDIA_API` macro. +每个 **Runtime** 对象都会初始化一组通用的 **Runtime API**。你需要实现 CUDA 版本的 API。参考 ``src/device/cpu/cpu_runtime_api.cpp`` 看 CPU 的实现方式,查阅 [`CUDA Runtime 文档`](https://docs.nvidia.com/cuda/cuda-runtime-api/index.html) 找到对应 API。 + +在 ``src/device/runtime_api.hpp`` 中,``nvidia::getRuntimeAPI()`` 被 ``ENABLE_NVIDIA_API`` 宏保护: ```c++ #ifdef ENABLE_NVIDIA_API @@ -368,9 +368,9 @@ const LlaisysRuntimeAPI *getRuntimeAPI(); #endif ``` -This macro is defined in `xmake.lua` as a switch to enable/disable CUDA support. CUDA codes will not be compiled if the switch is off. In `xmake/` directory, create a `nvidia.lua` that configs your compiling process. (Similar to `cpu.lua` for CPU.) Search online to learn how to do it with Xmake. +该宏的定义在 ``xmake.lua`` 中,用于开关 CUDA 支持。若关闭,CUDA 代码不会被编译。你需要在 ``xmake/`` 下新建 ``nvidia.lua``,配置编译流程(参考 ``cpu.lua``)。查阅资料学习如何用 Xmake 配置。 -After you implement the CUDA Runtime APIs, config your xmake with `--nv-gpu=y` to enable CUDA support and recompile your program. Run runtime tests to see if your implementation works. +完成 CUDA Runtime API 后,用 ``--nv-gpu=y`` 打开 CUDA 支持并重新编译,运行测试: ```bash xmake f --nv-gpu=y -cv @@ -379,53 +379,54 @@ xmake install python test/test_runtime.py --device nvidia ``` -### Implement CUDA Operators -Create a `nvdia/` sub-directory in each operator source directory and implement a cuda version. Check `src/ops/add/op.cpp` to see how to include your cuda implementations. Remeber to define the compiling procedures in the xmake files. Run the operator tests with `--device nvidia` flag to test your CUDA implementation. +### 实现 CUDA 算子 + +在每个算子目录下新建 ``nvidia/`` 子目录,写 CUDA 版本实现。参考 ``src/ops/add/op.cpp`` 看如何包含 CUDA 实现。别忘了在 xmake 文件中定义编译流程。用 ``--device nvidia`` 参数运行测试。 -You can use CUDA libraries like cuBLAS, cuDNN, etc. to accelerate your operators. Check their documentations to see how to use them. You can store extra device resources in `src/device/nvidia/nvidia_resource.cu`. +你可以使用 cuBLAS、cuDNN 等 CUDA 库来加速算子,额外的设备资源可以放在 `src/device/nvidia/nvidia_resource.cu`。 -Modify your model codes to support CUDA inference. +最后,修改模型代码,支持 CUDA 推理: ```bash python test/test_infer.py --model [dir_path/to/model] --test --device nvidia ``` -## Project #3: Build an AI chatbot +## 项目#3:构建 AI 聊天机器人 -In this project you will build an AI chatbot that can do live conversations with single user with LLAISYS. +本项目中,你将用 LLAISYS 构建一个能与单用户实时对话的聊天机器人。 -### Random Sampling +### 随机采样 -So far we have been testing our model with argmax sampling. This is good enough for testing, but a chatbot should be able to generate more natural responses. Implement a random sample operator. Try to add supports for **Temperature**, **Top-K** and **Top-P**. +目前我们只用过 argmax 采样,这在测试时够用,但聊天机器人需要更自然的回复。请实现一个随机采样算子,并尽量支持 **Temperature**、**Top-K**、**Top-P**。 -### Build a Chatbot Server +### 搭建聊天服务器 -In your Python frontend, implement a server that can receive http requests from user and send responses back. You can use frameworks like FastAPI to build the server. You should follow the OpenAI chat-completion APIs. Try to support streaming responses if you can. You can assume, for now, that the server is only serving one user, and block the endpoint until the previous request is served. +在 Python 前端里,实现一个能接收 HTTP 请求并返回响应的服务器。可以用 FastAPI 等框架。接口最好遵循 OpenAI 的 chat-completion API。如果可以,尽量支持流式输出。你可以先假设只有一个用户在使用,每次请求可以阻塞直到处理完成。 +### 交互式聊天 UI -### Interactive Chat UI +实现一个 UI,能向服务器发送请求并接收回复。可以是命令行界面,也可以是 Web 界面。要能通过连续发送消息与机器人保持对话。 -Build a UI that send requests to and receive responses from the chatbot server. You can build a simple command-line interface or a fancy web interface. You should be able to keep a conversation going with the chatbot by sending messages and receiving responses consecutively. +### (可选)会话管理 -### (Optional) Chat Session Management +实际应用中,用户可以开启多个对话并在它们之间切换,还能修改历史问题让 AI 重新生成回答。扩展 UI,支持这些功能。实现一个支持前缀匹配的 KV-Cache 池,尽可能复用已有结果。 -In real-world AI applications, users are allowed to start new conversations and switch between them. Users can also edit a past question and let the AI regenerate an answer. Enhance your UI to support these features. Implement a KV-Cache pool with prefix matching to reuse past results as much as possible. +## 项目#4:多用户推理服务 +在做这个项目之前,你需要完成 ``项目#3`` 并实现流式输出。 -## Project #4: Multi-user Inference Service +### 支持多用户 -You need to finish **Project #2** and achieve streaming response first before proceeding to this project. +现实中推理服务要同时为多个用户提供服务,请求可能随时到来。你的服务端需要将请求加入请求池/队列,并用单独的循环线程/进程来处理。 -### Serving Multiple Users +### 连续批处理 -In real-world scenarios, an inference service will serve multiple users. Requests can come in at any time, and the service should be able to handle them concurrently. Your endpoint should add a new request to a request pool or queue and have a another looping process or thread to serve the requests. +为了最大化吞吐量,你需要做批处理,而不是逐一处理。由于每个请求长度不同,需要实现连续的迭代级批处理机制:每轮从池中取出若干请求组成批次(batch),执行一次批量推理,再把未完成的请求放回池中。推理时尽量用批量矩阵乘法加速。注意每个请求需要绑定不同的 KV-Cache,应实现支持前缀匹配的 KV-Cache 池来复用结果。 -### Continous Batching -To maximize the throughput of your inference service, you need to batch your requests instead of serving them one by one. Since each request can have different length, you will need a continous and iteration-level batching mechanism. For each interation you extract several requests from pool to form a batch, do one round of batch inference, and then return the unfinished requests back to the pool. Use batched matrix multiplication when possible to speed up your inference. Note that every request in the batch need to bind with a different KV-Cache. You should build a KV-Cache pool with prefix matching to reuse past results as much as possible. +## 项目#5:分布式推理 -## Project #5: Distributed Inference -Introduce Tensor Parallelism to LLAISYS. Shard your model across multiple devices and implement distributed model inference. Support NCCL in LLAISYS if your are uing Nvidia GPUs, or MPI if you are using CPUs. +在 LLAISYS 中引入张量并行。把模型分片到多个设备上,实现分布式推理。如果用 Nvidia GPU,需要支持 NCCL;如果用 CPU,需要支持 MPI。 -## Project #6: Support New Models +## 项目#6:支持新模型 -Support another model type than the one we use for homework in LLAISYS. +在 LLAISYS 中支持除作业所用模型以外的其他模型。 diff --git a/README_ZN.md b/README_ZN.md deleted file mode 100644 index 7704dbd5b..000000000 --- a/README_ZN.md +++ /dev/null @@ -1,432 +0,0 @@ -# 欢迎使用 LLAISYS - -

-English | -中文 -

- -## 简介 - -LLAISYS(Let's Learn AI SYStem)是一个教育项目,旨在为新手和未来的AI工程师提供一个从零开始构建AI系统的学习平台。LLAISYS包含多个作业,帮助学生学习和构建基础模块;以及一些项目挑战,让他们为系统添加更多高级功能。LLAISYS使用C++作为系统后端的主要编程语言,并编译成共享库,提供C语言API。前端代码使用Python编写,调用这些API以提供更便捷的测试和与其他架构(如PyTorch)的交互。 - -### 项目结构概览 - -- `\include`:包含所有定义共享库提供的C API的头文件的目录。(函数声明以`__export`开头) - -- `\src`:C++源文件。 - - `\src\llaisys`包含头文件中定义的所有直接实现,并遵循与`\include`相同的目录结构。这也是C++代码的边界。 - - 其他目录包含不同模块的实际实现。 - -- `xmake.lua`:llaisys后端的构建规则。`\xmake`目录包含不同设备的子xmake文件。例如,将来可以在目录中添加`nvidia.lua`来支持CUDA。 - -- `\python`:Python源文件。 - - `\python\llaisys\libllaisys`包含llaisys API的所有ctypes封装函数。它基本上与C头文件的结构相匹配。 - - `\python\llaisys`包含ctypes函数的Python包装器,使包更符合Python风格。 - -- `\test`:导入llaisys python包的Python测试文件。 - -## 作业 #0:入门 - -### 任务-0.1 安装必备组件 - -- 编译工具:[Xmake](https://xmake.io/) -- C++编译器:MSVC(Windows)或Clang或GCC -- Python >= 3.9(PyTorch、Transformers等) -- Clang-Format-16(可选):用于格式化C++代码。 - -### 任务-0.2 Fork并构建LLAISYS - -- Fork LLAISYS仓库并克隆到本地机器。支持Windows和Linux。 - -- 编译和安装 - - ```bash - # 编译c++代码 - xmake - # 安装llaisys共享库 - xmake install - # 安装llaisys python包 - pip install ./python/ - ``` - -- Github自动测试 - - LLAISYS使用Github Actions在每次推送和拉取请求时运行自动化测试。你可以在仓库页面上看到测试结果。完成所有作业任务后,所有测试都应该通过。 - -### 任务-0.3 首次运行LLAISYS - -- 运行cpu运行时测试 - - ```bash - python test/test_runtime.py --device cpu - ``` - - 你应该看到测试通过。 - -### 任务-0.4 下载测试模型 - -- 我们用于作业的模型是[DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)。 - -- 使用PyTorch运行模型推理测试 - - ```bash - python test/test_infer.py --model [dir_path/to/model] - ``` - - 你可以看到PyTorch能够加载模型并使用示例输入执行推理。你可以调试进入`transformers`库代码来深入查看并了解其内部运作原理。现在,你的代码还无法执行任何操作,但在后续的作业中,你将构建一个能够实现相同功能的系统。 - -## 作业 #1:张量 - -张量是表示多维数据的数据结构。它是LLAISYS和大多数AI框架(如PyTorch)的基本构建单元。在这个作业中,你将学习如何实现一个基本的张量类。 - -张量对象具有以下字段: - -- `storage`:指向存储张量数据的内存块的共享指针。它可以被多个张量共享。有关更多详细信息,请查看storage类。 -- `offset`:张量在存储中的起始索引(以字节为单位)。 -- `meta`:描述张量形状、数据类型和步长的元数据。 - -实现`src/tensor/tensor.hpp`中定义的以下函数: - -### 任务-1.1 - -```c++ -void load(const void *src); -``` - -将主机(cpu)数据加载到张量(可以在设备上)。查看构造函数了解如何获取当前设备上下文的运行时API,并执行从主机到设备的内存复制。 - -### 任务-1.2 - -```c++ -bool isContiguous() const; -``` - -检查张量的形状和步长,判断它在内存中是否连续。 - -### 任务-1.3 - -```c++ -tensor_t view(const std::vector &shape) const; -``` - -创建一个新张量,通过拆分或合并原始维度将原始张量重塑为给定形状。不涉及数据传输。例如,通过合并最后两个维度,将形状为(2, 3, 5)的张量更改为(2, 15)。 - -这个函数不是简单地改变张量的形状那么简单,尽管测试会通过。如果新视图与原始张量不兼容,它应该引发错误。想想一个形状为(2, 3, 5)、步长为(30, 10, 1)的张量。你还能在不传输数据的情况下将其重塑为(2, 15)吗? - -### 任务-1.4 - -```c++ -tensor_t permute(const std::vector &order) const; -``` - -创建一个新张量,改变原始张量维度的顺序。转置可以通过这个函数实现,而无需移动数据。 - -### 任务-1.5 - -```c++ -tensor_t slice(size_t dim, size_t start, size_t end) const; -``` - -创建一个新张量,沿给定维度,start(包含)和end(不包含)索引对原始张量进行切片操作。 - -### 任务-1.6 - -运行张量测试。 - -```bash -python test/test_tensor.py -``` - -你应该看到所有测试都通过了。提交并推送你的更改。你应该看到作业#1的自动测试通过了。 - -## 作业 #2:算子 - -在这个作业中,你将实现以下算子的cpu版本: - -- argmax -- embedding -- linear -- rms_norm -- rope -- self_attention -- swiglu - -阅读`src/ops/add/`中的代码,了解"add"算子是如何实现的。确保你理解算子代码是如何组织、编译、链接以及暴露给Python前端的。**你的算子应该至少支持Float32、Float16和BFloat16数据类型**。`src/utils/`中提供了一个用于简单类型转换的辅助函数。所有python测试都在`test/ops`中,你的实现应该至少通过这些测试。首先尝试运行"add"算子的测试脚本。 - -### 任务-2.1 Argmax - -```c++ -void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals); -``` - -获取张量`vals`的最大值及其索引,并分别存储在`max_val`和`max_idx`中。你暂时可以假设`vals`是一个1D张量,`max_idx`和`max_val`都是包含单个元素的1D张量(这意味着保留了`vals`的维度)。 - -完成实现后,你应该能够通过`test/ops/argmax.py`中的测试用例。 - -### 任务-2.2 Embedding - -```c++ -void embedding(tensor_t out, tensor_t index, tensor_t weight); -``` - -从`weight`(2-D)中复制`index`(1-D)中的行到`output`(2-D)。`index`必须是Int64类型(PyTorch中int的默认数据类型)。 - -完成实现后,你应该能够通过`test/ops/embedding.py`中的测试用例。 - -### 任务-2.3 Linear - -```c++ -void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias); -``` - -计算以下内容: - -$$ -Y = xW^T + b -$$ - -- `out`:输出 $Y$ 。你暂时可以假设输出是一个2D连续张量,不涉及广播。 -- `input`:输入 $X$ 。你暂时可以假设输入是一个2D连续张量,不涉及广播。 -- `weight`:权重 $W$ 。2D连续张量。注意权重张量没有转置。你需要在计算过程中处理这个问题。 -- `bias`(可选):偏置 $b$ 。1D张量。你需要支持不提供偏置的情况。 - -完成实现后,你应该能够通过`test/ops/linear.py`中的测试用例。 - -### 任务-2.4 RMS Normalization - -```c++ -void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps); -``` - -为每一行计算以下内容: - -$$ -Y_i = \frac{W_i \times X_i}{\sqrt{\frac{1}{d}(\sum_{j=1}^d X_j^2) + \epsilon}} -$$ - -- `out`:输出 $Y$ 。你暂时可以假设输出是一个2D连续张量,不涉及广播。 -- `input`:输入 $X$ 。你暂时可以假设输入是一个2D连续张量,不涉及广播。标准化沿输入张量的最后一个维度(即每一行,长度为 $d$ )执行。 -- `weight`:权重 $W$ 。1D张量,与输入张量的一行长度相同。 -- `eps`:小值 $\epsilon$ 以避免除以零。 - -完成实现后,你应该能够通过`test/ops/rms_norm.py`中的测试用例。 - -### 任务-2.5 旋转位置编码(RoPE) - -```c++ -void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta); -``` - -为输入张量`in`的每个向量(这些向量与 pos_ids 中的位置 id 相对应)计算以下内容: - -设 $\mathbf{x}_i = [\mathbf{a}_i, \mathbf{b}_i] \in \mathbb{R}^d$ 为输入向量, $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i] \in \mathbb{R}^d$ 为索引 $i$ 处的输出向量,其中 $\mathbf{a}_i, \mathbf{b}_i,\mathbf{a}'_i, \mathbf{b}'_i \in \mathbb{R}^{d/2}$ 。 - -设 $\theta$ 为固定基数(例如 $\theta = 10000$), $j = 0, 1, \ldots, d/2 - 1$。 - -设 $p_i \in \mathbb{N}$ 是输入索引i处token的位置id。 - -那么RoPE的角度为 $\phi_{i,j} = \frac{p_i}{\theta^{2j/d}}$ - -输出向量 $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i]$ 计算如下: - -$$a_{i,j}' = a_{i,j} \cos(\phi_{i,j}) - b_{i,j} \sin(\phi_{i,j})$$ - -$$b_{i,j}' = b_{i,j} \cos(\phi_{i,j}) + a_{i,j} \sin(\phi_{i,j})$$ - -- `out`:结果**q**或**k**张量。形状应该是 [seqlen, nhead, d] 或 [seqlen, nkvhead, d]。你暂时可以假设张量是连续的。 -- `in`:原始**q**或**k**张量。形状应该是 [seqlen, nhead, d] 或 [seqlen, nkvhead, d]。你暂时可以假设张量是连续的。 -- `pos_ids`:输入序列中每个token的位置id(整个上下文中的索引)。形状应该是 [seqlen,],dtype应该是int64。 -- `theta`:频率向量的基值。 - -完成实现后,你应该能够通过`test/ops/rope.py`中的测试用例。 - -### 任务-2.6 自注意力(self-attention) - -```c++ -void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale); -``` - -为查询张量`q`、键张量`k`和值张量`v`计算自注意力。如果需要,你应该在进行此计算之前连接kvcache张量。 - -$$ -A = Q K^\top * scale \\ -$$ - -$$ -Y = \mathrm{causalsoftmax}(A) \cdot V \\ -$$ - -- `attn_val`:结果注意力值张量。形状应该是[seqlen, nhead, dv]。你暂时可以假设张量是连续的。 -- `q`:查询张量。形状应该是 [seqlen, nhead, d]。你暂时可以假设张量是连续的。 -- `k`:键张量。形状应该是 [total_len, nkvhead, d]。你暂时可以假设张量是连续的。 -- `v`:值张量。形状应该是 [total_len, nkvhead, dv]。你暂时可以假设张量是连续的。 -- `scale`:缩放因子。在大多数情况下取值为 $\frac{1}{\sqrt{d}}$ 。 - -完成实现后,你应该能够通过`test/ops/self_attention.py`中的测试用例。 - -### 任务-2.7 SwiGLU - -```c++ -void swiglu(tensor_t out, tensor_t gate, tensor_t up); -``` - -这是一个逐元素函数,计算以下内容: - -$$ -out_{i} = up_{i} \circ \frac { gate_{i}}{1 + e^{-gate_{i}}} -$$ - -`out`、`up`和`gate`是具有相同形状 [seqlen, intermediate_size] 的2D连续张量。 - -完成实现后,你应该能够通过`test/ops/swiglu.py`中的测试用例。 - -### 任务-2.8 - -运行算子测试。 - -```bash -python test/test_ops.py -``` - -你应该看到所有测试都通过了。提交并推送你的更改。你应该看到作业#2的自动测试通过了。 - -### 任务-2.9(可选)rearrange - -这是一个奖励任务。你在模型推理中可能需要也可能不需要它。 - -```c++ -void rearrange(tensor_t out, tensor_t in); -``` - -此算子用于将数据从一个张量复制到另一个具有相同形状但不同步长的张量。有了这个,你可以轻松地为张量实现`contiguous`功能。 - -## 作业 #3:大语言模型推理 - -终于,是时候用LLAISYS实现文本生成了。 - -- 在`test/test_infer.py`中,你的实现应该能够使用argmax采样生成与PyTorch相同的文本。我们用于此作业的模型是[DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)。 - -- 你的实现的python包装器在`python/llaisys/models/qwen2.py`中。你不允许在这里使用任何基于python的框架(如PyTorch)实现你的模型推理逻辑。相反,你需要在LLAISYS后端用C/C++实现模型。脚本加载safetensors文件中的每个张量,你需要从它们加载数据到你的模型后端。 - -- 在`include/llaisys/models/qwen2.h`中,为你定义了一个原型。你可以随意修改代码,但你应该至少提供模型创建、销毁、数据加载和推理的基本API。在`src/llaisys/`中实现你的C API,并像`src/`中的其他模块一样组织你的C++代码。记得在`xmake.lua`中定义编译过程。 - -- 在`python/llaisys/libllaisys/`中,为你的C API定义ctypes包装函数。使用你的包装函数实现`python/llaisys/models/qwen2.py`。 - -- 你需要实现 KV-Cache 功能,否则模型推理速度会过慢。 - -- 调试直到你的模型工作。利用张量的`debug`函数打印张量数据。它允许你在模型推理期间将任何张量的数据与PyTorch进行比较。 - -完成实现后,你可以运行以下命令来测试你的模型: - -```bash -python test/test_infer.py --model [dir_path/to/model] --test -``` - -提交并推送你的更改。你应该看到作业#3的自动测试通过了。 - -## 只有完成作业后,才能开始做项目。 - -## 项目#1:优化 LLAISYS 的 CPU 推理 - -你可能已经注意到,你的模型推理速度相比 PyTorch 非常慢。这主要是因为你的算子没有经过优化。运行算子测试脚本时加上 ``--profile`` 参数,看看算子的性能表现。你可能会发现 ``linear`` 操作比 PyTorch 慢很多。这个算子本质上是矩阵乘法,是 Transformer 模型里最耗时的操作。 - -以下是几种优化 CPU 算子的方法: - -### 使用 SIMD 指令 - -SIMD(单指令多数据)是一类可以在单条指令中对多个数据元素同时执行相同操作的指令。现代 CPU 都支持 SIMD。你可以查阅相关资料,学习编译器内建函数(如 AVX2、AVX-512、NEON、SVE)来向量化你的算子。 - -### 使用 OpenMP 实现并行 - -你可以用多线程来并行化算子。OpenMP 是 C/C++ 中常见的多线程库。为 LLAISYS 增加 OpenMP 支持,使得 ``linear`` 等算子能够并行执行。 - -### 使用第三方库 - -有很多库能帮你优化 CPU 上的算子,例如 Eigen、OpenBLAS、MKL 等,它们能高效处理线性代数运算。但要注意,有些库只支持特定硬件平台,需要仔细阅读文档并小心使用。你也可以参考 PyTorch 的算子实现,看是否能复用。 - -用任何你喜欢的方法优化你的推理实现,并报告性能提升情况。 - -## 项目#2:在 LLAISYS 中集成 CUDA,适配两款CUDA或类CUDA平台(以下统称CUDA) - -这个项目不依赖 ``项目#1``。需要选择 Nvidia、天数、摩尔、沐曦中的至少两款平台。 - -本次训练营提供了以上四种平台的算力,可以在官方进行申请算力,并用 CUDA 加速模型推理。在动手前,先深入理解 LLAISYS 框架。 - -事实上,LLAISYS 是一个支持同构硬件的框架。使用时,每个线程会创建一个线程唯一的 **Context** 对象,管理该线程使用的所有设备 **Runtime**。**Runtime** 对象是设备的资源管理器,**Context** 会为每个设备(以延迟初始化的方式)创建唯一的 **Runtime**。你可以用 ``setDevice`` 在不同设备间切换,每个线程同一时间只会激活一个设备。详情见 ``src/core/context.hpp``。 - -### 实现 CUDA Runtime API - -每个 **Runtime** 对象都会初始化一组通用的 **Runtime API**。你需要实现 CUDA 版本的 API。参考 ``src/device/cpu/cpu_runtime_api.cpp`` 看 CPU 的实现方式,查阅 [`CUDA Runtime 文档`](https://docs.nvidia.com/cuda/cuda-runtime-api/index.html) 找到对应 API。 - -在 ``src/device/runtime_api.hpp`` 中,``nvidia::getRuntimeAPI()`` 被 ``ENABLE_NVIDIA_API`` 宏保护: - -```c++ -#ifdef ENABLE_NVIDIA_API -namespace nvidia { -const LlaisysRuntimeAPI *getRuntimeAPI(); -} -#endif -``` - -该宏的定义在 ``xmake.lua`` 中,用于开关 CUDA 支持。若关闭,CUDA 代码不会被编译。你需要在 ``xmake/`` 下新建 ``nvidia.lua``,配置编译流程(参考 ``cpu.lua``)。查阅资料学习如何用 Xmake 配置。 - -完成 CUDA Runtime API 后,用 ``--nv-gpu=y`` 打开 CUDA 支持并重新编译,运行测试: - -```bash -xmake f --nv-gpu=y -cv -xmake -xmake install -python test/test_runtime.py --device nvidia -``` - -### 实现 CUDA 算子 - -在每个算子目录下新建 ``nvidia/`` 子目录,写 CUDA 版本实现。参考 ``src/ops/add/op.cpp`` 看如何包含 CUDA 实现。别忘了在 xmake 文件中定义编译流程。用 ``--device nvidia`` 参数运行测试。 - -你可以使用 cuBLAS、cuDNN 等 CUDA 库来加速算子,额外的设备资源可以放在 `src/device/nvidia/nvidia_resource.cu`。 - -最后,修改模型代码,支持 CUDA 推理: - -```bash -python test/test_infer.py --model [dir_path/to/model] --test --device nvidia -``` - -## 项目#3:构建 AI 聊天机器人 - -本项目中,你将用 LLAISYS 构建一个能与单用户实时对话的聊天机器人。 - -### 随机采样 - -目前我们只用过 argmax 采样,这在测试时够用,但聊天机器人需要更自然的回复。请实现一个随机采样算子,并尽量支持 **Temperature**、**Top-K**、**Top-P**。 - -### 搭建聊天服务器 - -在 Python 前端里,实现一个能接收 HTTP 请求并返回响应的服务器。可以用 FastAPI 等框架。接口最好遵循 OpenAI 的 chat-completion API。如果可以,尽量支持流式输出。你可以先假设只有一个用户在使用,每次请求可以阻塞直到处理完成。 - -### 交互式聊天 UI - -实现一个 UI,能向服务器发送请求并接收回复。可以是命令行界面,也可以是 Web 界面。要能通过连续发送消息与机器人保持对话。 - -### (可选)会话管理 - -实际应用中,用户可以开启多个对话并在它们之间切换,还能修改历史问题让 AI 重新生成回答。扩展 UI,支持这些功能。实现一个支持前缀匹配的 KV-Cache 池,尽可能复用已有结果。 - -## 项目#4:多用户推理服务 - -在做这个项目之前,你需要完成 ``项目#3`` 并实现流式输出。 - -### 支持多用户 - -现实中推理服务要同时为多个用户提供服务,请求可能随时到来。你的服务端需要将请求加入请求池/队列,并用单独的循环线程/进程来处理。 - -### 连续批处理 - -为了最大化吞吐量,你需要做批处理,而不是逐一处理。由于每个请求长度不同,需要实现连续的迭代级批处理机制:每轮从池中取出若干请求组成批次(batch),执行一次批量推理,再把未完成的请求放回池中。推理时尽量用批量矩阵乘法加速。注意每个请求需要绑定不同的 KV-Cache,应实现支持前缀匹配的 KV-Cache 池来复用结果。 - -## 项目#5:分布式推理 - -在 LLAISYS 中引入张量并行。把模型分片到多个设备上,实现分布式推理。如果用 Nvidia GPU,需要支持 NCCL;如果用 CPU,需要支持 MPI。 - -## 项目#6:支持新模型 - -在 LLAISYS 中支持除作业所用模型以外的其他模型。 diff --git a/REPORT.md b/REPORT.md new file mode 100644 index 000000000..49914b741 --- /dev/null +++ b/REPORT.md @@ -0,0 +1,308 @@ +# LLAISYS 项目报告 + +## 一、已完成工作总览 + +| 模块 | 说明 | +|------|------| +| 作业 #0-#3(基础) | 张量、算子、模型推理全部完成 | +| 项目 #2:多平台 CUDA 适配 | Nvidia + 天数 Iluvatar CoreX 双平台 | +| 项目 #3:AI 聊天机器人 | 服务器 + 前端 + 流式输出 + 会话管理 + KV 复用 | +| 项目 #4:多用户推理服务 | 调度器 + 连续批处理 + 共享模型池 + KV 感知路由 | +| 项目 #5:分布式推理 | 通信层 + NCCL 后端 + 张量并行 | + +## 二、作业阶段 + +**作业 #1:张量** — 实现了张量的核心操作:`load`、`isContiguous`、`view`、`permute`、`slice`。所有测试通过。 + +**作业 #2:算子** — 实现了 9 个 CPU 算子:`add`、`argmax`、`embedding`、`linear`、`rearrange`、`rms_norm`、`rope`、`self_attention`、`swiglu`。支持 Float32/Float16/BFloat16 数据类型,全部测试通过。 + +**作业 #3:大语言模型推理** — 实现了 DeepSeek-R1-Distill-Qwen-1.5B 模型的完整推理链路:C++ Decoder 实现(Transformer 前向传播 + KV Cache)、C API 导出 + Python ctypes 封装、端到端推理输出与 PyTorch 完全一致。 + +## 三、项目阶段 + +**项目 #2:多平台 CUDA 适配** + +在 Nvidia GPU 和天数 Iluvatar CoreX GPU 两个平台上实现 CUDA 加速推理。 + +实现方案: +- Nvidia 平台:实现 CUDA Runtime API + 9 个 CUDA 算子内核,使用 nvcc 编译 +- 天数 Iluvatar CoreX 平台:采用 kernel 零复制策略,直接复用 `nvidia::` 命名空间的 CUDA 内核,使用 `clang++ -x cuda --cuda-gpu-arch=ivcore10` 编译 + +关键问题与解决: + +| 问题 | 解决方案 | +|------|----------| +| xmake 自动调用 nvcc 而非 clang++ | 使用 `on_build()` 手动控制编译 | +| xmake 注入 `-lcudadevrt` | 不注册 .cu 文件,避免 CUDA 检测 | +| 静态库符号未解析 | `--whole-archive` 强制完整包含 | +| `-lcudart` 链接顺序错误 | 统一放入 `add_shflags()` 控制顺序 | + +验证结果:Nvidia 和 Iluvatar 平台的 runtime、算子、端到端推理测试全部通过。 + +--- + +**项目 #3:AI 聊天机器人** + +实现内容: +1. 随机采样:支持 Temperature、Top-K、Top-P、Seed(C API + Python 封装) +2. 聊天服务器(`python/llaisys/server.py`):HTTP 服务,兼容 OpenAI Chat Completion API(`/v1/chat/completions`),支持流式输出(SSE)和非流式输出 +3. 前端 UI(`frontend/`):Web 界面,支持连续对话、流式显示 +4. 会话管理:多会话支持、历史消息编辑 + 分叉重新生成、前缀匹配 KV Cache 池跨会话复用 + +架构: +``` +前端 (HTML/JS) → HTTP → 服务器 (Python) → C API → C++ 推理引擎 + ↕ + KV Cache Pool +``` + +--- + +**项目 #4:多用户推理服务** + +实现内容: +1. 请求调度器(`python/llaisys/scheduler.py`):入口线程 + 调度器 + Worker 执行模式,支持多 Worker、请求队列、超时控制、会话粘性路由 + KV 感知路由 +2. 连续批处理:迭代级批处理、Packed Prefill、动态缩批、流式 + 非流式请求均走批量路径 +3. 共享模型池(`--shared-model`):多 Worker 共享一份模型,内存从 N×model_size 降到 1×model_size +4. KV 内存感知流控(`--kv-memory-threshold`):内存压力超阈值时拒绝新请求(429) + +推荐启动参数: +```bash +python -m llaisys.server --model <模型路径> \ + --workers 4 --shared-model \ + --continuous-batching --max-batch-size 8 \ + --kv-aware-routing --kv-memory-threshold 0.85 +``` + +压测结果: + +| 参数 | 成功率 | 吞吐 | 平均延迟 | +|------|--------|------|----------| +| total=20, concurrency=2, tokens=16 | 20/20 | 0.18 rps | 11.1s | +| total=12, concurrency=4, tokens=8 | 12/12 | 0.37 rps | 10.2s | + +--- + +**项目 #5:分布式推理** + +引入张量并行,将模型分片到多个 GPU 上实现分布式推理,使用 NCCL 通信。 + +实现内容: + +1. 通信层(C API + C++ + NCCL 后端): + - `include/llaisys/comm.h` → C API 头文件(函数指针表) + - `src/device/comm_api.{hpp,cpp}` → C++ dispatcher(#ifdef 条件编译) + - `src/device/nvidia/nvidia_comm.cu` → NCCL 后端(8 个操作) + - 支持操作:init、destroy、get_rank、get_size、allreduce、broadcast、send、recv + +2. 张量并行(Megatron-style):Decoder 中每层插入 2 个 AllReduce(`attn_o` 和 `mlp_down` 线性投影后、残差加之前),单 GPU 时零开销 + +3. 权重切分(`python/llaisys/tensor_parallel.py`): + + | 权重 | 切分方式 | 说明 | + |------|----------|------| + | Q/K/V/gate/up | Column split (dim 0) | 每 rank 获得 nh/tp_size 个 head | + | attn_o/down | Row split (dim 1) | 输出需 AllReduce 聚合 | + | embeddings/norms | 复制 | 所有 rank 持有完整副本 | + +4. 多进程启动器:Rank 0 生成 NCCL unique ID → 文件 IPC 广播 → 各 rank 加载切分权重 → 分布式推理 + +验证结果(8×A100-80GB 服务器): + +| 测试 | 结果 | +|------|------| +| 单卡 runtime + 算子 | ✅ 通过 | +| 通信层单元测试 | ✅ 通过 | +| 2 卡 AllReduce | ✅ 通过(SUM = 3.0) | +| 4 卡 AllReduce | ✅ 通过(SUM = 10.0) | +| 8 卡 AllReduce | ⚠️ 超时(显存被其他进程占用) | +| 张量并行推理 | ✅ 通过(2 卡,token 一致) | + +## 四、代码架构 + +``` +llaisys/ +├── include/llaisys/ # C API 头文件 +│ ├── llaisys.h # 基础类型定义 +│ ├── runtime.h # 运行时 API +│ ├── comm.h # 通信 API +│ └── models/qwen2.h # 模型 API +├── src/ +│ ├── device/ # 设备抽象层 +│ │ ├── cpu/ # CPU 实现 +│ │ ├── nvidia/ # CUDA 实现 + NCCL 通信 +│ │ └── iluvatar/ # 天数 CoreX 实现 +│ ├── ops/ # 算子(9 个,各含 cpu/nvidia 子目录) +│ ├── models/ # 模型实现(Qwen2 Decoder) +│ └── core/ # 运行时核心(Context/Runtime/Storage) +├── python/llaisys/ # Python 前端 +│ ├── server.py # 聊天服务器 +│ ├── scheduler.py # 请求调度器 +│ ├── tensor_parallel.py # 权重切分 +│ └── libllaisys/ # ctypes 绑定 +├── frontend/ # Web UI +├── scripts/ # 工具脚本(启动器、压测) +├── test/ # 测试文件 +└── xmake.lua # 构建配置 +``` + +## 五、复现流程 + +### 环境要求 + +| 依赖 | 版本要求 | 用途 | +|------|----------|------| +| Xmake | >= 2.7 | 构建工具 | +| C++ 编译器 | GCC >= 9 / Clang >= 10 / MSVC 2019+ | 编译后端 | +| Python | >= 3.9 | 前端 + 测试 | +| PyTorch | >= 2.0 | 对比验证(仅测试时需要) | +| CUDA Toolkit | >= 11.0 | GPU 推理(项目 #2 起) | +| NCCL | >= 2.10 | 分布式推理(项目 #5) | + +### 步骤 0:克隆仓库 + 下载模型 + +```bash +git clone https://github.com/KevinSusan/llaisys_tt.git +cd llaisys_tt + +# 下载测试模型 DeepSeek-R1-Distill-Qwen-1.5B(约 3GB) +pip install huggingface_hub +huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --local-dir ./model +# 国内镜像:HF_ENDPOINT=https://hf-mirror.com huggingface-cli download ... +``` + +> 以下所有命令中 `./model` 替换为实际模型路径。 + +--- + +### 作业 #1-#3 验证(CPU,任意机器) + +```bash +# 编译 +xmake build + +# 安装共享库(Linux) +cp build/linux/x86_64/release/libllaisys.so python/llaisys/libllaisys/ +# Windows: copy build\windows\x64\release\llaisys.dll python\llaisys\libllaisys\ + +# 设置 Python 路径 +export PYTHONPATH=$(pwd)/python:$PYTHONPATH + +# 作业 #1:张量测试 +python test/test_tensor.py + +# 作业 #2:CPU 算子测试 +python test/test_ops.py + +# 作业 #3:CPU 端到端推理(输出应与 PyTorch 完全一致) +python test/test_infer.py --model ./model --test +``` + +--- + +### 项目 #2 验证(Nvidia GPU) + +设备要求:Nvidia GPU + CUDA Toolkit + +```bash +# 编译(开启 Nvidia GPU 支持) +xmake f --nv-gpu=y -c +xmake build +cp build/linux/x86_64/release/libllaisys.so python/llaisys/libllaisys/ +export PYTHONPATH=$(pwd)/python:$PYTHONPATH + +# GPU 运行时测试 +python test/test_runtime.py --device nvidia + +# GPU 算子测试(9 个算子) +python test/ops_gpu/run_all.py --device nvidia + +# GPU 端到端推理(输出应与 PyTorch 完全一致) +python test/test_infer.py --model ./model --test --device nvidia +``` + +### 项目 #2 验证(天数 Iluvatar CoreX GPU) + +设备要求:天数 Iluvatar CoreX GPU + CoreX SDK + +```bash +xmake f --iluvatar-gpu=y -c +xmake build +cp build/linux/x86_64/release/libllaisys.so python/llaisys/libllaisys/ +export PYTHONPATH=$(pwd)/python:/usr/local/corex/lib64/python3/dist-packages:$PYTHONPATH + +python test/test_runtime.py --device iluvatar +python test/ops_gpu/run_all.py --device iluvatar +python test/test_infer.py --model ./model --test --device iluvatar +``` + +--- + +### 项目 #3 验证(聊天机器人) + +设备要求:同项目 #2(GPU 推理) + +```bash +# 启动聊天服务器 +python -m llaisys.server --model ./model --device nvidia + +# 在另一个终端测试 API(兼容 OpenAI 格式) +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":32}' + +# 或打开浏览器访问 http://localhost:8000 使用 Web UI +``` + +--- + +### 项目 #4 验证(多用户推理服务) + +设备要求:同项目 #2(GPU 推理) + +```bash +# 启动多用户服务(共享模型 + 连续批处理) +python -m llaisys.server --model ./model --device nvidia \ + --workers 2 --shared-model --continuous-batching --max-batch-size 8 + +# 运行调度器测试 +python test/test_scheduler_inmemory.py + +# 运行并发压测 +python scripts/benchmark_chat_scheduler.py \ + --url http://localhost:8000 --total 12 --concurrency 4 --max-new-tokens 8 +``` + +--- + +### 项目 #5 验证(分布式推理) + +设备要求:多张 Nvidia GPU + NCCL + +```bash +# 编译(确保 --nv-gpu=y) +xmake f --nv-gpu=y -c && xmake build +cp build/linux/x86_64/release/libllaisys.so python/llaisys/libllaisys/ +export PYTHONPATH=$(pwd)/python:$PYTHONPATH +pip install transformers safetensors + +# 通信层单元测试(单卡) +python test/test_comm_api.py --device nvidia + +# 多卡 AllReduce 集成测试 +python test/test_allreduce.py --nranks 2 --device nvidia +python test/test_allreduce.py --nranks 4 --device nvidia + +# 张量并行推理(2 卡) +python scripts/launch_tp.py \ + --model ./model --nranks 2 --device nvidia \ + --prompt "Hello, world" --max-tokens 32 +``` + +## 六、技术亮点 + +1. **跨平台 CUDA 适配**:通过 kernel 零复制策略,天数 Iluvatar 平台无需修改任何 CUDA 内核代码,直接复用 Nvidia 实现 +2. **完整推理服务栈**:从底层 C++ 算子到 HTTP API,全链路自研,兼容 OpenAI API 格式 +3. **连续批处理**:迭代级调度 + Packed Prefill + 动态缩批,支持流式和非流式混合请求 +4. **Megatron-style 张量并行**:通信层抽象设计,支持 NCCL/IXCCL/MPI 多后端,Decoder 中仅需 2 个 AllReduce/层 +5. **KV Cache 复用体系**:前缀匹配 + 跨会话 donor 复用 + 分叉编辑 + 内存感知流控 diff --git a/Untitled b/Untitled new file mode 100644 index 000000000..fbd2fd289 --- /dev/null +++ b/Untitled @@ -0,0 +1 @@ +xmake f -m release --nv-gpu=y --vs=2022 \ No newline at end of file diff --git a/frontend/app.js b/frontend/app.js new file mode 100644 index 000000000..8f766634e --- /dev/null +++ b/frontend/app.js @@ -0,0 +1,366 @@ +let activeId = ""; +const conversations = []; + +const chat = document.getElementById("chat"); +const form = document.getElementById("chat-form"); +const promptInput = document.getElementById("prompt"); +const endpointInput = document.getElementById("endpoint"); +const maxTokensInput = document.getElementById("max-tokens"); +const samplingModeInput = document.getElementById("sampling-mode"); +const temperatureInput = document.getElementById("temperature"); +const topKInput = document.getElementById("top-k"); +const topPInput = document.getElementById("top-p"); +const seedInput = document.getElementById("seed"); +const editHint = document.getElementById("edit-hint"); +const sendButton = document.getElementById("send"); +const stopButton = document.getElementById("stop"); +const sessionList = document.getElementById("session-list"); +const newChatButton = document.getElementById("new-chat"); +let activeStreamController = null; +let pendingEdit = null; + +const createLocalId = () => { + if (crypto && crypto.randomUUID) return crypto.randomUUID(); + return `local-${Date.now()}-${Math.random().toString(16).slice(2)}`; +}; + +const dedupeAdjacentParagraphs = (text) => { + const parts = text + .split(/\n{2,}/) + .map((p) => p.trim()) + .filter(Boolean); + const deduped = []; + for (const p of parts) { + if (deduped.length > 0 && deduped[deduped.length - 1] === p) continue; + deduped.push(p); + } + return deduped.join("\n\n"); +}; + +const parseAssistantSections = (rawText) => { + const normalized = String(rawText || "").replaceAll("<|end_of_sentence|>", ""); + const openTag = ""; + const closeTag = ""; + const start = normalized.indexOf(openTag); + const closeOnly = normalized.indexOf(closeTag); + + // Tolerate outputs containing only a closing tag. + if (start < 0 && closeOnly >= 0) { + const thinking = normalized.slice(0, closeOnly).trim(); + const answer = normalized.slice(closeOnly + closeTag.length).trim(); + return { + thinking: dedupeAdjacentParagraphs(thinking.replaceAll(closeTag, "")), + answer: dedupeAdjacentParagraphs(answer.replaceAll(closeTag, "")), + }; + } + + if (start < 0) { + return { thinking: "", answer: dedupeAdjacentParagraphs(normalized.replaceAll(closeTag, "")) }; + } + const afterOpen = start + openTag.length; + const end = normalized.indexOf(closeTag, afterOpen); + if (end < 0) { + return { + thinking: dedupeAdjacentParagraphs(normalized.slice(afterOpen).replaceAll(openTag, "")), + answer: "", + }; + } + const thinking = normalized.slice(afterOpen, end).replaceAll(openTag, "").replaceAll(closeTag, ""); + const answer = normalized.slice(end + closeTag.length).replaceAll(openTag, "").replaceAll(closeTag, ""); + return { + thinking: dedupeAdjacentParagraphs(thinking), + answer: dedupeAdjacentParagraphs(answer), + }; +}; + +const renderAssistantBubble = (bubble, rawText) => { + bubble.dataset.raw = rawText; + const { thinking, answer } = parseAssistantSections(rawText); + const thinkingSection = bubble.querySelector(".assistant-thinking"); + const answerSection = bubble.querySelector(".assistant-answer"); + const normalizedThinking = thinking.replace(/\s+/g, " ").trim(); + const normalizedAnswer = answer.replace(/\s+/g, " ").trim(); + const isRedundantThinking = + normalizedThinking && + normalizedAnswer && + (normalizedThinking === normalizedAnswer || + normalizedAnswer.includes(normalizedThinking)); + + if (thinking && thinking.trim() && !isRedundantThinking) { + thinkingSection.style.display = "block"; + thinkingSection.querySelector(".assistant-thinking-content").textContent = thinking.trim(); + } else { + thinkingSection.style.display = "none"; + thinkingSection.querySelector(".assistant-thinking-content").textContent = ""; + } + answerSection.textContent = answer.trimStart(); +}; + +const clearPendingEdit = () => { + pendingEdit = null; + sendButton.textContent = "发送"; + editHint.style.display = "none"; + editHint.textContent = ""; +}; + +const setPendingEdit = (state) => { + pendingEdit = state; + sendButton.textContent = "分叉发送"; + const round = Number(state.editMessageIndex) + 1; + editHint.textContent = `正在编辑第 ${round} 轮用户消息,发送后将创建分叉会话(Esc 可取消)`; + editHint.style.display = "block"; +}; + +const appendBubble = (text, role, options = {}) => { + const div = document.createElement("div"); + div.className = `bubble ${role}`; + if (role === "assistant") { + div.innerHTML = ` + +
+ `; + renderAssistantBubble(div, text || ""); + } else { + const content = document.createElement("div"); + content.className = "user-content"; + content.textContent = text; + div.appendChild(content); + if (options.canEdit) { + const editButton = document.createElement("button"); + editButton.type = "button"; + editButton.className = "bubble-edit"; + editButton.textContent = "编辑"; + editButton.addEventListener("click", options.onEdit); + div.appendChild(editButton); + } + } + chat.appendChild(div); + chat.scrollTop = chat.scrollHeight; + return div; +}; + +const renderChat = (conversation) => { + chat.innerHTML = ""; + for (let i = 0; i < conversation.messages.length; i += 1) { + const message = conversation.messages[i]; + const canEdit = message.role === "user" && Boolean(conversation.serverId); + appendBubble(message.text, message.role, { + canEdit, + onEdit: () => { + if (!conversation.serverId) return; + setPendingEdit({ + sourceLocalId: conversation.id, + sourceServerId: conversation.serverId, + editMessageIndex: i, + }); + promptInput.value = message.text || ""; + promptInput.focus(); + }, + }); + } +}; + +const renderSessions = () => { + sessionList.innerHTML = ""; + for (const convo of conversations) { + const item = document.createElement("div"); + item.className = `session-item${convo.id === activeId ? " active" : ""}`; + item.textContent = convo.title || "新对话"; + item.addEventListener("click", () => { + activeId = convo.id; + clearPendingEdit(); + renderSessions(); + renderChat(convo); + }); + sessionList.appendChild(item); + } +}; + +const createConversation = () => { + const convo = { + id: createLocalId(), + serverId: "", + title: "新对话", + messages: [], + }; + conversations.unshift(convo); + activeId = convo.id; + clearPendingEdit(); + renderSessions(); + renderChat(convo); + return convo; +}; + +const getActiveConversation = () => { + let convo = conversations.find((c) => c.id === activeId); + if (!convo) { + convo = createConversation(); + } + return convo; +}; + +const streamChat = async (payload, bubble, convo, controller) => { + const res = await fetch(`${endpointInput.value}/v1/chat/completions`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ ...payload, stream: true }), + signal: controller.signal, + }); + + if (!res.ok || !res.body) { + throw new Error(`请求失败:${res.status}`); + } + + const reader = res.body.getReader(); + const decoder = new TextDecoder("utf-8"); + let buffer = ""; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + const parts = buffer.split("\n\n"); + buffer = parts.pop() || ""; + for (const part of parts) { + if (!part.startsWith("data: ")) continue; + const payload_str = part.slice(6).trim(); + if (payload_str === "[DONE]") return; + const data = JSON.parse(payload_str); + if (data.session_id && !convo.serverId) { + convo.serverId = data.session_id; + } + const delta = data.choices && data.choices[0] && data.choices[0].delta; + if (delta && delta.content) { + const raw = (bubble.dataset.raw || "") + delta.content; + renderAssistantBubble(bubble, raw); + } + if (data.choices && data.choices[0] && data.choices[0].finish_reason) { + return; + } + } + } +}; + +form.addEventListener("submit", async (event) => { + event.preventDefault(); + const prompt = promptInput.value.trim(); + if (!prompt) return; + + const activeConvo = getActiveConversation(); + let convo = activeConvo; + let payloadSessionId = activeConvo.serverId; + let payloadEditFrom = ""; + let payloadEditIndex = -1; + const editState = pendingEdit; + const isForkEdit = Boolean(editState && editState.sourceServerId); + + if (isForkEdit) { + const sourceConvo = conversations.find((c) => c.id === editState.sourceLocalId); + if (!sourceConvo || !sourceConvo.serverId) { + clearPendingEdit(); + return; + } + convo = createConversation(); + convo.title = `${sourceConvo.title || "新对话"} (分叉)`; + const prefix = sourceConvo.messages.slice(0, editState.editMessageIndex + 1).map((m) => ({ ...m })); + if ( + prefix.length === 0 || + prefix[prefix.length - 1].role !== "user" + ) { + clearPendingEdit(); + return; + } + prefix[prefix.length - 1].text = prompt; + convo.messages = prefix; + renderSessions(); + renderChat(convo); + payloadEditFrom = sourceConvo.serverId; + payloadEditIndex = editState.editMessageIndex; + payloadSessionId = ""; + } + + if (!isForkEdit) { + convo.messages.push({ role: "user", text: prompt }); + appendBubble(prompt, "user", { canEdit: false }); + } + promptInput.value = ""; + const assistantBubble = appendBubble("", "assistant"); + convo.messages.push({ role: "assistant", text: "" }); + + sendButton.disabled = true; + stopButton.disabled = false; + activeStreamController = new AbortController(); + const payload = { + prompt, + max_tokens: Number(maxTokensInput.value) || 128, + temperature: Number(temperatureInput.value) || 0, + top_k: Number(topKInput.value) || 1, + top_p: Number(topPInput.value) || 0, + seed: Number(seedInput.value) || 0, + }; + if (samplingModeInput.value) { + payload.sampling = samplingModeInput.value; + } + if (payloadSessionId) { + payload.session_id = payloadSessionId; + } + if (payloadEditFrom && payloadEditIndex >= 0) { + payload.edit_from_session_id = payloadEditFrom; + payload.edit_message_index = payloadEditIndex; + } + + try { + await streamChat(payload, assistantBubble, convo, activeStreamController); + convo.messages[convo.messages.length - 1].text = assistantBubble.dataset.raw || ""; + if (convo.title === "新对话") { + convo.title = prompt.slice(0, 12); + renderSessions(); + } + } catch (err) { + if (err && err.name === "AbortError") { + convo.messages[convo.messages.length - 1].text = assistantBubble.dataset.raw || ""; + return; + } + renderAssistantBubble(assistantBubble, `请求失败:${err.message}`); + convo.messages[convo.messages.length - 1].text = assistantBubble.dataset.raw || ""; + } finally { + clearPendingEdit(); + activeStreamController = null; + stopButton.disabled = true; + sendButton.disabled = false; + } +}); + +newChatButton.addEventListener("click", () => { + createConversation(); +}); + +stopButton.addEventListener("click", async () => { + if (activeStreamController) { + activeStreamController.abort(); + } + const convo = getActiveConversation(); + if (!convo.serverId) { + return; + } + try { + await fetch(`${endpointInput.value}/chat/stop`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ session_id: convo.serverId }), + }); + } catch (_) { + // no-op + } +}); + +document.addEventListener("keydown", (event) => { + if (event.key === "Escape" && pendingEdit) { + clearPendingEdit(); + } +}); + +createConversation(); diff --git a/frontend/index.html b/frontend/index.html new file mode 100644 index 000000000..fd94944a8 --- /dev/null +++ b/frontend/index.html @@ -0,0 +1,76 @@ + + + + + + LLAISYS Chat + + + +
+ + +
+
+

对话

+
+ +
+
+ +
+ +
+ + +
+
+ + + + + + +
+
+ + +
+
+
+
+
+ + + + diff --git a/frontend/style.css b/frontend/style.css new file mode 100644 index 000000000..1f8da19cb --- /dev/null +++ b/frontend/style.css @@ -0,0 +1,245 @@ +* { + box-sizing: border-box; +} + +body { + margin: 0; + font-family: "Segoe UI", "Microsoft YaHei", sans-serif; + background: #0f1115; + color: #e6e6e6; +} + +.container { + max-width: 960px; + margin: 0 auto; + padding: 24px; + display: grid; + grid-template-columns: 220px 1fr; + gap: 16px; + min-height: 100vh; +} + +.sidebar { + background: #141820; + border: 1px solid #252a35; + border-radius: 10px; + padding: 16px; + display: flex; + flex-direction: column; + gap: 12px; + height: fit-content; +} + +.brand { + font-size: 20px; + font-weight: 700; +} + +.new-chat { + width: 100%; +} + +.session-label { + font-size: 12px; + color: #9aa4b2; + text-transform: uppercase; + letter-spacing: 0.08em; +} + +.session-list { + display: flex; + flex-direction: column; + gap: 8px; +} + +.session-item { + padding: 8px 10px; + border-radius: 8px; + background: #1b1f2a; + cursor: pointer; + border: 1px solid transparent; +} + +.session-item.active { + border-color: #1f6feb; + background: #1c273d; +} + +.panel { + display: flex; + flex-direction: column; + gap: 16px; +} + +.header { + display: flex; + flex-direction: column; + gap: 12px; +} + +.header h1 { + margin: 0; +} + +.meta input { + width: 360px; + max-width: 100%; + padding: 6px 8px; + border-radius: 6px; + border: 1px solid #2b2f3a; + background: #171a21; + color: inherit; +} + +.chat { + flex: 1; + background: #141820; + border: 1px solid #252a35; + border-radius: 10px; + padding: 16px; + overflow-y: auto; + min-height: 360px; + display: flex; + flex-direction: column; +} + +.bubble { + padding: 10px 12px; + border-radius: 10px; + margin-bottom: 12px; + white-space: pre-wrap; + line-height: 1.5; +} + +.bubble.user { + background: #1f6feb; + color: white; + align-self: flex-end; + display: flex; + gap: 8px; + align-items: flex-start; +} + +.user-content { + white-space: pre-wrap; +} + +.bubble-edit { + background: rgba(255, 255, 255, 0.18); + border: 1px solid rgba(255, 255, 255, 0.45); + color: #fff; + padding: 2px 8px; + border-radius: 6px; + font-size: 12px; + line-height: 1.4; + cursor: pointer; + flex: 0 0 auto; +} + +.bubble.assistant { + background: #222836; + padding: 8px 10px; +} + +.assistant-thinking { + margin-bottom: 6px; + padding: 6px 8px; + border-left: 3px solid #6b7280; + background: rgba(255, 255, 255, 0.03); +} + +.assistant-thinking-label { + font-size: 12px; + color: #a7b0bf; + margin-bottom: 4px; +} + +.assistant-thinking-content { + font-size: 13px; + line-height: 1.45; + color: #c9d1dd; + white-space: pre-wrap; + max-height: 120px; + overflow-y: auto; +} + +.assistant-answer { + font-size: 18px; + line-height: 1.55; + color: #f2f5fb; + white-space: pre-wrap; +} + +.composer { + display: flex; + flex-direction: column; + gap: 12px; +} + +.edit-hint { + padding: 8px 10px; + border-radius: 8px; + border: 1px solid #375a8f; + background: #18263d; + color: #cfe3ff; + font-size: 13px; +} + +textarea { + width: 100%; + padding: 12px; + border-radius: 10px; + border: 1px solid #2b2f3a; + background: #171a21; + color: inherit; + resize: vertical; +} + +.actions { + display: flex; + justify-content: space-between; + align-items: center; + gap: 12px; +} + +.action-buttons { + display: flex; + gap: 8px; + align-items: center; +} + +.controls { + display: flex; + flex-wrap: wrap; + gap: 12px; + align-items: center; +} + +.actions input, +.actions select { + width: 110px; + padding: 6px 8px; + border-radius: 6px; + border: 1px solid #2b2f3a; + background: #171a21; + color: inherit; +} + +button { + padding: 8px 18px; + border: none; + border-radius: 8px; + background: #1f6feb; + color: white; + font-weight: 600; + cursor: pointer; +} + +button:disabled { + opacity: 0.6; + cursor: not-allowed; +} + +button.secondary { + background: #3a4456; +} diff --git a/include/llaisys.h b/include/llaisys.h index 73ca7eead..3bd516920 100644 --- a/include/llaisys.h +++ b/include/llaisys.h @@ -24,6 +24,7 @@ typedef enum { LLAISYS_DEVICE_CPU = 0, //// TODO: Add more device types here. Numbers need to be consecutive. LLAISYS_DEVICE_NVIDIA = 1, + LLAISYS_DEVICE_ILUVATAR = 2, LLAISYS_DEVICE_TYPE_COUNT } llaisysDeviceType_t; diff --git a/include/llaisys/comm.h b/include/llaisys/comm.h new file mode 100644 index 000000000..0afefa82b --- /dev/null +++ b/include/llaisys/comm.h @@ -0,0 +1,50 @@ +#ifndef LLAISYS_COMM_H +#define LLAISYS_COMM_H + +#include "../llaisys.h" + +__C { + // Communication Types + typedef void *llaisysComm_t; + + typedef enum { + LLAISYS_COMM_NCCL = 0, + LLAISYS_COMM_IXCCL = 1, + LLAISYS_COMM_MPI = 2, + } llaisysCommBackend_t; + + typedef enum { + LLAISYS_REDUCE_SUM = 0, + LLAISYS_REDUCE_PROD = 1, + LLAISYS_REDUCE_MIN = 2, + LLAISYS_REDUCE_MAX = 3, + } llaisysReduceOp_t; + + #define LLAISYS_COMM_UNIQUE_ID_MAX_SIZE 128 + + // Communication API Functions + typedef int (*comm_init_api)(llaisysComm_t *, int, int, const void *); + typedef void (*comm_destroy_api)(llaisysComm_t); + typedef int (*comm_get_rank_api)(llaisysComm_t); + typedef int (*comm_get_size_api)(llaisysComm_t); + typedef void (*comm_allreduce_api)(const void *, void *, size_t, llaisysDataType_t, llaisysReduceOp_t, llaisysComm_t, llaisysStream_t); + typedef void (*comm_broadcast_api)(void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t); + typedef void (*comm_send_api)(const void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t); + typedef void (*comm_recv_api)(void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t); + + struct LlaisysCommAPI { + comm_init_api init; + comm_destroy_api destroy; + comm_get_rank_api get_rank; + comm_get_size_api get_size; + comm_allreduce_api allreduce; + comm_broadcast_api broadcast; + comm_send_api send; + comm_recv_api recv; + }; + + __export const LlaisysCommAPI *llaisysGetCommAPI(llaisysCommBackend_t); + __export int llaisysCommGenerateUniqueId(llaisysCommBackend_t backend, void *id_out, size_t *id_size); +} + +#endif // LLAISYS_COMM_H diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h index 7054626d4..f18d09a11 100644 --- a/include/llaisys/models/qwen2.h +++ b/include/llaisys/models/qwen2.h @@ -2,15 +2,22 @@ #define LLAISYS_MODELS_QWEN2_H #include "../tensor.h" +#include "../comm.h" __C { + //千问2模型元信息 struct LlaisysQwen2Meta { + //数据类型 llaisysDataType_t dtype; + //模型参数 size_t nlayer, hs, nh, nkvh, dh, di, maxseq, voc; + //其他参数 float epsilon, theta; + //特殊token int64_t end_token; }; + //千问2模型权重 struct LlaisysQwen2Weights { llaisysTensor_t in_embed; llaisysTensor_t out_embed; @@ -29,14 +36,135 @@ __C { llaisysTensor_t *mlp_down_w; }; + // 采样参数 + struct LlaisysSamplingParams { + int32_t top_k; // <=1 表示贪心 + float top_p; // (0,1],<=0 表示不启用 + float temperature; // <=0 表示禁用温度缩放 + uint32_t seed; // 0 表示随机 + }; + + //千问2模型 struct LlaisysQwen2Model; + // KV block / context (experimental) + struct LlaisysQwen2KVBlock; + struct LlaisysQwen2KVContext; + + struct LlaisysQwen2KVBlockMeta { + llaisysDataType_t dtype; + size_t nlayer, nh, nkvh, dh; + size_t max_tokens; + }; + //创建千问2模型实例 __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice); + //销毁千问2模型实例 __export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model); + //获取千问2模型权重 __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model); + //执行千问2模型推理(兼容接口,建议改用 Prefill/Step) __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); + + //执行千问2模型预填充(prefill) + __export int64_t llaisysQwen2ModelPrefill(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); + + //执行千问2模型单步解码(step) + __export int64_t llaisysQwen2ModelStep(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); + + //执行千问2模型批量预填充(packed prompts) + // token_offsets 长度为 nseq + 1,且 token_offsets[0]=0, token_offsets[nseq]=ntoken + // out_next_tokens 需为长度 nseq 的可写缓冲区 + __export int32_t llaisysQwen2ModelPrefillPacked(struct LlaisysQwen2Model *model, + int64_t *token_ids, + const int64_t *token_offsets, + size_t nseq, + int64_t *out_next_tokens); + //执行千问2模型批量解码(packed,当前为过渡语义,详见实现注释) + __export int32_t llaisysQwen2ModelStepPacked(struct LlaisysQwen2Model *model, + int64_t *token_ids, + const int64_t *token_offsets, + size_t nseq, + int64_t *out_next_tokens); + + //执行千问2模型预填充(prefill,带采样参数) + __export int64_t llaisysQwen2ModelPrefillSampling(struct LlaisysQwen2Model * model, + int64_t * token_ids, + size_t ntoken, + const struct LlaisysSamplingParams *params); + + //执行千问2模型单步解码(step,带采样参数) + __export int64_t llaisysQwen2ModelStepSampling(struct LlaisysQwen2Model * model, + int64_t * token_ids, + size_t ntoken, + const struct LlaisysSamplingParams *params); + + //执行千问2模型推理(带采样参数) + __export int64_t llaisysQwen2ModelInferSampling(struct LlaisysQwen2Model * model, + int64_t * token_ids, + size_t ntoken, + const struct LlaisysSamplingParams *params); + + //执行千问2模型推理(带采样参数,按值传递) + __export int64_t llaisysQwen2ModelInferSamplingEx(struct LlaisysQwen2Model * model, + int64_t * token_ids, + size_t ntoken, + int32_t top_k, + float top_p, + float temperature, + uint32_t seed); + + //重置千问2模型的 KV-cache + __export void llaisysQwen2ModelResetKVCache(struct LlaisysQwen2Model * model); + + //启用/禁用 KV-cache + __export void llaisysQwen2ModelSetKVCacheEnabled(struct LlaisysQwen2Model * model, uint8_t enabled); + + //设置张量并行参数 + __export int32_t llaisysQwen2ModelSetTensorParallel(struct LlaisysQwen2Model *model, + llaisysComm_t comm, + llaisysStream_t stream, + int tp_size); + + // ===== Experimental KV block/context APIs ===== + __export struct LlaisysQwen2KVBlock *llaisysQwen2KVBlockCreate( + const struct LlaisysQwen2KVBlockMeta *meta, + llaisysDeviceType_t device, + int device_id); + __export void llaisysQwen2KVBlockRetain(struct LlaisysQwen2KVBlock *block); + __export void llaisysQwen2KVBlockRelease(struct LlaisysQwen2KVBlock *block); + __export int32_t llaisysQwen2KVBlockSetTokenCount(struct LlaisysQwen2KVBlock *block, size_t used_tokens); + __export size_t llaisysQwen2KVBlockTokenCount(const struct LlaisysQwen2KVBlock *block); + __export llaisysTensor_t llaisysQwen2KVBlockKeyTensor(struct LlaisysQwen2KVBlock *block, size_t layer); + __export llaisysTensor_t llaisysQwen2KVBlockValueTensor(struct LlaisysQwen2KVBlock *block, size_t layer); + + __export struct LlaisysQwen2KVContext *llaisysQwen2KVContextCreate( + llaisysDataType_t dtype, + llaisysDeviceType_t device, + int device_id, + size_t nlayer, + size_t nh, + size_t nkvh, + size_t dh); + __export void llaisysQwen2KVContextRetain(struct LlaisysQwen2KVContext *ctx); + __export void llaisysQwen2KVContextRelease(struct LlaisysQwen2KVContext *ctx); + __export int32_t llaisysQwen2KVContextAttachBlock( + struct LlaisysQwen2KVContext *ctx, + struct LlaisysQwen2KVBlock *block); + __export void llaisysQwen2KVContextDetachAll(struct LlaisysQwen2KVContext *ctx); + __export size_t llaisysQwen2KVContextBlockCount(const struct LlaisysQwen2KVContext *ctx); + __export size_t llaisysQwen2KVContextTokenCount(const struct LlaisysQwen2KVContext *ctx); + + __export int32_t llaisysQwen2ModelSetKVContext( + struct LlaisysQwen2Model *model, + struct LlaisysQwen2KVContext *ctx); + __export struct LlaisysQwen2KVContext *llaisysQwen2ModelGetKVContext( + struct LlaisysQwen2Model *model); + __export int32_t llaisysQwen2ModelExportKVContext( + struct LlaisysQwen2Model *model, + struct LlaisysQwen2KVContext *ctx, + size_t block_tokens); } #endif // LLAISYS_MODELS_QWEN2_H diff --git a/include/llaisys/ops.h b/include/llaisys/ops.h index ddb3be246..b79f074ca 100644 --- a/include/llaisys/ops.h +++ b/include/llaisys/ops.h @@ -12,6 +12,16 @@ __C { __export void llaisysRmsNorm(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, float eps); __export void llaisysROPE(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t pos_ids, float theta); __export void llaisysSelfAttention(llaisysTensor_t attn_val, llaisysTensor_t q, llaisysTensor_t k, llaisysTensor_t v, float scale); + // Segmented self-attention for packed batches. + // q_offsets/kv_offsets must both have length nseg + 1 and be non-decreasing. + __export void llaisysSelfAttentionSegmented(llaisysTensor_t attn_val, + llaisysTensor_t q, + llaisysTensor_t k, + llaisysTensor_t v, + float scale, + const int64_t *q_offsets, + const int64_t *kv_offsets, + size_t nseg); __export void llaisysSwiGLU(llaisysTensor_t out, llaisysTensor_t gate, llaisysTensor_t up); } diff --git a/include/llaisys/tokenizer.h b/include/llaisys/tokenizer.h new file mode 100644 index 000000000..d7ff6a6c1 --- /dev/null +++ b/include/llaisys/tokenizer.h @@ -0,0 +1,45 @@ +#ifndef LLAISYS_TOKENIZER_H +#define LLAISYS_TOKENIZER_H + +#include "../llaisys.h" + +__C { + struct LlaisysTokenizer; + + // Create a SentencePiece tokenizer from model file path. + __export struct LlaisysTokenizer *llaisysTokenizerCreateSentencePiece(const char *model_path); + + // Destroy tokenizer instance. + __export void llaisysTokenizerDestroy(struct LlaisysTokenizer *tokenizer); + + // Encode text into token ids. + // If out_ids is null or max_ids is 0, returns the required length. + // On error returns -1. + __export int llaisysTokenizerEncode(struct LlaisysTokenizer *tokenizer, + const char *text, + int64_t *out_ids, + size_t max_ids); + + // Decode token ids into text. + // If out_text is null or max_len is 0, returns the required length (including null terminator). + // On error returns -1. + __export int llaisysTokenizerDecode(struct LlaisysTokenizer *tokenizer, + const int64_t *ids, + size_t len, + char *out_text, + size_t max_len); + + // Map a single token string to its id. Returns -1 if not found. + __export int64_t llaisysTokenizerTokenToId(struct LlaisysTokenizer *tokenizer, const char *token); + + // Map a token id to its string. + // If out_token is null or max_len is 0, returns the required length (including null terminator). + // On error returns -1. + __export int llaisysTokenizerIdToToken(struct LlaisysTokenizer *tokenizer, + int64_t id, + char *out_token, + size_t max_len); + +} + +#endif // LLAISYS_TOKENIZER_H diff --git a/python/llaisys/__init__.py b/python/llaisys/__init__.py index de8d99f48..69ca3476e 100644 --- a/python/llaisys/__init__.py +++ b/python/llaisys/__init__.py @@ -5,6 +5,7 @@ from .libllaisys import llaisysStream_t as Stream from .tensor import Tensor from .ops import Ops +from .tokenizer import Tokenizer from . import models from .models import * @@ -16,5 +17,6 @@ "Stream", "Tensor", "Ops", + "Tokenizer", "models", ] diff --git a/python/llaisys/interfaces.py b/python/llaisys/interfaces.py new file mode 100644 index 000000000..6e8461958 --- /dev/null +++ b/python/llaisys/interfaces.py @@ -0,0 +1,194 @@ +"""接口定义 - 解耦调度器、服务、KVCache 池之间的依赖""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Sequence + +if TYPE_CHECKING: + from llaisys.kv_cache_pool import AcquireResult + + +class IKVCachePool(ABC): + """KVCache 池接口 + + 调度器可以通过此接口查询 KV 状态,而不需要知道具体实现。 + """ + + @property + @abstractmethod + def block_size(self) -> int: + """每个 block 的 token 数量""" + pass + + @abstractmethod + def query_prefix_len(self, tokens: Sequence[int]) -> int: + """查询前缀命中长度(只读,不修改状态) + + Args: + tokens: 待查询的 token 序列 + + Returns: + 命中的前缀长度(token 数量) + """ + pass + + @abstractmethod + def acquire_context(self, context_id: str, tokens: Sequence[int]) -> "AcquireResult": + """获取/创建上下文,返回匹配的前缀长度 + + Args: + context_id: 上下文/会话 ID + tokens: 当前请求的完整 token 序列 + + Returns: + AcquireResult,包含 context_id 和 prefix_len + """ + pass + + @abstractmethod + def update_context(self, context_id: str, tokens: Sequence[int]) -> None: + """更新上下文的 token 序列(生成结束后调用)""" + pass + + @abstractmethod + def release_context(self, context_id: str) -> None: + """释放上下文""" + pass + + @abstractmethod + def memory_pressure(self) -> float: + """返回 KV 内存压力值 (0.0~1.0) + + 取 used_blocks/max_blocks 和 used_bytes/max_bytes 的较大值。 + 调度器可用此值做流控决策。 + """ + pass + + @abstractmethod + def snapshot_stats(self) -> Dict[str, float]: + """获取统计信息快照""" + pass + + +class IInferenceService(ABC): + """推理服务接口 + + 调度器依赖此接口进行任务分发和执行。 + """ + + @abstractmethod + def generate(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """非流式生成 + + Args: + payload: 请求参数(prompt, session_id, max_new_tokens 等) + + Returns: + 生成结果(response, session_id, usage 等) + """ + pass + + @abstractmethod + def stream(self, payload: Dict[str, Any]) -> Iterable[Dict[str, Any]]: + """流式生成 + + Args: + payload: 请求参数 + + Yields: + 流式输出的每个 chunk(delta, done, session_id 等) + """ + pass + + @abstractmethod + def request_stop(self, session_id: str) -> bool: + """请求停止生成 + + Args: + session_id: 要停止的会话 ID + + Returns: + 是否成功发送停止信号 + """ + pass + + @abstractmethod + def kv_debug_snapshot(self, session_id: Optional[str] = None) -> Dict[str, Any]: + """获取 KV 调试快照 + + Args: + session_id: 可选,指定会话 ID + + Returns: + 调试信息字典 + """ + pass + + @property + @abstractmethod + def kv_pool(self) -> IKVCachePool: + """暴露 KVCache 池给调度器查询""" + pass + + def generate_packed_non_stream( + self, payloads: Sequence[Dict[str, Any]] + ) -> Optional[Sequence[Dict[str, Any]]]: + """批量非流式生成(可选实现) + + Args: + payloads: 多个请求参数 + + Returns: + 批量生成结果,如果不支持则返回 None + """ + return None + + def tokenize_for_routing( + self, payload: Dict[str, Any] + ) -> Optional[Sequence[int]]: + """为 KV 感知路由进行轻量级 tokenize(可选实现) + + 调度器可以调用此方法将请求转换为 token 序列, + 用于查询各 worker 的 KV 命中情况。 + + Args: + payload: 请求参数(prompt, messages 等) + + Returns: + token ids 序列,如果无法 tokenize 则返回 None + """ + return None + + def prepare_batch( + self, payloads: Sequence[Dict[str, Any]] + ) -> Optional[Any]: + """准备���式批处理:prefill 所有序列,返回 BatchState(可选实现) + + Args: + payloads: 多个请求参数 + + Returns: + BatchState 对象,如果不支持则返回 None + """ + return None + + def step_batch(self, state: Any) -> Optional[Sequence[Any]]: + """执行一步批量 decode,返回每个序列的 StepResult(可选实现) + + Args: + state: prepare_batch 返回的 BatchState + + Returns: + StepResult 列表,如果不支持则返回 None + """ + return None + + def finalize_sequence(self, state: Any, seq_index: int) -> None: + """完成单个序列:保存消息历史,清理状态(可选实现) + + Args: + state: BatchState 对象 + seq_index: 序列在 batch 中的索引 + """ + pass diff --git a/python/llaisys/kv_cache_pool.py b/python/llaisys/kv_cache_pool.py new file mode 100644 index 000000000..7f6d952c4 --- /dev/null +++ b/python/llaisys/kv_cache_pool.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +from dataclasses import dataclass +import threading +import time +from typing import Dict, List, Optional, Sequence, Tuple + +from llaisys.interfaces import IKVCachePool + + +@dataclass +class KVBlock: + block_id: int + generation: int + parent_id: Optional[int] + tokens: Tuple[int, ...] + sealed: bool + ref_count: int + last_access: float + prefix_key: Optional[Tuple[int, ...]] + + @property + def size_bytes(self) -> int: + # int64 token ids + return len(self.tokens) * 8 + + +@dataclass +class ContextState: + block_ids: List[int] + tokens: Tuple[int, ...] + updated_at: float + + +@dataclass +class AcquireResult: + context_id: str + prefix_len: int + + +class KVCachePool(IKVCachePool): + """In-memory token-block cache pool with reference counting. + + Notes: + - Only sealed (full) blocks are indexed for cross-context sharing. + - Block IDs are monotonic and never reused. + """ + + def __init__( + self, + block_size: int = 64, + max_blocks: int = 4096, + max_bytes: int = 256 * 1024 * 1024, + ) -> None: + if block_size <= 0: + raise ValueError("block_size must be > 0") + self._block_size = int(block_size) + self.max_blocks = int(max_blocks) + self.max_bytes = int(max_bytes) + + self._lock = threading.Lock() + self._next_block_id = 1 + self._blocks: Dict[int, KVBlock] = {} + self._contexts: Dict[str, ContextState] = {} + # prefix(tuple(tokens up to this block)) -> (block_id, generation) + self._prefix_index: Dict[Tuple[int, ...], Tuple[int, int]] = {} + self._total_bytes = 0 + self._acquire_count = 0 + self._prefix_hit_count = 0 + self._matched_tokens_total = 0 + + @property + def block_size(self) -> int: + return self._block_size + + def acquire_context(self, context_id: str, tokens: Sequence[int]) -> AcquireResult: + """Bind context to current prompt tokens. + + Returns matched prefix length for runtime reuse decision. + """ + token_tuple = tuple(int(t) for t in tokens) + with self._lock: + _, matched_len = self._build_or_replace_context(context_id, token_tuple) + self._acquire_count += 1 + self._matched_tokens_total += matched_len + if matched_len > 0: + self._prefix_hit_count += 1 + return AcquireResult(context_id=context_id, prefix_len=matched_len) + + def update_context(self, context_id: str, tokens: Sequence[int]) -> None: + """Update context after generation to preserve longer prefixes.""" + token_tuple = tuple(int(t) for t in tokens) + with self._lock: + self._build_or_replace_context(context_id, token_tuple) + + def release_context(self, context_id: str) -> None: + with self._lock: + old_state = self._contexts.pop(context_id, None) + if not old_state: + return + self._decref_chain(old_state.block_ids) + self._evict_if_needed() + + def _build_or_replace_context(self, context_id: str, tokens: Tuple[int, ...]) -> Tuple[List[int], int]: + old_state = self._contexts.get(context_id) + old_block_ids = list(old_state.block_ids) if old_state else [] + + matched_block_ids, matched_len = self._find_longest_sealed_prefix(tokens) + new_block_ids = list(matched_block_ids) + created_block_ids: List[int] = [] + incref_applied: List[int] = [] + + try: + # First, acquire refs to reused blocks. + for bid in matched_block_ids: + self._incref_block(bid) + incref_applied.append(bid) + + parent_id = new_block_ids[-1] if new_block_ids else None + cursor = matched_len + current_prefix = tuple(tokens[:matched_len]) + while cursor < len(tokens): + chunk = tuple(tokens[cursor : cursor + self.block_size]) + sealed = len(chunk) == self.block_size + block_id = self._create_block(parent_id, chunk, sealed, current_prefix) + created_block_ids.append(block_id) + incref_applied.append(block_id) + new_block_ids.append(block_id) + parent_id = block_id + current_prefix = current_prefix + chunk + cursor += len(chunk) + + # Commit context first, then release old refs. + self._contexts[context_id] = ContextState( + block_ids=new_block_ids, + tokens=tokens, + updated_at=time.time(), + ) + self._decref_chain(old_block_ids) + self._evict_if_needed() + return new_block_ids, matched_len + except Exception: + # Rollback ref changes and newly created blocks. + self._safe_rollback(incref_applied, created_block_ids) + # Keep old state untouched. + if old_state is None: + self._contexts.pop(context_id, None) + else: + self._contexts[context_id] = old_state + raise + + def _safe_rollback(self, incref_applied: List[int], created_block_ids: List[int]) -> None: + # Rollback refs idempotently. + seen = set() + for bid in reversed(incref_applied): + if bid in seen: + continue + seen.add(bid) + block = self._blocks.get(bid) + if not block: + continue + if block.ref_count > 0: + block.ref_count -= 1 + block.last_access = time.time() + # Remove newly created zero-ref blocks. + for bid in created_block_ids: + block = self._blocks.get(bid) + if block and block.ref_count == 0: + self._remove_block(bid) + + def _find_longest_sealed_prefix(self, tokens: Tuple[int, ...]) -> Tuple[List[int], int]: + matched_block_ids: List[int] = [] + matched_len = 0 + parent_id: Optional[int] = None + cursor = 0 + prefix: Tuple[int, ...] = () + + while cursor + self.block_size <= len(tokens): + chunk = tuple(tokens[cursor : cursor + self.block_size]) + prefix = prefix + chunk + indexed = self._prefix_index.get(prefix) + if not indexed: + break + bid, generation = indexed + block = self._blocks.get(bid) + if ( + block is None + or block.generation != generation + or not block.sealed + or block.parent_id != parent_id + or block.tokens != chunk + ): + break + matched_block_ids.append(bid) + matched_len += self.block_size + parent_id = bid + cursor += self.block_size + return matched_block_ids, matched_len + + def _create_block( + self, + parent_id: Optional[int], + tokens: Tuple[int, ...], + sealed: bool, + current_prefix: Tuple[int, ...], + ) -> int: + block_id = self._next_block_id + self._next_block_id += 1 + generation = 1 + prefix_key = current_prefix + tokens if sealed else None + block = KVBlock( + block_id=block_id, + generation=generation, + parent_id=parent_id, + tokens=tokens, + sealed=sealed, + ref_count=1, + last_access=time.time(), + prefix_key=prefix_key, + ) + self._blocks[block_id] = block + self._total_bytes += block.size_bytes + if sealed and prefix_key is not None: + self._prefix_index[prefix_key] = (block_id, generation) + return block_id + + def _incref_block(self, block_id: int) -> None: + block = self._blocks.get(block_id) + if not block: + raise RuntimeError(f"missing block {block_id}") + block.ref_count += 1 + block.last_access = time.time() + + def _decref_chain(self, block_ids: List[int]) -> None: + # Idempotent-ish: never below zero. + for bid in block_ids: + block = self._blocks.get(bid) + if not block: + continue + if block.ref_count > 0: + block.ref_count -= 1 + block.last_access = time.time() + + def _evict_if_needed(self) -> None: + while len(self._blocks) > self.max_blocks or self._total_bytes > self.max_bytes: + evict_candidates = [b for b in self._blocks.values() if b.ref_count == 0] + if not evict_candidates: + break + victim = min(evict_candidates, key=lambda b: b.last_access) + self._remove_block(victim.block_id) + + def _remove_block(self, block_id: int) -> None: + block = self._blocks.pop(block_id, None) + if not block: + return + self._total_bytes = max(0, self._total_bytes - block.size_bytes) + if block.prefix_key is not None: + indexed = self._prefix_index.get(block.prefix_key) + if indexed and indexed[0] == block_id: + self._prefix_index.pop(block.prefix_key, None) + + def memory_pressure(self) -> float: + """返回 KV 内存压力值 (0.0~1.0)""" + with self._lock: + block_ratio = len(self._blocks) / self.max_blocks if self.max_blocks > 0 else 0.0 + byte_ratio = self._total_bytes / self.max_bytes if self.max_bytes > 0 else 0.0 + return max(block_ratio, byte_ratio) + + def snapshot_stats(self) -> Dict[str, float]: + """Return lightweight stats for verification and debugging.""" + with self._lock: + zero_ref_blocks = sum(1 for b in self._blocks.values() if b.ref_count == 0) + shared_blocks = sum(1 for b in self._blocks.values() if b.ref_count > 1) + total_refs = sum(b.ref_count for b in self._blocks.values()) + hit_rate = ( + float(self._prefix_hit_count) / float(self._acquire_count) + if self._acquire_count > 0 + else 0.0 + ) + avg_matched_tokens = ( + float(self._matched_tokens_total) / float(self._acquire_count) + if self._acquire_count > 0 + else 0.0 + ) + return { + "contexts": float(len(self._contexts)), + "blocks": float(len(self._blocks)), + "prefix_entries": float(len(self._prefix_index)), + "total_bytes": float(self._total_bytes), + "zero_ref_blocks": float(zero_ref_blocks), + "shared_blocks": float(shared_blocks), + "total_refs": float(total_refs), + "acquire_count": float(self._acquire_count), + "prefix_hit_count": float(self._prefix_hit_count), + "prefix_hit_rate": hit_rate, + "avg_matched_tokens": avg_matched_tokens, + } + + def query_prefix_len(self, tokens: Sequence[int]) -> int: + """查询前缀命中长度(只读,不修改状态) + + 调度器可以用此方法查询某个 token 序列在当前池中的命中情况, + 用于做 KV 感知的路由决策。 + + Args: + tokens: 待查询的 token 序列 + + Returns: + 命中的前缀长度(token 数量),0 表示无命中 + """ + token_tuple = tuple(int(t) for t in tokens) + with self._lock: + _, matched_len = self._find_longest_sealed_prefix(token_tuple) + return matched_len + + def debug_context(self, context_id: str) -> Optional[Dict[str, object]]: + """Return context chain snapshot for tests and diagnostics.""" + with self._lock: + state = self._contexts.get(context_id) + if state is None: + return None + return { + "context_id": context_id, + "tokens": list(state.tokens), + "block_ids": list(state.block_ids), + "updated_at": state.updated_at, + } diff --git a/python/llaisys/kv_runtime_bridge.py b/python/llaisys/kv_runtime_bridge.py new file mode 100644 index 000000000..c6433cae2 --- /dev/null +++ b/python/llaisys/kv_runtime_bridge.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import threading +from typing import Any, Dict, List, Optional, Tuple + + +class KVRuntimeBridge: + """Bridge for native C++ KV context lifecycle (bind, export, find, release, debug).""" + + def __init__(self, model: Any, enabled: bool = False) -> None: + self._model = model + self._enabled = bool(enabled) + self._lock = threading.Lock() + self._native_kv_contexts: Dict[str, Any] = {} + self._native_kv_tokens: Dict[str, Tuple[int, ...]] = {} + self._last_kv_bind_debug: Dict[str, Dict[str, Any]] = {} + + @property + def enabled(self) -> bool: + return self._enabled + + def bind_for_request( + self, + context_id: str, + prompt_ids: List[int], + prefix_len: int, + ) -> None: + """Bind the best KV context to the model for the current request. + + Search order: + 1. Same context_id native context + 2. Prefix-matching donor context + 3. No match -> set_kv_context(None) + """ + debug: Dict[str, Any] = { + "enabled": self._enabled, + "session_id": context_id, + "prefix_len": int(prefix_len), + "bound": False, + "source_session_id": None, + "set_kv_context_rc": None, + } + if not self._enabled or prefix_len <= 0: + self._model.set_kv_context(None) + with self._lock: + self._last_kv_bind_debug[context_id] = debug + return + with self._lock: + ctx = self._native_kv_contexts.get(context_id) + source_session_id: Optional[str] = context_id if ctx else None + if not ctx: + source_session_id, ctx = self._find_for_prefix(prompt_ids, prefix_len) + if not ctx: + self._model.set_kv_context(None) + with self._lock: + self._last_kv_bind_debug[context_id] = debug + return + rc = self._model.set_kv_context(ctx) + debug["set_kv_context_rc"] = int(rc) + debug["source_session_id"] = source_session_id + if rc != 0: + self._model.set_kv_context(None) + else: + debug["bound"] = True + with self._lock: + self._last_kv_bind_debug[context_id] = debug + + def export_after_request( + self, + context_id: str, + tokens: List[int], + block_size: int, + ) -> None: + """Export KV context after request completion for future reuse.""" + if not self._enabled: + return + with self._lock: + ctx = self._native_kv_contexts.get(context_id) + if not ctx: + ctx = self._model.kv_context_create() + if not ctx: + return + with self._lock: + self._native_kv_contexts[context_id] = ctx + rc = self._model.export_kv_context(ctx, block_size) + if rc == 0: + with self._lock: + self._native_kv_tokens[context_id] = tuple(int(t) for t in tokens) + + def release(self, context_id: str) -> None: + """Release native KV context for a given session.""" + with self._lock: + ctx = self._native_kv_contexts.pop(context_id, None) + self._native_kv_tokens.pop(context_id, None) + self._last_kv_bind_debug.pop(context_id, None) + if ctx: + self._model.kv_context_release(ctx) + + def debug_snapshot(self, session_id: Optional[str] = None) -> Dict[str, Any]: + """Return KV runtime debug information.""" + with self._lock: + if session_id: + last_bind = dict(self._last_kv_bind_debug.get(session_id, {})) + native_tokens = len(self._native_kv_tokens.get(session_id, ())) + has_native_ctx = session_id in self._native_kv_contexts + else: + last_bind = {} + native_tokens = 0 + has_native_ctx = False + native_contexts = len(self._native_kv_contexts) + tracked_token_sessions = len(self._native_kv_tokens) + return { + "session_id": session_id, + "has_native_context": has_native_ctx, + "native_tokens": native_tokens, + "native_contexts": native_contexts, + "tracked_token_sessions": tracked_token_sessions, + "last_bind": last_bind, + } + + def _find_for_prefix( + self, prompt_ids: List[int], prefix_len: int + ) -> Tuple[Optional[str], Any]: + """Find native KV context matching the given prefix.""" + if prefix_len <= 0: + return None, None + prompt_prefix = tuple(prompt_ids[:prefix_len]) + with self._lock: + best_sid: Optional[str] = None + best_ctx: Any = None + best_len = -1 + for sid, ctx in self._native_kv_contexts.items(): + tokens = self._native_kv_tokens.get(sid, ()) + tlen = len(tokens) + if tlen < prefix_len: + continue + if tuple(tokens[:prefix_len]) != prompt_prefix: + continue + if tlen > best_len: + best_len = tlen + best_sid = sid + best_ctx = ctx + return best_sid, best_ctx diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py index f536fb527..5d51a2b5d 100644 --- a/python/llaisys/libllaisys/__init__.py +++ b/python/llaisys/libllaisys/__init__.py @@ -12,6 +12,20 @@ from .tensor import llaisysTensor_t from .tensor import load_tensor from .ops import load_ops +from .models import load_models, load_comm +from .models import ( + LlaisysQwen2Meta, + LlaisysQwen2Weights, + LlaisysQwen2Model, + LlaisysSamplingParams, + LlaisysQwen2KVBlockMeta, + LlaisysQwen2KVBlock, + LlaisysQwen2KVContext, + LlaisysCommAPI, + llaisysComm_t, + LLAISYS_COMM_UNIQUE_ID_MAX_SIZE, +) +from .tokenizer import load_tokenizer, LlaisysTokenizer def load_shared_library(): @@ -38,6 +52,9 @@ def load_shared_library(): load_runtime(LIB_LLAISYS) load_tensor(LIB_LLAISYS) load_ops(LIB_LLAISYS) +load_models(LIB_LLAISYS) +load_comm(LIB_LLAISYS) +load_tokenizer(LIB_LLAISYS) __all__ = [ @@ -52,4 +69,15 @@ def load_shared_library(): "llaisysMemcpyKind_t", "MemcpyKind", "llaisysStream_t", + "LlaisysQwen2Meta", + "LlaisysQwen2Weights", + "LlaisysQwen2Model", + "LlaisysSamplingParams", + "LlaisysQwen2KVBlockMeta", + "LlaisysQwen2KVBlock", + "LlaisysQwen2KVContext", + "LlaisysCommAPI", + "llaisysComm_t", + "LLAISYS_COMM_UNIQUE_ID_MAX_SIZE", + "LlaisysTokenizer", ] diff --git a/python/llaisys/libllaisys/llaisys_types.py b/python/llaisys/libllaisys/llaisys_types.py index c5a0b4679..84c761b73 100644 --- a/python/llaisys/libllaisys/llaisys_types.py +++ b/python/llaisys/libllaisys/llaisys_types.py @@ -6,7 +6,8 @@ class DeviceType(IntEnum): CPU = 0 NVIDIA = 1 - COUNT = 2 + ILUVATAR = 2 + COUNT = 3 llaisysDeviceType_t = ctypes.c_int diff --git a/python/llaisys/libllaisys/models.py b/python/llaisys/libllaisys/models.py new file mode 100644 index 000000000..419c0f3e9 --- /dev/null +++ b/python/llaisys/libllaisys/models.py @@ -0,0 +1,287 @@ +from ctypes import Structure, POINTER, CFUNCTYPE, c_size_t, c_int, c_float, c_int64, c_uint32, c_void_p, c_int32 + +from .llaisys_types import llaisysDeviceType_t, llaisysDataType_t, llaisysStream_t +from .tensor import llaisysTensor_t + + +class LlaisysQwen2Meta(Structure): + _fields_ = [ + ("dtype", llaisysDataType_t), + ("nlayer", c_size_t), + ("hs", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("maxseq", c_size_t), + ("voc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_int64), + ] + + +class LlaisysQwen2Weights(Structure): + _fields_ = [ + ("in_embed", llaisysTensor_t), + ("out_embed", llaisysTensor_t), + ("out_norm_w", llaisysTensor_t), + ("attn_norm_w", POINTER(llaisysTensor_t)), + ("attn_q_w", POINTER(llaisysTensor_t)), + ("attn_q_b", POINTER(llaisysTensor_t)), + ("attn_k_w", POINTER(llaisysTensor_t)), + ("attn_k_b", POINTER(llaisysTensor_t)), + ("attn_v_w", POINTER(llaisysTensor_t)), + ("attn_v_b", POINTER(llaisysTensor_t)), + ("attn_o_w", POINTER(llaisysTensor_t)), + ("mlp_norm_w", POINTER(llaisysTensor_t)), + ("mlp_gate_w", POINTER(llaisysTensor_t)), + ("mlp_up_w", POINTER(llaisysTensor_t)), + ("mlp_down_w", POINTER(llaisysTensor_t)), + ] + +class LlaisysSamplingParams(Structure): + _fields_ = [ + ("top_k", c_int), + ("top_p", c_float), + ("temperature", c_float), + ("seed", c_uint32), + ] + + +class LlaisysQwen2KVBlockMeta(Structure): + _fields_ = [ + ("dtype", llaisysDataType_t), + ("nlayer", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("max_tokens", c_size_t), + ] + + +LlaisysQwen2Model = c_void_p +LlaisysQwen2KVBlock = c_void_p +LlaisysQwen2KVContext = c_void_p + + +def load_models(lib): + lib.llaisysQwen2ModelCreate.argtypes = [ + POINTER(LlaisysQwen2Meta), + llaisysDeviceType_t, + POINTER(c_int), + c_int, + ] + lib.llaisysQwen2ModelCreate.restype = LlaisysQwen2Model + + lib.llaisysQwen2ModelDestroy.argtypes = [LlaisysQwen2Model] + lib.llaisysQwen2ModelDestroy.restype = None + + lib.llaisysQwen2ModelWeights.argtypes = [LlaisysQwen2Model] + lib.llaisysQwen2ModelWeights.restype = POINTER(LlaisysQwen2Weights) + + lib.llaisysQwen2ModelInfer.argtypes = [LlaisysQwen2Model, POINTER(c_int64), c_size_t] + lib.llaisysQwen2ModelInfer.restype = c_int64 + + lib.llaisysQwen2ModelPrefill.argtypes = [LlaisysQwen2Model, POINTER(c_int64), c_size_t] + lib.llaisysQwen2ModelPrefill.restype = c_int64 + + lib.llaisysQwen2ModelStep.argtypes = [LlaisysQwen2Model, POINTER(c_int64), c_size_t] + lib.llaisysQwen2ModelStep.restype = c_int64 + if hasattr(lib, "llaisysQwen2ModelPrefillPacked"): + lib.llaisysQwen2ModelPrefillPacked.argtypes = [ + LlaisysQwen2Model, + POINTER(c_int64), + POINTER(c_int64), + c_size_t, + POINTER(c_int64), + ] + lib.llaisysQwen2ModelPrefillPacked.restype = c_int32 + if hasattr(lib, "llaisysQwen2ModelStepPacked"): + lib.llaisysQwen2ModelStepPacked.argtypes = [ + LlaisysQwen2Model, + POINTER(c_int64), + POINTER(c_int64), + c_size_t, + POINTER(c_int64), + ] + lib.llaisysQwen2ModelStepPacked.restype = c_int32 + + if hasattr(lib, "llaisysQwen2ModelPrefillPackedSampling"): + lib.llaisysQwen2ModelPrefillPackedSampling.argtypes = [ + LlaisysQwen2Model, + POINTER(c_int64), + POINTER(c_int64), + c_size_t, + POINTER(LlaisysSamplingParams), + POINTER(c_int64), + ] + lib.llaisysQwen2ModelPrefillPackedSampling.restype = c_int32 + + if hasattr(lib, "llaisysQwen2ModelStepPackedSampling"): + lib.llaisysQwen2ModelStepPackedSampling.argtypes = [ + LlaisysQwen2Model, + POINTER(c_int64), + POINTER(c_int64), + c_size_t, + POINTER(LlaisysSamplingParams), + POINTER(c_int64), + ] + lib.llaisysQwen2ModelStepPackedSampling.restype = c_int32 + + lib.llaisysQwen2ModelPrefillSampling.argtypes = [ + LlaisysQwen2Model, + POINTER(c_int64), + c_size_t, + POINTER(LlaisysSamplingParams), + ] + lib.llaisysQwen2ModelPrefillSampling.restype = c_int64 + + lib.llaisysQwen2ModelStepSampling.argtypes = [ + LlaisysQwen2Model, + POINTER(c_int64), + c_size_t, + POINTER(LlaisysSamplingParams), + ] + lib.llaisysQwen2ModelStepSampling.restype = c_int64 + + lib.llaisysQwen2ModelInferSampling.argtypes = [ + LlaisysQwen2Model, + POINTER(c_int64), + c_size_t, + POINTER(LlaisysSamplingParams), + ] + lib.llaisysQwen2ModelInferSampling.restype = c_int64 + + lib.llaisysQwen2ModelInferSamplingEx.argtypes = [ + LlaisysQwen2Model, + POINTER(c_int64), + c_size_t, + c_int, + c_float, + c_float, + c_uint32, + ] + lib.llaisysQwen2ModelInferSamplingEx.restype = c_int64 + + lib.llaisysQwen2ModelResetKVCache.argtypes = [LlaisysQwen2Model] + lib.llaisysQwen2ModelResetKVCache.restype = None + + lib.llaisysQwen2ModelSetKVCacheEnabled.argtypes = [LlaisysQwen2Model, c_int] + lib.llaisysQwen2ModelSetKVCacheEnabled.restype = None + + # Experimental KV block/context APIs + lib.llaisysQwen2KVBlockCreate.argtypes = [ + POINTER(LlaisysQwen2KVBlockMeta), + llaisysDeviceType_t, + c_int, + ] + lib.llaisysQwen2KVBlockCreate.restype = LlaisysQwen2KVBlock + lib.llaisysQwen2KVBlockRetain.argtypes = [LlaisysQwen2KVBlock] + lib.llaisysQwen2KVBlockRetain.restype = None + lib.llaisysQwen2KVBlockRelease.argtypes = [LlaisysQwen2KVBlock] + lib.llaisysQwen2KVBlockRelease.restype = None + lib.llaisysQwen2KVBlockSetTokenCount.argtypes = [LlaisysQwen2KVBlock, c_size_t] + lib.llaisysQwen2KVBlockSetTokenCount.restype = c_int32 + lib.llaisysQwen2KVBlockTokenCount.argtypes = [LlaisysQwen2KVBlock] + lib.llaisysQwen2KVBlockTokenCount.restype = c_size_t + lib.llaisysQwen2KVBlockKeyTensor.argtypes = [LlaisysQwen2KVBlock, c_size_t] + lib.llaisysQwen2KVBlockKeyTensor.restype = llaisysTensor_t + lib.llaisysQwen2KVBlockValueTensor.argtypes = [LlaisysQwen2KVBlock, c_size_t] + lib.llaisysQwen2KVBlockValueTensor.restype = llaisysTensor_t + + lib.llaisysQwen2KVContextCreate.argtypes = [ + llaisysDataType_t, + llaisysDeviceType_t, + c_int, + c_size_t, + c_size_t, + c_size_t, + c_size_t, + ] + lib.llaisysQwen2KVContextCreate.restype = LlaisysQwen2KVContext + lib.llaisysQwen2KVContextRetain.argtypes = [LlaisysQwen2KVContext] + lib.llaisysQwen2KVContextRetain.restype = None + lib.llaisysQwen2KVContextRelease.argtypes = [LlaisysQwen2KVContext] + lib.llaisysQwen2KVContextRelease.restype = None + lib.llaisysQwen2KVContextAttachBlock.argtypes = [LlaisysQwen2KVContext, LlaisysQwen2KVBlock] + lib.llaisysQwen2KVContextAttachBlock.restype = c_int32 + lib.llaisysQwen2KVContextDetachAll.argtypes = [LlaisysQwen2KVContext] + lib.llaisysQwen2KVContextDetachAll.restype = None + lib.llaisysQwen2KVContextBlockCount.argtypes = [LlaisysQwen2KVContext] + lib.llaisysQwen2KVContextBlockCount.restype = c_size_t + lib.llaisysQwen2KVContextTokenCount.argtypes = [LlaisysQwen2KVContext] + lib.llaisysQwen2KVContextTokenCount.restype = c_size_t + + lib.llaisysQwen2ModelSetKVContext.argtypes = [LlaisysQwen2Model, LlaisysQwen2KVContext] + lib.llaisysQwen2ModelSetKVContext.restype = c_int32 + lib.llaisysQwen2ModelGetKVContext.argtypes = [LlaisysQwen2Model] + lib.llaisysQwen2ModelGetKVContext.restype = LlaisysQwen2KVContext + lib.llaisysQwen2ModelExportKVContext.argtypes = [LlaisysQwen2Model, LlaisysQwen2KVContext, c_size_t] + lib.llaisysQwen2ModelExportKVContext.restype = c_int32 + + if hasattr(lib, "llaisysQwen2ModelSetTensorParallel"): + lib.llaisysQwen2ModelSetTensorParallel.argtypes = [LlaisysQwen2Model, c_void_p, c_void_p, c_int] + lib.llaisysQwen2ModelSetTensorParallel.restype = c_int32 + + +# --- Comm API ctypes --- + +llaisysComm_t = c_void_p + +LLAISYS_COMM_UNIQUE_ID_MAX_SIZE = 128 + +comm_init_api = CFUNCTYPE(c_int, POINTER(llaisysComm_t), c_int, c_int, c_void_p) +comm_destroy_api = CFUNCTYPE(None, llaisysComm_t) +comm_get_rank_api = CFUNCTYPE(c_int, llaisysComm_t) +comm_get_size_api = CFUNCTYPE(c_int, llaisysComm_t) +comm_allreduce_api = CFUNCTYPE( + None, c_void_p, c_void_p, c_size_t, c_int, c_int, llaisysComm_t, llaisysStream_t, +) +comm_broadcast_api = CFUNCTYPE( + None, c_void_p, c_size_t, c_int, c_int, llaisysComm_t, llaisysStream_t, +) +comm_send_api = CFUNCTYPE( + None, c_void_p, c_size_t, c_int, c_int, llaisysComm_t, llaisysStream_t, +) +comm_recv_api = CFUNCTYPE( + None, c_void_p, c_size_t, c_int, c_int, llaisysComm_t, llaisysStream_t, +) + + +class LlaisysCommAPI(Structure): + _fields_ = [ + ("init", comm_init_api), + ("destroy", comm_destroy_api), + ("get_rank", comm_get_rank_api), + ("get_size", comm_get_size_api), + ("allreduce", comm_allreduce_api), + ("broadcast", comm_broadcast_api), + ("send", comm_send_api), + ("recv", comm_recv_api), + ] + + +def load_comm(lib): + if hasattr(lib, "llaisysGetCommAPI"): + lib.llaisysGetCommAPI.argtypes = [c_int] + lib.llaisysGetCommAPI.restype = POINTER(LlaisysCommAPI) + if hasattr(lib, "llaisysCommGenerateUniqueId"): + lib.llaisysCommGenerateUniqueId.argtypes = [c_int, c_void_p, POINTER(c_size_t)] + lib.llaisysCommGenerateUniqueId.restype = c_int + + +__all__ = [ + "LlaisysQwen2Meta", + "LlaisysQwen2Weights", + "LlaisysSamplingParams", + "LlaisysQwen2KVBlockMeta", + "LlaisysQwen2Model", + "LlaisysQwen2KVBlock", + "LlaisysQwen2KVContext", + "load_models", + "LlaisysCommAPI", + "llaisysComm_t", + "LLAISYS_COMM_UNIQUE_ID_MAX_SIZE", + "load_comm", +] diff --git a/python/llaisys/libllaisys/ops.py b/python/llaisys/libllaisys/ops.py index 5be095eff..a1ee3a8cb 100644 --- a/python/llaisys/libllaisys/ops.py +++ b/python/llaisys/libllaisys/ops.py @@ -1,5 +1,5 @@ from .tensor import llaisysTensor_t -from ctypes import c_float +from ctypes import c_float, c_int64, c_size_t, POINTER def load_ops(lib): lib.llaisysAdd.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t] @@ -32,5 +32,18 @@ def load_ops(lib): ] lib.llaisysSelfAttention.restype = None + if hasattr(lib, "llaisysSelfAttentionSegmented"): + lib.llaisysSelfAttentionSegmented.argtypes = [ + llaisysTensor_t, # attn_val + llaisysTensor_t, # q + llaisysTensor_t, # k + llaisysTensor_t, # v + c_float, # scale + POINTER(c_int64), # q_offsets ptr + POINTER(c_int64), # kv_offsets ptr + c_size_t, # nseg + ] + lib.llaisysSelfAttentionSegmented.restype = None + lib.llaisysSwiGLU.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t] lib.llaisysSwiGLU.restype = None diff --git a/python/llaisys/libllaisys/tokenizer.py b/python/llaisys/libllaisys/tokenizer.py new file mode 100644 index 000000000..27d6df499 --- /dev/null +++ b/python/llaisys/libllaisys/tokenizer.py @@ -0,0 +1,37 @@ +from ctypes import POINTER, c_char_p, c_int, c_int64, c_size_t, c_void_p + + +LlaisysTokenizer = c_void_p + + +def load_tokenizer(lib): + lib.llaisysTokenizerCreateSentencePiece.argtypes = [c_char_p] + lib.llaisysTokenizerCreateSentencePiece.restype = LlaisysTokenizer + + lib.llaisysTokenizerDestroy.argtypes = [LlaisysTokenizer] + lib.llaisysTokenizerDestroy.restype = None + + lib.llaisysTokenizerEncode.argtypes = [ + LlaisysTokenizer, + c_char_p, + POINTER(c_int64), + c_size_t, + ] + lib.llaisysTokenizerEncode.restype = c_int + + lib.llaisysTokenizerDecode.argtypes = [ + LlaisysTokenizer, + POINTER(c_int64), + c_size_t, + c_char_p, + c_size_t, + ] + lib.llaisysTokenizerDecode.restype = c_int + + lib.llaisysTokenizerTokenToId.argtypes = [LlaisysTokenizer, c_char_p] + lib.llaisysTokenizerTokenToId.restype = c_int64 + lib.llaisysTokenizerIdToToken.argtypes = [LlaisysTokenizer, c_int64, c_char_p, c_size_t] + lib.llaisysTokenizerIdToToken.restype = c_int + + +__all__ = ["LlaisysTokenizer", "load_tokenizer"] diff --git a/python/llaisys/models/__init__.py b/python/llaisys/models/__init__.py index af9918b0d..242720675 100644 --- a/python/llaisys/models/__init__.py +++ b/python/llaisys/models/__init__.py @@ -1 +1 @@ -from .qwen2 import Qwen2 +from .qwen2 import Qwen2, format_chat_prompt diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b21..b53b4c54c 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,33 +1,599 @@ -from typing import Sequence -from ..libllaisys import LIB_LLAISYS -from ..libllaisys import DeviceType - +from typing import Sequence, Iterable, Mapping, Optional +import warnings +from ctypes import byref, c_int, c_size_t, c_float, c_int64, c_uint32, c_void_p +import json from pathlib import Path + +import numpy as np import safetensors +from ..libllaisys import ( + LIB_LLAISYS, + DeviceType, + DataType, + llaisysDeviceType_t, + llaisysDataType_t, + LlaisysQwen2Meta, + LlaisysSamplingParams, + LlaisysQwen2KVBlockMeta, +) + + +def format_chat_prompt( + messages: Iterable[Mapping[str, str]], + system_prompt: Optional[str] = None, + add_generation_prompt: bool = True, +) -> str: + lines: list[str] = [] + if system_prompt: + lines.append(f"System: {system_prompt}") + for msg in messages: + role = str(msg.get("role", "")).strip().lower() + content = str(msg.get("content", "")).strip() + if not role or content == "": + raise ValueError("Each message must have non-empty role and content") + if role == "system": + label = "System" + elif role == "assistant": + label = "Assistant" + else: + label = "User" + lines.append(f"{label}: {content}") + if add_generation_prompt: + if not lines or not lines[-1].startswith("Assistant:"): + lines.append("Assistant: ") + return "\n".join(lines) + class Qwen2: + @staticmethod + def build_prompt( + messages: Iterable[Mapping[str, str]], + system_prompt: Optional[str] = None, + add_generation_prompt: bool = True, + ) -> str: + return format_chat_prompt(messages, system_prompt, add_generation_prompt) - def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor + def __init__(self, model_path, device: DeviceType = DeviceType.CPU): model_path = Path(model_path) + self._device = device + # 实例化模型元信息 + config_path = model_path / "config.json" + # 如果config.json不存在,则递归查找 + if not config_path.exists(): + candidates = list(model_path.rglob("config.json")) + if not candidates: + raise FileNotFoundError("config.json not found under model_path") + config_path = candidates[0] + # 读取配置文件 + with open(config_path, "r", encoding="utf-8") as f: + cfg = json.load(f) + # 解析数据类型 + torch_dtype = str(cfg.get("torch_dtype", "bfloat16")).lower() + if "float32" in torch_dtype or torch_dtype in {"fp32", "f32"}: + dtype = DataType.F32 + elif "float16" in torch_dtype or torch_dtype in {"fp16", "f16"}: + dtype = DataType.F16 + else: + dtype = DataType.BF16 + # 统一用 torch 读取,避免 numpy->torch 混合加载路径在 Windows 上触发崩溃 + # (历史上 safetensors 在该切换路径会触发访问冲突)。 + use_torch_loader = True + if dtype == DataType.BF16: + dtype = DataType.F16 + # 解析模型参数 + nlayer = int(cfg.get("num_hidden_layers", 0)) + hs = int(cfg.get("hidden_size", 0)) + nh = int(cfg.get("num_attention_heads", 0)) + nkvh = int(cfg.get("num_key_value_heads", nh)) + di = int(cfg.get("intermediate_size", 0)) + maxseq = int(cfg.get("max_position_embeddings", 0)) + voc = int(cfg.get("vocab_size", 0)) + epsilon = float(cfg.get("rms_norm_eps", 1e-6)) + theta = float(cfg.get("rope_theta", 10000.0)) + eos = cfg.get("eos_token_id", -1) + # 解析结束token + if isinstance(eos, list): + end_token = int(eos[0]) if eos else -1 + else: + end_token = int(eos) + # 解析head_dim + dh = int(cfg.get("head_dim", hs // nh if nh else 0)) + # 创建模型元信息结构体 + model_meta = LlaisysQwen2Meta( + llaisysDataType_t(dtype), + c_size_t(nlayer), + c_size_t(hs), + c_size_t(nh), + c_size_t(nkvh), + c_size_t(dh), + c_size_t(di), + c_size_t(maxseq), + c_size_t(voc), + c_float(epsilon), + c_float(theta), + c_int64(end_token), + ) + # 创建模型实例 + device_ids = (c_int * 1)(0) + self._model = LIB_LLAISYS.llaisysQwen2ModelCreate( + byref(model_meta), + llaisysDeviceType_t(device), + device_ids, + 1, + ) + if not self._model: + raise RuntimeError("llaisysQwen2ModelCreate failed") + self._model_weights = LIB_LLAISYS.llaisysQwen2ModelWeights(self._model) + self._meta = model_meta + + # 默认开启 KV-cache + LIB_LLAISYS.llaisysQwen2ModelSetKVCacheEnabled(self._model, c_int(1)) + # + def _dtype_to_llaisys(dtype: np.dtype) -> DataType: + name = getattr(dtype, "name", str(dtype)).lower() + if name in {"float32", "f4"}: + return DataType.F32 + if name in {"float16", "f2"}: + return DataType.F16 + if name in {"bfloat16", "bf16"}: + return DataType.BF16 + if name in {"int64", "i8"}: + return DataType.I64 + if name in {"int32", "i4"}: + return DataType.I32 + if name in {"int16", "i2"}: + return DataType.I16 + if name in {"int8", "i1"}: + return DataType.I8 + if name in {"uint8", "u1"}: + return DataType.U8 + raise ValueError(f"Unsupported dtype: {dtype}") + + def _create_tensor_from_numpy(arr: np.ndarray): + arr = np.ascontiguousarray(arr) + _shape = (c_size_t * arr.ndim)(*arr.shape) + _dtype = _dtype_to_llaisys(arr.dtype) + tensor = LIB_LLAISYS.tensorCreate( + _shape, + c_size_t(arr.ndim), + llaisysDataType_t(_dtype), + llaisysDeviceType_t(device), + c_int(0), + ) + LIB_LLAISYS.tensorLoad(tensor, c_void_p(arr.ctypes.data)) + return tensor + + # 加载模型权重 for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") + import torch + data_ = safetensors.safe_open(file, framework="pt", device="cpu") for name_ in data_.keys(): ## TODO: load the model weights - pass + arr = data_.get_tensor(name_) + if use_torch_loader: + if arr.dtype == torch.bfloat16: + arr = arr.to(torch.float16) + arr = arr.cpu().numpy() + tensor = _create_tensor_from_numpy(arr) + w = self._model_weights.contents + + if name_ in {"model.embed_tokens.weight", "transformer.wte.weight"}: + w.in_embed = tensor + continue + if name_ in {"lm_head.weight", "model.lm_head.weight"}: + w.out_embed = tensor + continue + if name_ in {"model.norm.weight", "transformer.ln_f.weight"}: + w.out_norm_w = tensor + continue + + if name_.startswith("model.layers."): + parts = name_.split(".") + if len(parts) < 4: + continue + layer = int(parts[2]) + sub = ".".join(parts[3:]) + + if sub == "input_layernorm.weight": + w.attn_norm_w[layer] = tensor + elif sub == "self_attn.q_proj.weight": + w.attn_q_w[layer] = tensor + elif sub == "self_attn.q_proj.bias": + w.attn_q_b[layer] = tensor + elif sub == "self_attn.k_proj.weight": + w.attn_k_w[layer] = tensor + elif sub == "self_attn.k_proj.bias": + w.attn_k_b[layer] = tensor + elif sub == "self_attn.v_proj.weight": + w.attn_v_w[layer] = tensor + elif sub == "self_attn.v_proj.bias": + w.attn_v_b[layer] = tensor + elif sub == "self_attn.o_proj.weight": + w.attn_o_w[layer] = tensor + elif sub == "post_attention_layernorm.weight": + w.mlp_norm_w[layer] = tensor + elif sub == "mlp.gate_proj.weight": + w.mlp_gate_w[layer] = tensor + elif sub == "mlp.up_proj.weight": + w.mlp_up_w[layer] = tensor + elif sub == "mlp.down_proj.weight": + w.mlp_down_w[layer] = tensor + w = self._model_weights.contents + if not w.out_embed and w.in_embed: + w.out_embed = w.in_embed + + def generate( self, + # 输入数组 inputs: Sequence[int], + # 最大token数 max_new_tokens: int = None, + # top-k 采样,1 表示贪心 top_k: int = 1, + # top-p 核采样阈值 top_p: float = 0.8, + # 温度系数,越小越保守 temperature: float = 0.8, + # 随机种子,0 表示随机 + seed: int = 0, ): + tokens = list(inputs) + if max_new_tokens is None: + max_new_tokens = 128 + use_sampling = temperature > 0 or top_k > 1 or top_p > 0 + + # prefill with full prompt + if use_sampling: + next_token = int( + self.prefill_sampling( + tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + seed=seed, + ) + ) + else: + token_buf = (c_int64 * len(tokens))(*tokens) + next_token = int( + LIB_LLAISYS.llaisysQwen2ModelPrefill( + self._model, + token_buf, + c_size_t(len(tokens)), + ) + ) + if next_token < 0: + return tokens + tokens.append(next_token) + if self._meta.end_token >= 0 and next_token == self._meta.end_token: + return tokens + + remaining = max_new_tokens - 1 + if remaining <= 0: + return tokens + + # step with newly generated tokens only + for _ in range(remaining): + if next_token < 0: + break + if self._meta.end_token >= 0 and next_token == self._meta.end_token: + break + if use_sampling: + next_token = int( + self.step_sampling( + [next_token], + top_k=top_k, + top_p=top_p, + temperature=temperature, + seed=seed, + ) + ) + else: + token_buf = (c_int64 * 1)(next_token) + next_token = int( + LIB_LLAISYS.llaisysQwen2ModelStep( + self._model, + token_buf, + c_size_t(1), + ) + ) + if next_token < 0: + break + tokens.append(next_token) + + return tokens + + def prefill(self, inputs: Sequence[int]) -> int: + tokens = list(inputs) + token_buf = (c_int64 * len(tokens))(*tokens) + return int( + LIB_LLAISYS.llaisysQwen2ModelPrefill( + self._model, + token_buf, + c_size_t(len(tokens)), + ) + ) + + def step(self, new_tokens: Sequence[int]) -> int: + tokens = list(new_tokens) + token_buf = (c_int64 * len(tokens))(*tokens) + return int( + LIB_LLAISYS.llaisysQwen2ModelStep( + self._model, + token_buf, + c_size_t(len(tokens)), + ) + ) + + def prefill_packed(self, sequences: Sequence[Sequence[int]]) -> list[int]: + seqs = [list(s) for s in sequences] + if not seqs: + return [] + if not hasattr(LIB_LLAISYS, "llaisysQwen2ModelPrefillPacked"): + raise RuntimeError("llaisysQwen2ModelPrefillPacked is unavailable in current llaisys.dll") + offsets = [0] + flat: list[int] = [] + for s in seqs: + if not s: + raise ValueError("each packed sequence must be non-empty") + flat.extend(int(x) for x in s) + offsets.append(len(flat)) + token_buf = (c_int64 * len(flat))(*flat) + off_buf = (c_int64 * len(offsets))(*offsets) + out_buf = (c_int64 * len(seqs))() + ret = int( + LIB_LLAISYS.llaisysQwen2ModelPrefillPacked( + self._model, + token_buf, + off_buf, + c_size_t(len(seqs)), + out_buf, + ) + ) + if ret != 0: + raise RuntimeError(f"llaisysQwen2ModelPrefillPacked failed with code {ret}") + return [int(out_buf[i]) for i in range(len(seqs))] + + def step_packed(self, sequences: Sequence[Sequence[int]]) -> list[int]: + seqs = [list(s) for s in sequences] + if not seqs: + return [] + if not hasattr(LIB_LLAISYS, "llaisysQwen2ModelStepPacked"): + raise RuntimeError("llaisysQwen2ModelStepPacked is unavailable in current llaisys.dll") + offsets = [0] + flat: list[int] = [] + for s in seqs: + if not s: + raise ValueError("each packed sequence must be non-empty") + flat.extend(int(x) for x in s) + offsets.append(len(flat)) + token_buf = (c_int64 * len(flat))(*flat) + off_buf = (c_int64 * len(offsets))(*offsets) + out_buf = (c_int64 * len(seqs))() + ret = int( + LIB_LLAISYS.llaisysQwen2ModelStepPacked( + self._model, + token_buf, + off_buf, + c_size_t(len(seqs)), + out_buf, + ) + ) + if ret != 0: + raise RuntimeError(f"llaisysQwen2ModelStepPacked failed with code {ret}") + return [int(out_buf[i]) for i in range(len(seqs))] + + def prefill_packed_sampling( + self, + sequences: Sequence[Sequence[int]], + params_list: Sequence[LlaisysSamplingParams], + ) -> list[int]: + seqs = [list(s) for s in sequences] + if not seqs: + return [] + if not hasattr(LIB_LLAISYS, "llaisysQwen2ModelPrefillPackedSampling"): + raise RuntimeError("llaisysQwen2ModelPrefillPackedSampling is unavailable in current llaisys.dll") + if len(params_list) != len(seqs): + raise ValueError("params_list length must match sequences length") + offsets = [0] + flat: list[int] = [] + for s in seqs: + if not s: + raise ValueError("each packed sequence must be non-empty") + flat.extend(int(x) for x in s) + offsets.append(len(flat)) + token_buf = (c_int64 * len(flat))(*flat) + off_buf = (c_int64 * len(offsets))(*offsets) + params_buf = (LlaisysSamplingParams * len(seqs))(*params_list) + out_buf = (c_int64 * len(seqs))() + ret = int( + LIB_LLAISYS.llaisysQwen2ModelPrefillPackedSampling( + self._model, + token_buf, + off_buf, + c_size_t(len(seqs)), + params_buf, + out_buf, + ) + ) + if ret != 0: + raise RuntimeError(f"llaisysQwen2ModelPrefillPackedSampling failed with code {ret}") + return [int(out_buf[i]) for i in range(len(seqs))] + + def step_packed_sampling( + self, + sequences: Sequence[Sequence[int]], + params_list: Sequence[LlaisysSamplingParams], + ) -> list[int]: + seqs = [list(s) for s in sequences] + if not seqs: + return [] + if not hasattr(LIB_LLAISYS, "llaisysQwen2ModelStepPackedSampling"): + raise RuntimeError("llaisysQwen2ModelStepPackedSampling is unavailable in current llaisys.dll") + if len(params_list) != len(seqs): + raise ValueError("params_list length must match sequences length") + offsets = [0] + flat: list[int] = [] + for s in seqs: + if not s: + raise ValueError("each packed sequence must be non-empty") + flat.extend(int(x) for x in s) + offsets.append(len(flat)) + token_buf = (c_int64 * len(flat))(*flat) + off_buf = (c_int64 * len(offsets))(*offsets) + params_buf = (LlaisysSamplingParams * len(seqs))(*params_list) + out_buf = (c_int64 * len(seqs))() + ret = int( + LIB_LLAISYS.llaisysQwen2ModelStepPackedSampling( + self._model, + token_buf, + off_buf, + c_size_t(len(seqs)), + params_buf, + out_buf, + ) + ) + if ret != 0: + raise RuntimeError(f"llaisysQwen2ModelStepPackedSampling failed with code {ret}") + return [int(out_buf[i]) for i in range(len(seqs))] + + def prefill_sampling( + self, + inputs: Sequence[int], + top_k: int = 1, + top_p: float = 0.0, + temperature: float = 0.0, + seed: int = 0, + ) -> int: + tokens = list(inputs) + token_buf = (c_int64 * len(tokens))(*tokens) + params = LlaisysSamplingParams( + c_int(top_k), + c_float(top_p), + c_float(temperature), + c_uint32(seed), + ) + return int( + LIB_LLAISYS.llaisysQwen2ModelPrefillSampling( + self._model, + token_buf, + c_size_t(len(tokens)), + byref(params), + ) + ) + + def step_sampling( + self, + new_tokens: Sequence[int], + top_k: int = 1, + top_p: float = 0.0, + temperature: float = 0.0, + seed: int = 0, + ) -> int: + tokens = list(new_tokens) + token_buf = (c_int64 * len(tokens))(*tokens) + params = LlaisysSamplingParams( + c_int(top_k), + c_float(top_p), + c_float(temperature), + c_uint32(seed), + ) + return int( + LIB_LLAISYS.llaisysQwen2ModelStepSampling( + self._model, + token_buf, + c_size_t(len(tokens)), + byref(params), + ) + ) + + def infer(self, inputs: Sequence[int]) -> int: + warnings.warn( + "Qwen2.infer is deprecated; use prefill()/step() instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.prefill(inputs) + + def reset_kv_cache(self): + LIB_LLAISYS.llaisysQwen2ModelResetKVCache(self._model) + + # ===== Experimental KV block/context wrappers ===== + def kv_context_create(self): + return LIB_LLAISYS.llaisysQwen2KVContextCreate( + llaisysDataType_t(self._meta.dtype), + llaisysDeviceType_t(self._device), + c_int(0), + c_size_t(self._meta.nlayer), + c_size_t(self._meta.nh), + c_size_t(self._meta.nkvh), + c_size_t(self._meta.dh), + ) + + def kv_context_release(self, ctx): + LIB_LLAISYS.llaisysQwen2KVContextRelease(ctx) + + def kv_context_attach_block(self, ctx, block): + return int(LIB_LLAISYS.llaisysQwen2KVContextAttachBlock(ctx, block)) + + def kv_context_detach_all(self, ctx): + LIB_LLAISYS.llaisysQwen2KVContextDetachAll(ctx) + + def kv_context_block_count(self, ctx) -> int: + return int(LIB_LLAISYS.llaisysQwen2KVContextBlockCount(ctx)) + + def kv_context_token_count(self, ctx) -> int: + return int(LIB_LLAISYS.llaisysQwen2KVContextTokenCount(ctx)) + + def kv_block_create(self, max_tokens: int): + meta = LlaisysQwen2KVBlockMeta( + llaisysDataType_t(self._meta.dtype), + c_size_t(self._meta.nlayer), + c_size_t(self._meta.nh), + c_size_t(self._meta.nkvh), + c_size_t(self._meta.dh), + c_size_t(max_tokens), + ) + return LIB_LLAISYS.llaisysQwen2KVBlockCreate( + byref(meta), + llaisysDeviceType_t(self._device), + c_int(0), + ) + + def kv_block_retain(self, block): + LIB_LLAISYS.llaisysQwen2KVBlockRetain(block) + + def kv_block_release(self, block): + LIB_LLAISYS.llaisysQwen2KVBlockRelease(block) + + def kv_block_token_count(self, block) -> int: + return int(LIB_LLAISYS.llaisysQwen2KVBlockTokenCount(block)) + + def kv_block_set_token_count(self, block, used_tokens: int) -> int: + return int(LIB_LLAISYS.llaisysQwen2KVBlockSetTokenCount(block, c_size_t(int(used_tokens)))) + + def kv_block_key_tensor(self, block, layer: int): + return LIB_LLAISYS.llaisysQwen2KVBlockKeyTensor(block, c_size_t(int(layer))) + + def kv_block_value_tensor(self, block, layer: int): + return LIB_LLAISYS.llaisysQwen2KVBlockValueTensor(block, c_size_t(int(layer))) + + def set_kv_context(self, ctx) -> int: + return int(LIB_LLAISYS.llaisysQwen2ModelSetKVContext(self._model, ctx)) - # TODO: Implement generate function + def get_kv_context(self): + return LIB_LLAISYS.llaisysQwen2ModelGetKVContext(self._model) - return [] + def export_kv_context(self, ctx, block_tokens: int) -> int: + return int( + LIB_LLAISYS.llaisysQwen2ModelExportKVContext( + self._model, + ctx, + c_size_t(int(block_tokens)), + ) + ) diff --git a/python/llaisys/ops.py b/python/llaisys/ops.py index ed0180bc8..5921339b6 100644 --- a/python/llaisys/ops.py +++ b/python/llaisys/ops.py @@ -1,6 +1,6 @@ from .libllaisys import LIB_LLAISYS from .tensor import Tensor -from ctypes import c_float, c_int +from ctypes import c_float, c_int, c_int64, c_size_t class Ops: @@ -50,6 +50,35 @@ def self_attention(attn_val: Tensor, q: Tensor, k: Tensor, v: Tensor, scale: flo c_float(scale), ) + @staticmethod + def self_attention_segmented( + attn_val: Tensor, + q: Tensor, + k: Tensor, + v: Tensor, + scale: float, + q_offsets: list[int], + kv_offsets: list[int], + ): + if len(q_offsets) != len(kv_offsets): + raise ValueError("q_offsets and kv_offsets must have same length") + if len(q_offsets) < 2: + raise ValueError("offsets must contain at least start/end") + if not hasattr(LIB_LLAISYS, "llaisysSelfAttentionSegmented"): + raise RuntimeError("llaisysSelfAttentionSegmented is unavailable in current llaisys.dll") + q_buf = (c_int64 * len(q_offsets))(*[int(x) for x in q_offsets]) + kv_buf = (c_int64 * len(kv_offsets))(*[int(x) for x in kv_offsets]) + LIB_LLAISYS.llaisysSelfAttentionSegmented( + attn_val.lib_tensor(), + q.lib_tensor(), + k.lib_tensor(), + v.lib_tensor(), + c_float(scale), + q_buf, + kv_buf, + c_size_t(len(q_offsets) - 1), + ) + @staticmethod def swiglu(out: Tensor, gate: Tensor, up: Tensor): LIB_LLAISYS.llaisysSwiGLU(out.lib_tensor(), gate.lib_tensor(), up.lib_tensor()) diff --git a/python/llaisys/scheduler.py b/python/llaisys/scheduler.py new file mode 100644 index 000000000..399347081 --- /dev/null +++ b/python/llaisys/scheduler.py @@ -0,0 +1,910 @@ +from __future__ import annotations + +from dataclasses import dataclass +import logging +import queue +import threading +import time +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple +from collections import OrderedDict, deque + +if TYPE_CHECKING: + from llaisys.interfaces import IInferenceService + +logger = logging.getLogger(__name__) + +_END = object() + + +@dataclass +class InferenceTask: + payload: Dict[str, Any] + stream: bool + output_queue: "queue.Queue[Any]" + deadline_at: Optional[float] + + +@dataclass +class _ActiveTask: + task: InferenceTask + iterator: Any + emitted_any: bool = False + + +class SchedulerQueueFullError(RuntimeError): + pass + + +class TaskTimeoutError(RuntimeError): + pass + + +class TaskHandle: + def __init__(self, output_queue: "queue.Queue[Any]") -> None: + self._q = output_queue + + def get_result(self, timeout: Optional[float] = None) -> Dict[str, Any]: + while True: + try: + item = self._q.get(timeout=timeout) + except queue.Empty as exc: + raise TaskTimeoutError("task result timeout") from exc + if item is _END: + raise RuntimeError("task ended without result") + if isinstance(item, dict): + return item + raise RuntimeError("unexpected task result type") + + def iter_stream(self, timeout: Optional[float] = None) -> Iterable[Dict[str, Any]]: + while True: + try: + item = self._q.get(timeout=timeout) + except queue.Empty as exc: + raise TaskTimeoutError("task stream timeout") from exc + if item is _END: + break + if isinstance(item, dict): + yield item + else: + raise RuntimeError("unexpected stream item type") + + +class InferenceScheduler: + """In-process scheduler with per-worker queues and session stickiness.""" + + def __init__( + self, + services: "List[IInferenceService]", + queue_size: int = 128, + request_timeout_ms: int = 120000, + continuous_batching: bool = False, + kv_aware_routing: bool = False, + max_sticky_sessions: int = 10000, + max_batch_size: int = 8, + kv_memory_threshold: float = 0.0, + ) -> None: + if not services: + raise ValueError("services must not be empty") + self._services: "List[IInferenceService]" = list(services) + self._queue_size = max(1, int(queue_size)) + self._request_timeout_ms = max(0, int(request_timeout_ms)) + self._continuous_batching = bool(continuous_batching) + self._kv_aware_routing = bool(kv_aware_routing) + self._max_sticky_sessions = max(100, int(max_sticky_sessions)) + self._max_batch_size = max(1, int(max_batch_size)) + self._kv_memory_threshold = float(kv_memory_threshold) + self._queues: List["queue.Queue[Optional[InferenceTask]]"] = [ + queue.Queue(maxsize=self._queue_size) for _ in self._services + ] + self._threads: List[threading.Thread] = [] + self._stop = threading.Event() + self._lock = threading.Lock() + self._session_worker: "OrderedDict[str, int]" = OrderedDict() + self._rr = 0 + self._packed_prefill_last_error: str = "" + self._metrics: Dict[str, float] = { + "submitted": 0.0, + "completed": 0.0, + "cancelled": 0.0, + "failed": 0.0, + "timed_out": 0.0, + "queue_full": 0.0, + "stop_requests": 0.0, + "batch_rounds": 0.0, + "batch_active_sum": 0.0, + "batch_last_active": 0.0, + "prefill_rounds": 0.0, + "decode_rounds": 0.0, + "prefill_last_active": 0.0, + "decode_last_active": 0.0, + "packed_prefill_batches": 0.0, + "packed_prefill_tasks": 0.0, + "packed_prefill_attempts": 0.0, + "packed_prefill_candidate_tasks": 0.0, + "packed_prefill_none_returns": 0.0, + "packed_prefill_exceptions": 0.0, + # KV 感知路由指标 + "kv_aware_routing_attempts": 0.0, + "kv_aware_routing_hits": 0.0, + "kv_aware_routing_best_prefix_len_sum": 0.0, + # 流式批处理指标 + "stream_batch_prefill_batches": 0.0, + "stream_batch_prefill_tasks": 0.0, + "stream_batch_decode_rounds": 0.0, + "stream_batch_decode_active_sum": 0.0, + "stream_batch_shrink_events": 0.0, + "stream_batch_fallback_tasks": 0.0, + # KV 内存流控指标 + "kv_memory_rejected": 0.0, + } + + def start(self) -> None: + if self._threads: + return + self._stop.clear() + for idx in range(len(self._services)): + t = threading.Thread(target=self._worker_loop, args=(idx,), daemon=True) + t.start() + self._threads.append(t) + + def stop(self) -> None: + self._stop.set() + for q in self._queues: + try: + q.put_nowait(None) + except queue.Full: + pass + for t in self._threads: + t.join(timeout=1.0) + self._threads.clear() + + def submit(self, payload: Dict[str, Any], stream: bool) -> TaskHandle: + payload = dict(payload) # shallow copy to avoid mutating caller's dict + + # KV 内存感知流控:超过阈值时拒绝新请求 + if self._kv_memory_threshold > 0: + try: + pressure = max( + svc.kv_pool.memory_pressure() for svc in self._services + ) + except Exception: + pressure = 0.0 + if pressure > self._kv_memory_threshold: + with self._lock: + self._metrics["kv_memory_rejected"] += 1.0 + raise SchedulerQueueFullError("KV memory pressure too high") + + # 自动 tokenize:如果启用了 KV 感知路由且 payload 中没有 _prompt_tokens + if ( + self._kv_aware_routing + and "_prompt_tokens" not in payload + and len(self._services) > 1 + ): + try: + # 使用第一个服务进行 tokenize(所有服务使用相同的 tokenizer) + svc = self._services[0] + if hasattr(svc, "tokenize_for_routing"): + tokens = svc.tokenize_for_routing(payload) + if tokens: + payload["_prompt_tokens"] = tokens + except Exception: + logger.debug("tokenize_for_routing failed, falling back to default routing", exc_info=True) + + worker_idx = self._choose_worker(payload) + + # 清理路由专用的内部字段,不传递给下游 + payload.pop("_prompt_tokens", None) + + out_q: "queue.Queue[Any]" = queue.Queue() + deadline_at: Optional[float] = None + if self._request_timeout_ms > 0: + deadline_at = time.time() + self._request_timeout_ms / 1000.0 + task = InferenceTask(payload=payload, stream=bool(stream), output_queue=out_q, deadline_at=deadline_at) + try: + self._queues[worker_idx].put_nowait(task) + except queue.Full: + with self._lock: + self._metrics["queue_full"] += 1.0 + raise SchedulerQueueFullError("scheduler queue is full") + with self._lock: + self._metrics["submitted"] += 1.0 + return TaskHandle(out_q) + + def request_stop(self, session_id: str) -> bool: + sid = str(session_id or "").strip() + if not sid: + return False + with self._lock: + self._metrics["stop_requests"] += 1.0 + idx = self._session_worker.get(sid) + if idx is not None: + return bool(self._services[idx].request_stop(sid)) + ok = False + for svc in self._services: + ok = bool(svc.request_stop(sid)) or ok + return ok + + def kv_debug_snapshot(self, session_id: Optional[str] = None) -> Dict[str, Any]: + sid = str(session_id or "").strip() + if sid: + with self._lock: + idx = self._session_worker.get(sid) + if idx is not None: + snap = self._services[idx].kv_debug_snapshot(sid) + snap["worker"] = idx + return snap + for idx2, svc in enumerate(self._services): + snap = svc.kv_debug_snapshot(sid) + if snap.get("has_native_context") or snap.get("last_bind"): + snap["worker"] = idx2 + return snap + return self._services[0].kv_debug_snapshot(sid) + + merged = { + "session_id": None, + "workers": len(self._services), + "queue_size": self._queue_size, + "queues": [q.qsize() for q in self._queues], + "kv_pool": { + "contexts": 0.0, + "blocks": 0.0, + "prefix_entries": 0.0, + "total_bytes": 0.0, + "zero_ref_blocks": 0.0, + "shared_blocks": 0.0, + "total_refs": 0.0, + "acquire_count": 0.0, + "prefix_hit_count": 0.0, + "prefix_hit_rate": 0.0, + "avg_matched_tokens": 0.0, + }, + } + + # 共享 KV 池优化:所有 worker 共享同一个池时只查询一次 + first_pool = getattr(self._services[0], "kv_pool", None) + shared_pool = first_pool is not None and all( + getattr(svc, "kv_pool", None) is first_pool for svc in self._services[1:] + ) + + if shared_pool: + snap = self._services[0].kv_debug_snapshot(None) + pool = snap.get("kv_pool", {}) + for k in merged["kv_pool"]: + merged["kv_pool"][k] = float(pool.get(k, 0.0)) + else: + hit_rate_numer = 0.0 + hit_rate_denom = 0.0 + matched_numer = 0.0 + matched_denom = 0.0 + for svc in self._services: + snap = svc.kv_debug_snapshot(None) + pool = snap.get("kv_pool", {}) + for k in ("contexts", "blocks", "prefix_entries", "total_bytes", "zero_ref_blocks", "shared_blocks", "total_refs", "acquire_count", "prefix_hit_count"): + merged["kv_pool"][k] += float(pool.get(k, 0.0)) + hit_rate_numer += float(pool.get("prefix_hit_count", 0.0)) + hit_rate_denom += float(pool.get("acquire_count", 0.0)) + matched_numer += float(pool.get("avg_matched_tokens", 0.0)) * float(pool.get("acquire_count", 0.0)) + matched_denom += float(pool.get("acquire_count", 0.0)) + merged["kv_pool"]["prefix_hit_rate"] = hit_rate_numer / hit_rate_denom if hit_rate_denom > 0 else 0.0 + merged["kv_pool"]["avg_matched_tokens"] = matched_numer / matched_denom if matched_denom > 0 else 0.0 + return merged + + def debug_snapshot(self) -> Dict[str, Any]: + with self._lock: + metrics = dict(self._metrics) + packed_prefill_last_error = self._packed_prefill_last_error + sticky_sessions = len(self._session_worker) + avg_batch_active = ( + metrics.get("batch_active_sum", 0.0) / metrics.get("batch_rounds", 1.0) + if metrics.get("batch_rounds", 0.0) > 0 + else 0.0 + ) + kv_routing_attempts = metrics.get("kv_aware_routing_attempts", 0.0) + kv_routing_hit_rate = ( + metrics.get("kv_aware_routing_hits", 0.0) / kv_routing_attempts + if kv_routing_attempts > 0 + else 0.0 + ) + kv_routing_avg_prefix_len = ( + metrics.get("kv_aware_routing_best_prefix_len_sum", 0.0) / metrics.get("kv_aware_routing_hits", 1.0) + if metrics.get("kv_aware_routing_hits", 0.0) > 0 + else 0.0 + ) + # KV memory pressure snapshot + try: + kv_memory_pressure = max( + svc.kv_pool.memory_pressure() for svc in self._services + ) + except Exception: + kv_memory_pressure = 0.0 + return { + "workers": len(self._services), + "queue_size": self._queue_size, + "queues": [q.qsize() for q in self._queues], + "request_timeout_ms": self._request_timeout_ms, + "continuous_batching": self._continuous_batching, + "kv_aware_routing": self._kv_aware_routing, + "kv_routing_hit_rate": kv_routing_hit_rate, + "kv_routing_avg_prefix_len": kv_routing_avg_prefix_len, + "kv_memory_threshold": self._kv_memory_threshold, + "kv_memory_pressure": kv_memory_pressure, + "max_batch_size": self._max_batch_size, + "avg_batch_active": avg_batch_active, + "sticky_sessions": sticky_sessions, + "packed_prefill_last_error": packed_prefill_last_error, + "metrics": metrics, + } + + def request_timeout_seconds(self) -> Optional[float]: + if self._request_timeout_ms <= 0: + return None + return self._request_timeout_ms / 1000.0 + + def _touch_session(self, sid: str, worker_idx: int) -> None: + """Record/update session->worker mapping with LRU eviction. Caller must hold self._lock.""" + if sid in self._session_worker: + self._session_worker.move_to_end(sid) + self._session_worker[sid] = worker_idx + while len(self._session_worker) > self._max_sticky_sessions: + self._session_worker.popitem(last=False) + + def _choose_worker(self, payload: Dict[str, Any]) -> int: + sid = str(payload.get("session_id") or payload.get("edit_from_session_id") or "").strip() + + # 1. 会话粘性:已绑定的 session 优先路由到原 worker + with self._lock: + if sid and sid in self._session_worker: + self._session_worker.move_to_end(sid) + return self._session_worker[sid] + + # 2. KV 感知路由:查询各 worker 的 KV 命中情况 + # KV 感知路由是 best-effort:查询到入队之间 KV 状态可能变化, + # 最坏情况是路由到非最优 worker,不影响正确性。 + prompt_tokens: Optional[Sequence[int]] = payload.get("_prompt_tokens") + if self._kv_aware_routing and prompt_tokens and len(self._services) > 1: + best_worker = -1 + best_prefix_len = -1 + + # 共享 KV 池优化:所有 worker 共享同一个池时只查询一次 + first_pool = getattr(self._services[0], "kv_pool", None) + shared_pool = first_pool is not None and all( + getattr(svc, "kv_pool", None) is first_pool for svc in self._services[1:] + ) + + if shared_pool: + try: + prefix_len = first_pool.query_prefix_len(prompt_tokens) + if prefix_len > 0: + best_prefix_len = prefix_len + # 共享池模式下选负载最轻的 worker + best_worker = min(range(len(self._queues)), key=lambda i: self._queues[i].qsize()) + except Exception: + pass + else: + for idx, svc in enumerate(self._services): + try: + kv_pool = getattr(svc, "kv_pool", None) + if kv_pool is None: + continue + prefix_len = kv_pool.query_prefix_len(prompt_tokens) + if prefix_len > best_prefix_len: + best_prefix_len = prefix_len + best_worker = idx + except Exception: + # 查询失败,跳过该 worker + continue + + with self._lock: + self._metrics["kv_aware_routing_attempts"] += 1.0 + if best_prefix_len > 0: + self._metrics["kv_aware_routing_hits"] += 1.0 + self._metrics["kv_aware_routing_best_prefix_len_sum"] += float(best_prefix_len) + + if best_worker >= 0 and best_prefix_len > 0: + if sid: + with self._lock: + self._touch_session(sid, best_worker) + return best_worker + + # 3. Fallback:hash 或轮询 + with self._lock: + if sid: + idx = hash(sid) % len(self._services) + self._touch_session(sid, idx) + return idx + idx = self._rr % len(self._services) + self._rr = (self._rr + 1) % len(self._services) + return idx + + def _bind_session(self, session_id: Optional[str], worker_idx: int) -> None: + sid = str(session_id or "").strip() + if not sid: + return + with self._lock: + self._touch_session(sid, worker_idx) + + def _worker_loop(self, idx: int) -> None: + if self._continuous_batching: + self._worker_loop_continuous(idx) + return + + svc = self._services[idx] + q = self._queues[idx] + while not self._stop.is_set(): + task = q.get() + if task is None: + q.task_done() + continue + try: + if task.deadline_at is not None and time.time() > task.deadline_at: + with self._lock: + self._metrics["timed_out"] += 1.0 + if task.stream: + task.output_queue.put({"error": "request timeout", "code": "timeout", "done": True}) + else: + task.output_queue.put({"error": "request timeout", "code": "timeout"}) + task.output_queue.put(_END) + continue + if task.stream: + try: + for item in svc.stream(task.payload): + if isinstance(item, dict): + self._bind_session(item.get("session_id"), idx) + if item.get("done") and item.get("stopped"): + with self._lock: + self._metrics["cancelled"] += 1.0 + task.output_queue.put(item) + with self._lock: + self._metrics["completed"] += 1.0 + except Exception as exc: + with self._lock: + self._metrics["failed"] += 1.0 + task.output_queue.put({"error": str(exc), "done": True}) + finally: + task.output_queue.put(_END) + else: + try: + result = svc.generate(task.payload) + if isinstance(result, dict): + self._bind_session(result.get("session_id"), idx) + task.output_queue.put(result) + with self._lock: + self._metrics["completed"] += 1.0 + if isinstance(result, dict) and result.get("stopped"): + self._metrics["cancelled"] += 1.0 + except Exception as exc: + with self._lock: + self._metrics["failed"] += 1.0 + task.output_queue.put({"error": str(exc)}) + finally: + task.output_queue.put(_END) + finally: + q.task_done() + + def _worker_loop_continuous(self, idx: int) -> None: + svc = self._services[idx] + q = self._queues[idx] + # Raw tasks waiting for prefill (not yet started) + prefill_pending: "deque[InferenceTask]" = deque() + # Fallback: legacy iterator-based active tasks + fallback_prefill: "deque[_ActiveTask]" = deque() + fallback_decode: List[_ActiveTask] = [] + # Batch-driven decode state (from prepare_batch) + batch_state: Optional[Any] = None + batch_tasks: List[InferenceTask] = [] # parallel to batch_state.sequences + + def _append_from_queue(block: bool) -> None: + while True: + try: + task = q.get(block=block, timeout=0.1 if block else 0.0) + except queue.Empty: + return + if task is None: + q.task_done() + return + prefill_pending.append(task) + q.task_done() + block = False + + def _emit_chunk(task: InferenceTask, chunk: dict) -> None: + """Send a stream chunk or accumulate for non-stream.""" + task.output_queue.put(chunk) + + def _emit_final_stream(task: InferenceTask, context_id: str, + finish_reason: str, prompt_len: int, + gen_len: int, stopped: bool) -> None: + usage = { + "prompt_tokens": prompt_len, + "completion_tokens": gen_len, + "total_tokens": prompt_len + gen_len, + } + from llaisys.server import _wrap_chunk + chunk = _wrap_chunk(context_id, None, finish_reason, usage=usage, stopped=stopped) + task.output_queue.put(chunk) + task.output_queue.put(_END) + + def _emit_final_non_stream(task: InferenceTask, context_id: str, + content: str, finish_reason: str, + prompt_len: int, gen_len: int, + stopped: bool) -> None: + usage = { + "prompt_tokens": prompt_len, + "completion_tokens": gen_len, + "total_tokens": prompt_len + gen_len, + } + from llaisys.server import _wrap_completion + result = _wrap_completion(context_id, content, finish_reason, usage, stopped=stopped) + task.output_queue.put(result) + task.output_queue.put(_END) + + # --- Fallback helpers (legacy iterator path) --- + def _step_once(state: _ActiveTask) -> str: + task = state.task + it = state.iterator + if task.deadline_at is not None and time.time() > task.deadline_at: + with self._lock: + self._metrics["timed_out"] += 1.0 + if task.stream: + task.output_queue.put({"error": "request timeout", "code": "timeout", "done": True}) + else: + task.output_queue.put({"error": "request timeout", "code": "timeout"}) + task.output_queue.put(_END) + return "done" + try: + item = next(it) + if isinstance(item, dict): + self._bind_session(item.get("session_id"), idx) + + def _is_final(d: dict) -> bool: + if d.get("done"): + return True + choices = d.get("choices") + if choices and isinstance(choices, list) and len(choices) > 0: + if choices[0].get("finish_reason") is not None: + return True + return False + + def _is_stopped(d: dict) -> bool: + if d.get("stopped"): + return True + choices = d.get("choices") + if choices and isinstance(choices, list) and len(choices) > 0: + if choices[0].get("finish_reason") == "stop": + return True + return False + + if task.stream: + if not isinstance(item, dict): + raise RuntimeError("stream item must be dict") + task.output_queue.put(item) + state.emitted_any = True + if _is_final(item): + with self._lock: + self._metrics["completed"] += 1.0 + if _is_stopped(item): + self._metrics["cancelled"] += 1.0 + task.output_queue.put(_END) + return "done" + return "keep" + if isinstance(item, dict) and _is_final(item): + if item.get("error"): + with self._lock: + self._metrics["failed"] += 1.0 + task.output_queue.put({"error": str(item.get("error"))}) + else: + result = dict(item) + choices = result.get("choices") + if choices and isinstance(choices, list) and len(choices) > 0: + c = dict(choices[0]) + acc = getattr(state, "accumulated_content", "") + delta = c.pop("delta", {}) + final_content = acc + delta.get("content", "") + c["message"] = {"role": "assistant", "content": final_content} + result["choices"] = [c] + if result.get("object") == "chat.completion.chunk": + result["object"] = "chat.completion" + task.output_queue.put(result) + with self._lock: + self._metrics["completed"] += 1.0 + if _is_stopped(item): + self._metrics["cancelled"] += 1.0 + task.output_queue.put(_END) + return "done" + choices = item.get("choices") + if choices and isinstance(choices, list) and len(choices) > 0: + delta = choices[0].get("delta", {}) + content = delta.get("content", "") + if content: + state.accumulated_content = getattr(state, "accumulated_content", "") + content + return "keep" + except StopIteration: + with self._lock: + self._metrics["failed"] += 1.0 + if task.stream: + task.output_queue.put({"error": "stream ended unexpectedly", "done": True}) + else: + task.output_queue.put({"error": "task ended unexpectedly"}) + task.output_queue.put(_END) + return "done" + except Exception as exc: + with self._lock: + self._metrics["failed"] += 1.0 + if task.stream: + task.output_queue.put({"error": str(exc), "done": True}) + else: + task.output_queue.put({"error": str(exc)}) + task.output_queue.put(_END) + return "done" + + while not self._stop.is_set(): + has_work = ( + prefill_pending or fallback_prefill or fallback_decode + or (batch_state is not None) + ) + if not has_work: + _append_from_queue(block=True) + if not prefill_pending: + continue + else: + _append_from_queue(block=False) + + with self._lock: + total_active = ( + len(prefill_pending) + len(fallback_prefill) + len(fallback_decode) + + (len([s for s in batch_state.sequences if not s.finished]) if batch_state else 0) + ) + self._metrics["batch_rounds"] += 1.0 + self._metrics["batch_active_sum"] += float(total_active) + self._metrics["batch_last_active"] = float(total_active) + self._metrics["prefill_last_active"] = float(len(prefill_pending) + len(fallback_prefill)) + decode_count = len(fallback_decode) + ( + len([s for s in batch_state.sequences if not s.finished]) if batch_state else 0 + ) + self._metrics["decode_last_active"] = float(decode_count) + + # ============================================================ + # P stage: try batch prefill for pending tasks + # ============================================================ + if prefill_pending and batch_state is None: + with self._lock: + self._metrics["prefill_rounds"] += 1.0 + + # Collect candidates up to max_batch_size + batch_candidates: List[InferenceTask] = [] + remaining: "deque[InferenceTask]" = deque() + decode_active_count = len(fallback_decode) + slots = self._max_batch_size - decode_active_count + + while prefill_pending and len(batch_candidates) < slots: + task = prefill_pending.popleft() + # Check deadline + if task.deadline_at is not None and time.time() > task.deadline_at: + with self._lock: + self._metrics["timed_out"] += 1.0 + if task.stream: + task.output_queue.put({"error": "request timeout", "code": "timeout", "done": True}) + else: + task.output_queue.put({"error": "request timeout", "code": "timeout"}) + task.output_queue.put(_END) + continue + batch_candidates.append(task) + + if not batch_candidates: + pass # all timed out + elif len(batch_candidates) >= 1 and hasattr(svc, "prepare_batch"): + # Try batch path + try: + payloads = [t.payload for t in batch_candidates] + result = svc.prepare_batch(payloads) + except Exception as exc: + logger.debug("prepare_batch failed: %s", exc, exc_info=True) + result = None + + if result is not None: + batch_state = result + batch_tasks = list(batch_candidates) + with self._lock: + self._metrics["stream_batch_prefill_batches"] += 1.0 + self._metrics["stream_batch_prefill_tasks"] += float(len(batch_candidates)) + + # Emit first token chunks for each sequence + from llaisys.server import _wrap_chunk + for i, seq in enumerate(batch_state.sequences): + task = batch_tasks[i] + self._bind_session(seq.context_id, idx) + if seq.filtered_text and task.stream: + chunk = _wrap_chunk(seq.context_id, seq.filtered_text, None) + task.output_queue.put(chunk) + # If already finished after prefill + if seq.finished: + if task.stream: + _emit_final_stream( + task, seq.context_id, seq.finish_reason or "stop", + len(seq.prompt_ids), len(seq.generated_ids), + stopped=bool(seq.cancel_event and seq.cancel_event.is_set()), + ) + else: + _emit_final_non_stream( + task, seq.context_id, seq.filtered_text, + seq.finish_reason or "stop", + len(seq.prompt_ids), len(seq.generated_ids), + stopped=bool(seq.cancel_event and seq.cancel_event.is_set()), + ) + with self._lock: + self._metrics["completed"] += 1.0 + svc.finalize_sequence(batch_state, i) + # If all finished after prefill, clear batch + if all(s.finished for s in batch_state.sequences): + batch_state = None + batch_tasks = [] + else: + # Fallback: push to legacy iterator path + with self._lock: + self._metrics["stream_batch_fallback_tasks"] += float(len(batch_candidates)) + for task in batch_candidates: + try: + it = svc.stream(task.payload) + fallback_prefill.append(_ActiveTask(task=task, iterator=it)) + except Exception as exc: + if task.stream: + task.output_queue.put({"error": str(exc), "done": True}) + else: + task.output_queue.put({"error": str(exc)}) + task.output_queue.put(_END) + with self._lock: + self._metrics["failed"] += 1.0 + else: + # No prepare_batch available, use legacy path + with self._lock: + self._metrics["stream_batch_fallback_tasks"] += float(len(batch_candidates)) + for task in batch_candidates: + try: + it = svc.stream(task.payload) + fallback_prefill.append(_ActiveTask(task=task, iterator=it)) + except Exception as exc: + if task.stream: + task.output_queue.put({"error": str(exc), "done": True}) + else: + task.output_queue.put({"error": str(exc)}) + task.output_queue.put(_END) + with self._lock: + self._metrics["failed"] += 1.0 + + # Legacy P stage: step fallback prefill tasks one at a time + if fallback_prefill: + with self._lock: + if not prefill_pending: + self._metrics["prefill_rounds"] += 1.0 + + # Try packed prefill for non-stream fallback tasks + packed_candidates: List[_ActiveTask] = [] + for state in fallback_prefill: + if state.task.stream: + continue + packed_candidates.append(state) + if len(packed_candidates) >= self._max_batch_size: + break + if len(packed_candidates) >= 2 and ( + hasattr(svc, "generate_packed_non_stream") or hasattr(svc, "generate_packed_once") + ): + packed_exception = False + with self._lock: + self._metrics["packed_prefill_attempts"] += 1.0 + self._metrics["packed_prefill_candidate_tasks"] += float(len(packed_candidates)) + try: + packed_payloads = [st.task.payload for st in packed_candidates] + if hasattr(svc, "generate_packed_non_stream"): + packed_results = svc.generate_packed_non_stream(packed_payloads) + else: + packed_results = svc.generate_packed_once(packed_payloads) + except Exception as exc: + packed_exception = True + with self._lock: + self._metrics["packed_prefill_exceptions"] += 1.0 + self._packed_prefill_last_error = str(exc) + packed_results = None + if isinstance(packed_results, list) and len(packed_results) == len(packed_candidates): + packed_ids = {id(st) for st in packed_candidates} + fallback_prefill = deque([st for st in fallback_prefill if id(st) not in packed_ids]) + for st, result in zip(packed_candidates, packed_results): + st.task.output_queue.put(result) + st.task.output_queue.put(_END) + with self._lock: + self._metrics["completed"] += float(len(packed_candidates)) + self._metrics["packed_prefill_batches"] += 1.0 + self._metrics["packed_prefill_tasks"] += float(len(packed_candidates)) + self._packed_prefill_last_error = "" + # Skip single step below if we consumed all + if not fallback_prefill: + pass + elif not packed_exception: + with self._lock: + self._metrics["packed_prefill_none_returns"] += 1.0 + + if fallback_prefill: + state = fallback_prefill.popleft() + status = _step_once(state) + if status == "keep": + fallback_decode.append(state) + + # ============================================================ + # D stage: batch decode + # ============================================================ + if batch_state is not None: + active_before = len([s for s in batch_state.sequences if not s.finished]) + with self._lock: + self._metrics["decode_rounds"] += 1.0 + self._metrics["stream_batch_decode_rounds"] += 1.0 + self._metrics["stream_batch_decode_active_sum"] += float(active_before) + + try: + step_results = svc.step_batch(batch_state) + except Exception as exc: + logger.debug("step_batch failed: %s", exc, exc_info=True) + # Mark all active as failed + for i, seq in enumerate(batch_state.sequences): + if not seq.finished: + seq.finished = True + task = batch_tasks[i] + if task.stream: + task.output_queue.put({"error": str(exc), "done": True}) + else: + task.output_queue.put({"error": str(exc)}) + task.output_queue.put(_END) + with self._lock: + self._metrics["failed"] += 1.0 + batch_state = None + batch_tasks = [] + step_results = None + + if step_results is not None: + from llaisys.server import _wrap_chunk + for sr in step_results: + task = batch_tasks[sr.seq_index] + seq = batch_state.sequences[sr.seq_index] + + if sr.delta_text and task.stream: + chunk = _wrap_chunk(seq.context_id, sr.delta_text, None) + task.output_queue.put(chunk) + + if sr.finished: + if task.stream: + _emit_final_stream( + task, seq.context_id, sr.finish_reason or "stop", + len(seq.prompt_ids), len(seq.generated_ids), + stopped=sr.stopped, + ) + else: + _emit_final_non_stream( + task, seq.context_id, seq.filtered_text, + sr.finish_reason or "stop", + len(seq.prompt_ids), len(seq.generated_ids), + stopped=sr.stopped, + ) + with self._lock: + self._metrics["completed"] += 1.0 + if sr.stopped: + self._metrics["cancelled"] += 1.0 + svc.finalize_sequence(batch_state, sr.seq_index) + + # Check for shrink events + active_after = len([s for s in batch_state.sequences if not s.finished]) + if active_after < active_before and active_after > 0: + with self._lock: + self._metrics["stream_batch_shrink_events"] += 1.0 + + # Clear batch if all done + if all(s.finished for s in batch_state.sequences): + batch_state = None + batch_tasks = [] + + # Legacy D stage: iterate fallback decode tasks + if fallback_decode: + with self._lock: + self._metrics["decode_rounds"] += 1.0 + next_decode: List[_ActiveTask] = [] + for state in fallback_decode: + status = _step_once(state) + if status == "keep": + next_decode.append(state) + fallback_decode = next_decode diff --git a/python/llaisys/server.py b/python/llaisys/server.py new file mode 100644 index 000000000..6c73af966 --- /dev/null +++ b/python/llaisys/server.py @@ -0,0 +1,1102 @@ +from __future__ import annotations + +import argparse +import json +import re +import threading +from dataclasses import dataclass, field +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple +from urllib.parse import parse_qs, urlparse + +import llaisys +from llaisys.interfaces import IInferenceService +from llaisys.kv_cache_pool import KVCachePool +from llaisys.kv_runtime_bridge import KVRuntimeBridge +from llaisys.libllaisys import LlaisysSamplingParams +from llaisys.models import Qwen2 +from llaisys.scheduler import InferenceScheduler, SchedulerQueueFullError, TaskTimeoutError + + +# --------------------------------------------------------------------------- +# Streaming batch data structures +# --------------------------------------------------------------------------- + +@dataclass +class BatchSequenceState: + index: int + context_id: str + messages: List[Dict[str, str]] + prompt_ids: List[int] + generated_ids: List[int] = field(default_factory=list) + filtered_text: str = "" + max_new_tokens: int = 128 + sampling: Dict[str, Any] = field(default_factory=dict) + sampling_params: Optional[LlaisysSamplingParams] = None + use_sampling: bool = False + cancel_event: Optional[threading.Event] = None + finished: bool = False + finish_reason: Optional[str] = None + + +@dataclass +class BatchState: + sequences: List[BatchSequenceState] + any_sampling: bool + eos_token: int + + +@dataclass +class StepResult: + seq_index: int + delta_text: str + finished: bool + finish_reason: Optional[str] + stopped: bool = False +from llaisys.session_manager import SessionManager + + +def _wrap_completion( + session_id: str, + content: str, + finish_reason: str, + usage: Dict[str, int], + stopped: bool = False, +) -> Dict[str, Any]: + result: Dict[str, Any] = { + "id": f"chatcmpl-{session_id}", + "object": "chat.completion", + "model": "qwen2", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": finish_reason, + } + ], + "usage": usage, + "session_id": session_id, + } + if stopped: + result["stopped"] = True + return result + + +def _wrap_chunk( + session_id: str, + delta_content: Optional[str], + finish_reason: Optional[str], + usage: Optional[Dict[str, int]] = None, + stopped: bool = False, +) -> Dict[str, Any]: + delta: Dict[str, str] = {} + if delta_content is not None: + delta["content"] = delta_content + chunk: Dict[str, Any] = { + "id": f"chatcmpl-{session_id}", + "object": "chat.completion.chunk", + "model": "qwen2", + "choices": [ + { + "index": 0, + "delta": delta, + "finish_reason": finish_reason, + } + ], + "session_id": session_id, + } + if usage is not None: + chunk["usage"] = usage + if stopped: + chunk["stopped"] = True + return chunk + + +def _wrap_error(message: str, error_type: str = "server_error", code: str = "") -> Dict[str, Any]: + err: Dict[str, Any] = {"error": {"message": message, "type": error_type}} + if code: + err["error"]["code"] = code + return err + + +class ChatService(IInferenceService): + def __init__( + self, + model: Qwen2, + tokenizer: llaisys.Tokenizer, + model_path: Optional[str] = None, + enable_kv_runtime_reuse: bool = False, + block_size: int = 64, + max_blocks: int = 4096, + max_bytes: int = 256 * 1024 * 1024, + model_lock: Optional[threading.RLock] = None, + kv_pool: Optional[KVCachePool] = None, + kv_bridge: Optional[KVRuntimeBridge] = None, + ) -> None: + self.model = model + self.tokenizer = tokenizer + self._enable_kv_runtime_reuse = bool(enable_kv_runtime_reuse) + # RLock allows cooperative iterator-level scheduling in continuous-batching mode. + self._model_lock = model_lock if model_lock is not None else threading.RLock() + + # Delegated components + self._session_mgr = SessionManager() + self._kv_bridge = kv_bridge if kv_bridge is not None else KVRuntimeBridge(model, enabled=enable_kv_runtime_reuse) + self._kv_pool = kv_pool if kv_pool is not None else KVCachePool( + block_size=block_size, + max_blocks=max_blocks, + max_bytes=max_bytes, + ) + self._active_tokens: List[int] = [] + + # Text processing + self._chat_template_tokenizer = self._init_chat_template_tokenizer(model_path) + self._filter_tokens = ("<|end_of_sentence|>",) + self._filter_patterns = [ + re.compile(r"<\s*\|\s*end_of_sentence\s*\|\s*>", re.IGNORECASE), + re.compile(r"<\s*\|[^>]*\|\s*>"), + re.compile(r"<\s*[\|\uFF5C][^>]*[\|\uFF5C]\s*>"), + re.compile( + r"<\s*[\|\uFF5C]\s*end[\s_\u2581]*of[\s_\u2581]*sentence\s*[\|\uFF5C]\s*>", + re.IGNORECASE, + ), + ] + + @property + def kv_pool(self) -> KVCachePool: + """暴露 KVCache 池给调度器查询""" + return self._kv_pool + + def tokenize_for_routing(self, payload: Dict[str, Any]) -> Optional[List[int]]: + """为 KV 感知路由进行轻量级 tokenize + + 尝试从 payload 构建 prompt 并 tokenize,用于调度器查询 KV 命中。 + 失败时返回 None,不影响正常请求处理。 + + Args: + payload: 请求参数 + + Returns: + token ids 列表,或 None(如果无法 tokenize) + """ + try: + # 尝试提取 messages + messages = payload.get("messages") + prompt_text = payload.get("prompt") + system_prompt = payload.get("system_prompt") + + if messages is not None: + if not isinstance(messages, list): + return None + prompt = self._render_prompt(list(messages), str(system_prompt) if system_prompt else None) + elif prompt_text is not None: + # 简单 prompt,尝试获取历史 + session_id = str(payload.get("session_id") or "").strip() + history = self._session_mgr.get_messages(session_id) + history.append({"role": "user", "content": str(prompt_text)}) + prompt = self._render_prompt(history, str(system_prompt) if system_prompt else None) + else: + return None + + return self.tokenizer.encode(prompt) + except Exception: + return None + + @staticmethod + def _init_chat_template_tokenizer(model_path: Optional[str]): + if not model_path: + return None + try: + from transformers import AutoTokenizer + except Exception: + return None + try: + return AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + except Exception: + return None + + def _postprocess_text(self, text: str) -> str: + for token in self._filter_tokens: + text = text.replace(token, "") + for pattern in self._filter_patterns: + text = pattern.sub("", text) + return text + + def request_stop(self, context_id: str) -> bool: + return self._session_mgr.request_stop(context_id) + + def kv_debug_snapshot(self, session_id: Optional[str] = None) -> Dict[str, Any]: + snapshot = self._kv_bridge.debug_snapshot(session_id) + snapshot["kv_pool"] = self._kv_pool.snapshot_stats() + return snapshot + + def _render_prompt(self, messages: List[Dict[str, str]], system_prompt: Optional[str]) -> str: + templated_messages: List[Dict[str, str]] = [] + if system_prompt: + templated_messages.append({"role": "system", "content": str(system_prompt)}) + templated_messages.extend(messages) + + if self._chat_template_tokenizer is not None: + try: + return self._chat_template_tokenizer.apply_chat_template( + templated_messages, + add_generation_prompt=True, + tokenize=False, + ) + except Exception: + pass + return Qwen2.build_prompt( + messages, + system_prompt=str(system_prompt) if system_prompt else None, + add_generation_prompt=True, + ) + + def _eos_token(self) -> int: + eos = getattr(self.model, "_meta", None) + if eos is None: + return -1 + end_token = getattr(eos, "end_token", -1) + return int(getattr(end_token, "value", end_token)) + + def _decode_next( + self, + token_ids: List[int], + use_sampling: bool, + sampling: Dict[str, Any], + ) -> int: + top_k = int(sampling.get("top_k", 1)) + top_p = float(sampling.get("top_p", 0.0)) + temperature = float(sampling.get("temperature", 0.0)) + seed = int(sampling.get("seed", 0)) + if use_sampling: + return int( + self.model.step_sampling( + token_ids, + top_k=top_k, + top_p=top_p, + temperature=temperature, + seed=seed, + ) + ) + return int(self.model.step(token_ids)) + + def _prefill_next( + self, + prompt_ids: List[int], + use_sampling: bool, + sampling: Dict[str, Any], + ) -> int: + top_k = int(sampling.get("top_k", 1)) + top_p = float(sampling.get("top_p", 0.0)) + temperature = float(sampling.get("temperature", 0.0)) + seed = int(sampling.get("seed", 0)) + if use_sampling: + return int( + self.model.prefill_sampling( + prompt_ids, + top_k=top_k, + top_p=top_p, + temperature=temperature, + seed=seed, + ) + ) + return int(self.model.prefill(prompt_ids)) + + def _iter_generate_ids( + self, + prompt_ids: List[int], + max_new_tokens: int, + sampling: Dict[str, Any], + prefix_len: int, + cancel_event: threading.Event, + ) -> Iterable[int]: + mode = str(sampling.get("mode", "")).strip().lower() + top_k = int(sampling.get("top_k", 1)) + top_p = float(sampling.get("top_p", 0.0)) + temperature = float(sampling.get("temperature", 0.0)) + if mode == "argmax": + use_sampling = False + elif mode == "sample": + use_sampling = True + else: + use_sampling = temperature > 0.0 or top_k > 1 or top_p > 0.0 + + if cancel_event.is_set(): + return + + can_reuse_active_prefix = ( + self._enable_kv_runtime_reuse + and prefix_len > 0 + and len(self._active_tokens) == prefix_len + and self._active_tokens[:prefix_len] == prompt_ids[:prefix_len] + and len(prompt_ids) > prefix_len + ) + if can_reuse_active_prefix: + next_token = self._decode_next(prompt_ids[prefix_len:], use_sampling, sampling) + self._active_tokens = list(prompt_ids) + else: + self.model.reset_kv_cache() + next_token = self._prefill_next(prompt_ids, use_sampling, sampling) + self._active_tokens = list(prompt_ids) + + if next_token < 0: + return + + eos = self._eos_token() + yield next_token + self._active_tokens.append(next_token) + for _ in range(max_new_tokens - 1): + if cancel_event.is_set(): + break + if eos >= 0 and next_token == eos: + break + next_token = self._decode_next([next_token], use_sampling, sampling) + if next_token < 0: + break + yield next_token + self._active_tokens.append(next_token) + + def _prepare_request(self, payload: Dict[str, Any]) -> Tuple[str, List[Dict[str, str]], List[int], Dict[str, Any], int]: + system_prompt = payload.get("system_prompt") + # Accept OpenAI's max_tokens as alias; prefer it over max_new_tokens + if "max_tokens" in payload: + max_new_tokens = int(payload["max_tokens"]) + else: + max_new_tokens = int(payload.get("max_new_tokens", 128)) + # model field accepted and ignored + sampling = { + "mode": payload.get("sampling"), + "top_k": payload.get("top_k", 1), + "top_p": payload.get("top_p", 0.0), + "temperature": payload.get("temperature", 0.0), + "seed": payload.get("seed", 0), + } + + context_id, messages = self._session_mgr.extract_messages(payload) + prompt = self._render_prompt(messages, str(system_prompt) if system_prompt else None) + prompt_ids = self.tokenizer.encode(prompt) + return context_id, messages, prompt_ids, sampling, max_new_tokens + + def generate_packed_non_stream(self, payloads: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]: + """Best-effort packed non-stream path (greedy + sampling). + + Current safe scope: + - non-stream requests only + - greedy and sampling requests (mixed batches supported) + - no history-edit branching fields + + When any request uses sampling, the batch is routed through the + packed-sampling C API. If that API is unavailable (old DLL), sampling + requests fall back to ``None`` so the scheduler handles them one by one. + Pure-greedy batches still use the original fast ``prefill_packed`` path. + """ + if not payloads: + return [] + if not hasattr(self.model, "prefill_packed") or not hasattr(self.model, "step_packed"): + return None + + prepared: List[Tuple[str, List[Dict[str, str]], List[int], Dict[str, Any], int]] = [] + any_sampling = False + sampling_params_list: List[LlaisysSamplingParams] = [] + for payload in payloads: + if payload.get("stream", False): + return None + # History editing introduces branch semantics; keep packed path conservative for now. + if payload.get("edit_from_session_id"): + return None + try: + context_id, messages, prompt_ids, sampling, max_new_tokens = self._prepare_request(payload) + except Exception: + return None + if max_new_tokens <= 0: + return None + mode = str(sampling.get("mode", "")).strip().lower() + top_k = int(sampling.get("top_k", 1)) + top_p = float(sampling.get("top_p", 0.0)) + temperature = float(sampling.get("temperature", 0.0)) + seed = int(sampling.get("seed", 0)) + if mode == "argmax": + use_sampling = False + elif mode == "sample": + use_sampling = True + else: + use_sampling = temperature > 0.0 or top_k > 1 or top_p > 0.0 + if use_sampling: + any_sampling = True + sampling_params_list.append(LlaisysSamplingParams( + top_k=top_k, top_p=top_p, + temperature=temperature, seed=seed, + )) + prepared.append((context_id, messages, prompt_ids, sampling, max_new_tokens)) + + # If any request needs sampling, check for the packed-sampling API. + # Fall back to None (single-request path) when the new DLL is absent. + if any_sampling: + if not hasattr(self.model, "prefill_packed_sampling") or not hasattr(self.model, "step_packed_sampling"): + return None + + prompts = [it[2] for it in prepared] + generated_all: List[List[int]] = [[] for _ in prepared] + last_step_inputs: List[int] = [int(p[-1]) if p else 0 for p in prompts] + max_new_tokens_list = [int(it[4]) for it in prepared] + eos = self._eos_token() + with self._model_lock: + self.model.reset_kv_cache() + if any_sampling: + next_tokens = self.model.prefill_packed_sampling(prompts, sampling_params_list) + else: + next_tokens = self.model.prefill_packed(prompts) + if len(next_tokens) != len(prepared): + return None + for i, tok in enumerate(next_tokens): + t = int(tok) + if t >= 0: + generated_all[i].append(t) + last_step_inputs[i] = t + # Continue decode rounds for unfinished requests (dynamic shrinking). + while True: + active_indices: List[int] = [] + decode_inputs: List[List[int]] = [] + active_sp: List[LlaisysSamplingParams] = [] + for i in range(len(generated_all)): + gen = generated_all[i] + if not gen: + continue + if len(gen) >= max_new_tokens_list[i]: + continue + if eos >= 0 and gen[-1] == eos: + continue + active_indices.append(i) + decode_inputs.append([int(last_step_inputs[i])]) + active_sp.append(sampling_params_list[i]) + if not active_indices: + break + if any_sampling: + step_tokens = self.model.step_packed_sampling(decode_inputs, active_sp) + else: + step_tokens = self.model.step_packed(decode_inputs) + if len(step_tokens) != len(active_indices): + return None + for j, tok in enumerate(step_tokens): + ai = active_indices[j] + t = int(tok) + if t >= 0: + generated_all[ai].append(t) + last_step_inputs[ai] = t + + out: List[Dict[str, Any]] = [] + for i, (context_id, messages, prompt_ids, _sampling, _max_new_tokens) in enumerate(prepared): + generated_ids = list(generated_all[i]) + response_text = self._postprocess_text(self.tokenizer.decode(generated_ids)) + messages2 = list(messages) + messages2.append({"role": "assistant", "content": response_text}) + self._session_mgr.save_messages(context_id, messages2) + self._session_mgr.clear_stop(context_id) + usage = { + "prompt_tokens": len(prompt_ids), + "completion_tokens": len(generated_ids), + "total_tokens": len(prompt_ids) + len(generated_ids), + } + hit_limit = len(generated_ids) >= _max_new_tokens + finish_reason = "length" if hit_limit else "stop" + out.append(_wrap_completion(context_id, response_text, finish_reason, usage)) + return out + + # Backward-compatible alias used by scheduler tests/mocks. + def generate_packed_once(self, payloads: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]: + return self.generate_packed_non_stream(payloads) + + def generate(self, payload: Dict[str, Any]) -> Dict[str, Any]: + context_id, messages, prompt_ids, sampling, max_new_tokens = self._prepare_request(payload) + cancel_event = self._session_mgr.get_cancel_event(context_id) + self._session_mgr.clear_stop(context_id) + + with self._model_lock: + acquire = self._kv_pool.acquire_context(context_id, prompt_ids) + self._kv_bridge.bind_for_request(context_id, prompt_ids, acquire.prefix_len) + generated_ids: List[int] = [] + try: + for token_id in self._iter_generate_ids( + prompt_ids=prompt_ids, + max_new_tokens=max_new_tokens, + sampling=sampling, + prefix_len=acquire.prefix_len, + cancel_event=cancel_event, + ): + generated_ids.append(int(token_id)) + cancelled = cancel_event.is_set() + if cancelled: + self._active_tokens = list(prompt_ids) + self._kv_pool.update_context(context_id, prompt_ids) + else: + self._kv_pool.update_context(context_id, self._active_tokens) + self._kv_bridge.export_after_request( + context_id, self._active_tokens, self._kv_pool.block_size + ) + except Exception: + self._kv_pool.release_context(context_id) + self._kv_bridge.release(context_id) + raise + + response_text = self._postprocess_text(self.tokenizer.decode(generated_ids)) + usage = { + "prompt_tokens": len(prompt_ids), + "completion_tokens": len(generated_ids), + "total_tokens": len(prompt_ids) + len(generated_ids), + } + if cancel_event.is_set(): + self._session_mgr.clear_stop(context_id) + return _wrap_completion(context_id, response_text, "stop", usage, stopped=True) + hit_limit = len(generated_ids) >= max_new_tokens + finish_reason = "length" if hit_limit else "stop" + messages = list(messages) + messages.append({"role": "assistant", "content": response_text}) + self._session_mgr.save_messages(context_id, messages) + self._session_mgr.clear_stop(context_id) + return _wrap_completion(context_id, response_text, finish_reason, usage) + + def stream(self, payload: Dict[str, Any]) -> Iterable[Dict[str, Any]]: + context_id, messages, prompt_ids, sampling, max_new_tokens = self._prepare_request(payload) + cancel_event = self._session_mgr.get_cancel_event(context_id) + self._session_mgr.clear_stop(context_id) + + generated_ids: List[int] = [] + filtered = "" + with self._model_lock: + acquire = self._kv_pool.acquire_context(context_id, prompt_ids) + self._kv_bridge.bind_for_request(context_id, prompt_ids, acquire.prefix_len) + try: + for token_id in self._iter_generate_ids( + prompt_ids=prompt_ids, + max_new_tokens=max_new_tokens, + sampling=sampling, + prefix_len=acquire.prefix_len, + cancel_event=cancel_event, + ): + generated_ids.append(int(token_id)) + new_text = self.tokenizer.decode(generated_ids) + new_filtered = self._postprocess_text(new_text) + delta = new_filtered[len(filtered) :] + filtered = new_filtered + if delta: + yield _wrap_chunk(context_id, delta, None) + cancelled = cancel_event.is_set() + if cancelled: + self._active_tokens = list(prompt_ids) + self._kv_pool.update_context(context_id, prompt_ids) + else: + self._kv_pool.update_context(context_id, self._active_tokens) + self._kv_bridge.export_after_request( + context_id, self._active_tokens, self._kv_pool.block_size + ) + except Exception: + self._kv_pool.release_context(context_id) + self._kv_bridge.release(context_id) + raise + + if cancel_event.is_set(): + self._session_mgr.clear_stop(context_id) + usage = { + "prompt_tokens": len(prompt_ids), + "completion_tokens": len(generated_ids), + "total_tokens": len(prompt_ids) + len(generated_ids), + } + yield _wrap_chunk(context_id, None, "stop", usage=usage, stopped=True) + return + + messages = list(messages) + messages.append({"role": "assistant", "content": filtered}) + self._session_mgr.save_messages(context_id, messages) + self._session_mgr.clear_stop(context_id) + hit_limit = len(generated_ids) >= max_new_tokens + finish_reason = "length" if hit_limit else "stop" + usage = { + "prompt_tokens": len(prompt_ids), + "completion_tokens": len(generated_ids), + "total_tokens": len(prompt_ids) + len(generated_ids), + } + yield _wrap_chunk(context_id, None, finish_reason, usage=usage) + + # ------------------------------------------------------------------ + # Streaming batch API (Phase 1) + # ------------------------------------------------------------------ + + def prepare_batch(self, payloads: List[Dict[str, Any]]) -> Optional[BatchState]: + """Prefill all sequences in a batch, return BatchState or None to fall back.""" + if not payloads: + return None + if not hasattr(self.model, "prefill_packed") or not hasattr(self.model, "step_packed"): + return None + + sequences: List[BatchSequenceState] = [] + any_sampling = False + sampling_params_list: List[LlaisysSamplingParams] = [] + + for i, payload in enumerate(payloads): + # Edit-fork requests are not supported in batch path + if payload.get("edit_from_session_id"): + return None + try: + context_id, messages, prompt_ids, sampling, max_new_tokens = self._prepare_request(payload) + except Exception: + return None + if max_new_tokens <= 0: + return None + + mode = str(sampling.get("mode", "")).strip().lower() + top_k = int(sampling.get("top_k", 1)) + top_p = float(sampling.get("top_p", 0.0)) + temperature = float(sampling.get("temperature", 0.0)) + seed = int(sampling.get("seed", 0)) + if mode == "argmax": + use_sampling = False + elif mode == "sample": + use_sampling = True + else: + use_sampling = temperature > 0.0 or top_k > 1 or top_p > 0.0 + if use_sampling: + any_sampling = True + + sp = LlaisysSamplingParams(top_k=top_k, top_p=top_p, temperature=temperature, seed=seed) + cancel_event = self._session_mgr.get_cancel_event(context_id) + self._session_mgr.clear_stop(context_id) + + sequences.append(BatchSequenceState( + index=i, + context_id=context_id, + messages=messages, + prompt_ids=prompt_ids, + generated_ids=[], + filtered_text="", + max_new_tokens=max_new_tokens, + sampling=sampling, + sampling_params=sp, + use_sampling=use_sampling, + cancel_event=cancel_event, + finished=False, + finish_reason=None, + )) + sampling_params_list.append(sp) + + # Check for packed-sampling API if needed + if any_sampling: + if not hasattr(self.model, "prefill_packed_sampling") or not hasattr(self.model, "step_packed_sampling"): + return None + + prompts = [seq.prompt_ids for seq in sequences] + eos = self._eos_token() + + with self._model_lock: + self.model.reset_kv_cache() + if any_sampling: + next_tokens = self.model.prefill_packed_sampling(prompts, sampling_params_list) + else: + next_tokens = self.model.prefill_packed(prompts) + + if len(next_tokens) != len(sequences): + return None + + for i, tok in enumerate(next_tokens): + t = int(tok) + if t >= 0: + sequences[i].generated_ids.append(t) + # Decode and compute initial filtered text + new_text = self.tokenizer.decode(sequences[i].generated_ids) + sequences[i].filtered_text = self._postprocess_text(new_text) + else: + sequences[i].finished = True + sequences[i].finish_reason = "stop" + + # Check immediate termination + if not sequences[i].finished: + if eos >= 0 and t == eos: + sequences[i].finished = True + sequences[i].finish_reason = "stop" + elif len(sequences[i].generated_ids) >= sequences[i].max_new_tokens: + sequences[i].finished = True + sequences[i].finish_reason = "length" + elif sequences[i].cancel_event and sequences[i].cancel_event.is_set(): + sequences[i].finished = True + sequences[i].finish_reason = "stop" + + return BatchState(sequences=sequences, any_sampling=any_sampling, eos_token=eos) + + def step_batch(self, state: BatchState) -> List[StepResult]: + """Execute one decode step for all active sequences. Dynamic shrinking: skip finished.""" + results: List[StepResult] = [] + active_indices: List[int] = [] + decode_inputs: List[List[int]] = [] + sampling_params_active: List[LlaisysSamplingParams] = [] + + for i, seq in enumerate(state.sequences): + if seq.finished: + continue + if seq.cancel_event and seq.cancel_event.is_set(): + seq.finished = True + seq.finish_reason = "stop" + results.append(StepResult( + seq_index=i, delta_text="", finished=True, + finish_reason="stop", stopped=True, + )) + continue + active_indices.append(i) + last_tok = seq.generated_ids[-1] if seq.generated_ids else 0 + decode_inputs.append([last_tok]) + if seq.sampling_params is not None: + sampling_params_active.append(seq.sampling_params) + + if not active_indices: + return results + + with self._model_lock: + if state.any_sampling: + step_tokens = self.model.step_packed_sampling(decode_inputs, sampling_params_active) + else: + step_tokens = self.model.step_packed(decode_inputs) + + if len(step_tokens) != len(active_indices): + # Model returned unexpected count; mark all active as finished + for ai in active_indices: + seq = state.sequences[ai] + seq.finished = True + seq.finish_reason = "stop" + results.append(StepResult( + seq_index=ai, delta_text="", finished=True, + finish_reason="stop", stopped=False, + )) + return results + + for j, ai in enumerate(active_indices): + seq = state.sequences[ai] + t = int(step_tokens[j]) + + if t < 0: + seq.finished = True + seq.finish_reason = "stop" + results.append(StepResult( + seq_index=ai, delta_text="", finished=True, + finish_reason="stop", stopped=False, + )) + continue + + seq.generated_ids.append(t) + new_text = self.tokenizer.decode(seq.generated_ids) + new_filtered = self._postprocess_text(new_text) + delta = new_filtered[len(seq.filtered_text):] + seq.filtered_text = new_filtered + + # Check termination + finished = False + finish_reason = None + stopped = False + + if state.eos_token >= 0 and t == state.eos_token: + finished = True + finish_reason = "stop" + elif len(seq.generated_ids) >= seq.max_new_tokens: + finished = True + finish_reason = "length" + elif seq.cancel_event and seq.cancel_event.is_set(): + finished = True + finish_reason = "stop" + stopped = True + + if finished: + seq.finished = True + seq.finish_reason = finish_reason + + results.append(StepResult( + seq_index=ai, delta_text=delta, finished=finished, + finish_reason=finish_reason, stopped=stopped, + )) + + return results + + def finalize_sequence(self, state: BatchState, seq_index: int) -> None: + """Save session history and clean up for a completed sequence.""" + seq = state.sequences[seq_index] + if seq.cancel_event and seq.cancel_event.is_set(): + self._session_mgr.clear_stop(seq.context_id) + return + messages = list(seq.messages) + messages.append({"role": "assistant", "content": seq.filtered_text}) + self._session_mgr.save_messages(seq.context_id, messages) + self._session_mgr.clear_stop(seq.context_id) + + +class ChatHandler(BaseHTTPRequestHandler): + protocol_version = "HTTP/1.1" + scheduler: InferenceScheduler + + def _set_cors_headers(self) -> None: + self.send_header("Access-Control-Allow-Origin", "*") + self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + self.send_header("Access-Control-Allow-Headers", "Content-Type") + + def _send_json(self, code: int, payload: Dict[str, Any]) -> None: + data = json.dumps(payload, ensure_ascii=False).encode("utf-8") + self.send_response(code) + self.send_header("Content-Type", "application/json; charset=utf-8") + self._set_cors_headers() + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + + def _write_chunk(self, data: bytes) -> bool: + try: + self.wfile.write(f"{len(data):X}\r\n".encode("ascii")) + self.wfile.write(data) + self.wfile.write(b"\r\n") + self.wfile.flush() + return True + except (BrokenPipeError, ConnectionAbortedError, ConnectionResetError): + return False + + def do_GET(self) -> None: + parsed = urlparse(self.path) + if parsed.path == "/health": + self._send_json(200, {"status": "ok"}) + return + if parsed.path == "/debug/kv": + query = parse_qs(parsed.query) + session_id = str((query.get("session_id") or [""])[0]).strip() or None + payload = self.scheduler.kv_debug_snapshot(session_id) + self._send_json(200, payload) + return + if parsed.path == "/debug/scheduler": + self._send_json(200, self.scheduler.debug_snapshot()) + return + self._send_json(404, _wrap_error("not found", "invalid_request_error", "not_found")) + + def do_OPTIONS(self) -> None: + self.send_response(204) + self._set_cors_headers() + self.send_header("Content-Length", "0") + self.end_headers() + + def do_POST(self) -> None: + if self.path not in ("/chat", "/v1/chat/completions", "/chat/stop"): + self._send_json(404, _wrap_error("not found", "invalid_request_error", "not_found")) + return + + length = int(self.headers.get("Content-Length", "0")) + body = self.rfile.read(length) if length > 0 else b"{}" + try: + payload = json.loads(body.decode("utf-8")) + except Exception: + self._send_json(400, _wrap_error("invalid JSON", "invalid_request_error", "invalid_json")) + return + + if self.path == "/chat/stop": + session_id = str(payload.get("session_id") or "").strip() + if not session_id: + self._send_json(400, _wrap_error("session_id is required", "invalid_request_error", "missing_field")) + return + self.scheduler.request_stop(session_id) + self._send_json(200, {"ok": True, "session_id": session_id}) + return + + stream = bool(payload.get("stream", False)) + if not stream: + try: + handle = self.scheduler.submit(payload, stream=False) + result = handle.get_result(timeout=self.scheduler.request_timeout_seconds()) + if isinstance(result, dict) and result.get("error"): + code = 504 if result.get("code") == "timeout" else 400 + err = result.get("error") + err_code = str(result.get("code", "")) or "server_error" + self._send_json(code, _wrap_error(str(err), "server_error", err_code)) + return + except SchedulerQueueFullError as exc: + self._send_json(429, _wrap_error(str(exc), "server_error", "queue_full")) + return + except TaskTimeoutError as exc: + self._send_json(504, _wrap_error(str(exc), "server_error", "timeout")) + return + except RuntimeError as exc: + self._send_json(400, _wrap_error(str(exc), "server_error")) + return + self._send_json(200, result) + return + + self.send_response(200) + self.send_header("Content-Type", "text/event-stream; charset=utf-8") + self.send_header("Cache-Control", "no-cache") + self.send_header("Connection", "keep-alive") + self.send_header("Transfer-Encoding", "chunked") + self._set_cors_headers() + self.end_headers() + + current_session_id = "" + try: + handle = self.scheduler.submit(payload, stream=True) + for item in handle.iter_stream(timeout=self.scheduler.request_timeout_seconds()): + current_session_id = str(item.get("session_id") or current_session_id) + data = json.dumps(item, ensure_ascii=False).encode("utf-8") + if not self._write_chunk(b"data: " + data + b"\n\n"): + if current_session_id: + self.scheduler.request_stop(current_session_id) + return + self._write_chunk(b"data: [DONE]\n\n") + except SchedulerQueueFullError as exc: + err = _wrap_error(str(exc), "server_error", "queue_full") + data = json.dumps(err, ensure_ascii=False).encode("utf-8") + self._write_chunk(b"data: " + data + b"\n\n") + except TaskTimeoutError as exc: + if current_session_id: + self.scheduler.request_stop(current_session_id) + err = _wrap_error(str(exc), "server_error", "timeout") + data = json.dumps(err, ensure_ascii=False).encode("utf-8") + self._write_chunk(b"data: " + data + b"\n\n") + except Exception as exc: + if current_session_id: + self.scheduler.request_stop(current_session_id) + err = _wrap_error(str(exc), "server_error") + data = json.dumps(err, ensure_ascii=False).encode("utf-8") + self._write_chunk(b"data: " + data + b"\n\n") + finally: + self._write_chunk(b"") + + +def _resolve_tokenizer_path(model_path: str, tokenizer_path: Optional[str]) -> str: + if tokenizer_path: + return tokenizer_path + path = Path(model_path) + sp = path / "tokenizer.model" + if sp.exists(): + return str(sp) + hf = path / "tokenizer.json" + if hf.exists(): + return str(hf) + raise FileNotFoundError(f"No tokenizer.model or tokenizer.json found under: {path}") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, type=str, help="model directory") + parser.add_argument("--tokenizer", required=False, type=str, help="tokenizer file path") + parser.add_argument("--host", default="127.0.0.1", type=str) + parser.add_argument("--port", default=8000, type=int) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"]) + parser.add_argument("--pool-size", default=1, type=int, help="deprecated") + parser.add_argument( + "--kv-runtime-reuse", + action="store_true", + help="enable experimental runtime KV reuse fast-path", + ) + parser.add_argument("--kv-block-size", default=64, type=int, help="kv block token size") + parser.add_argument("--kv-max-blocks", default=4096, type=int, help="kv max block count") + parser.add_argument("--kv-max-bytes", default=268435456, type=int, help="kv max bytes") + parser.add_argument("--workers", default=1, type=int, help="inference worker count") + parser.add_argument("--queue-size", default=128, type=int, help="max queued tasks per worker") + parser.add_argument("--request-timeout-ms", default=120000, type=int, help="scheduler request timeout in milliseconds") + parser.add_argument( + "--continuous-batching", + action="store_true", + help="enable minimal iteration-level continuous scheduling", + ) + parser.add_argument( + "--kv-aware-routing", + action="store_true", + help="enable KV-aware worker routing (query KV pool before dispatching)", + ) + parser.add_argument( + "--max-batch-size", + default=8, + type=int, + help="max sequences per streaming batch (default 8)", + ) + parser.add_argument( + "--shared-model", + action="store_true", + help="share a single model instance and KV pool across all workers", + ) + parser.add_argument( + "--kv-memory-threshold", + default=0.0, + type=float, + help="KV memory pressure threshold (0.0=disabled, 0.85=recommended)", + ) + args = parser.parse_args() + + tokenizer_path = _resolve_tokenizer_path(args.model, args.tokenizer) + worker_count = max(1, int(args.workers)) + services: List[ChatService] = [] + if args.shared_model: + # Shared mode: one model, one tokenizer, one KV pool, one KV bridge, one lock + model = Qwen2( + args.model, + llaisys.DeviceType.CPU if args.device == "cpu" else llaisys.DeviceType.NVIDIA, + ) + tokenizer = llaisys.Tokenizer(tokenizer_path) + shared_lock = threading.RLock() + shared_kv_pool = KVCachePool( + block_size=args.kv_block_size, + max_blocks=args.kv_max_blocks, + max_bytes=args.kv_max_bytes, + ) + shared_kv_bridge = KVRuntimeBridge(model, enabled=args.kv_runtime_reuse) + for _ in range(worker_count): + services.append( + ChatService( + model, + tokenizer, + model_path=args.model, + enable_kv_runtime_reuse=args.kv_runtime_reuse, + block_size=args.kv_block_size, + max_blocks=args.kv_max_blocks, + max_bytes=args.kv_max_bytes, + model_lock=shared_lock, + kv_pool=shared_kv_pool, + kv_bridge=shared_kv_bridge, + ) + ) + else: + for _ in range(worker_count): + tokenizer = llaisys.Tokenizer(tokenizer_path) + model = Qwen2( + args.model, + llaisys.DeviceType.CPU if args.device == "cpu" else llaisys.DeviceType.NVIDIA, + ) + services.append( + ChatService( + model, + tokenizer, + model_path=args.model, + enable_kv_runtime_reuse=args.kv_runtime_reuse, + block_size=args.kv_block_size, + max_blocks=args.kv_max_blocks, + max_bytes=args.kv_max_bytes, + ) + ) + scheduler = InferenceScheduler( + services, + queue_size=max(1, int(args.queue_size)), + request_timeout_ms=max(0, int(args.request_timeout_ms)), + continuous_batching=bool(args.continuous_batching), + kv_aware_routing=bool(args.kv_aware_routing), + max_batch_size=max(1, int(args.max_batch_size)), + kv_memory_threshold=float(args.kv_memory_threshold), + ) + scheduler.start() + + handler = ChatHandler + handler.scheduler = scheduler + server = ThreadingHTTPServer((args.host, args.port), handler) + server.daemon_threads = True + kv_routing_str = ", kv_aware_routing=on" if args.kv_aware_routing else "" + shared_str = ", shared_model=on" if args.shared_model else "" + kv_mem_str = f", kv_memory_threshold={args.kv_memory_threshold}" if args.kv_memory_threshold > 0 else "" + print( + f"LLAISYS chat server listening on http://{args.host}:{args.port} " + f"(workers={worker_count}, queue_size={max(1, int(args.queue_size))}{kv_routing_str}{shared_str}{kv_mem_str})" + ) + try: + server.serve_forever() + finally: + scheduler.stop() + + +if __name__ == "__main__": + main() diff --git a/python/llaisys/session_manager.py b/python/llaisys/session_manager.py new file mode 100644 index 000000000..c37aa8fbf --- /dev/null +++ b/python/llaisys/session_manager.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import threading +import uuid +from typing import Any, Dict, List, Tuple + + +class SessionManager: + """Session message history and cancellation event management.""" + + def __init__(self) -> None: + self._lock = threading.Lock() + self._context_messages: Dict[str, List[Dict[str, str]]] = {} + self._cancel_events: Dict[str, threading.Event] = {} + + def extract_messages(self, payload: Dict[str, Any]) -> Tuple[str, List[Dict[str, str]]]: + """Extract context_id and message list from payload. + + Handles three input modes: + - edit_from_session_id: branch from history edit + - messages: direct message list + - prompt: append to existing session history + """ + context_id = str(payload.get("session_id") or "").strip() or str(uuid.uuid4()) + messages = payload.get("messages") + prompt = payload.get("prompt") + edit_from = str(payload.get("edit_from_session_id") or "").strip() + edit_index_raw = payload.get("edit_message_index") + + if edit_from: + with self._lock: + source = list(self._context_messages.get(edit_from, [])) + if not source: + raise ValueError("edit_from_session_id not found") + if prompt is None: + raise ValueError("prompt is required when editing history") + if edit_index_raw is None: + raise ValueError("edit_message_index is required when editing history") + edit_index = int(edit_index_raw) + if edit_index < 0 or edit_index >= len(source): + raise ValueError("edit_message_index out of range") + if source[edit_index].get("role") != "user": + raise ValueError("edit_message_index must point to a user message") + branched = source[: edit_index + 1] + branched[edit_index] = {"role": "user", "content": str(prompt)} + if not str(payload.get("session_id") or "").strip(): + context_id = str(uuid.uuid4()) + return context_id, branched + + if messages is not None: + if not isinstance(messages, list): + raise ValueError("messages must be a list") + return context_id, list(messages) + + if prompt is None: + raise ValueError("payload must include messages or prompt") + + with self._lock: + history = list(self._context_messages.get(context_id, [])) + history.append({"role": "user", "content": str(prompt)}) + return context_id, history + + def save_messages(self, context_id: str, messages: List[Dict[str, str]]) -> None: + """Save session message history.""" + with self._lock: + self._context_messages[context_id] = list(messages) + + def get_messages(self, context_id: str) -> List[Dict[str, str]]: + """Get session message history (returns a copy).""" + with self._lock: + return list(self._context_messages.get(context_id, [])) + + def get_cancel_event(self, context_id: str) -> threading.Event: + """Get or create a cancellation event for the given context.""" + with self._lock: + event = self._cancel_events.get(context_id) + if event is None: + event = threading.Event() + self._cancel_events[context_id] = event + return event + + def request_stop(self, context_id: str) -> bool: + """Set the cancellation event for the given context.""" + with self._lock: + event = self._cancel_events.get(context_id) + if event is None: + event = threading.Event() + self._cancel_events[context_id] = event + event.set() + return True + + def clear_stop(self, context_id: str) -> None: + """Clear the cancellation event for the given context.""" + with self._lock: + event = self._cancel_events.get(context_id) + if event: + event.clear() diff --git a/python/llaisys/tensor_parallel.py b/python/llaisys/tensor_parallel.py new file mode 100644 index 000000000..f21029dcd --- /dev/null +++ b/python/llaisys/tensor_parallel.py @@ -0,0 +1,64 @@ +"""Tensor parallel weight splitting for Qwen2 models (Megatron-style).""" + +import numpy as np + + +def split_column(tensor: np.ndarray, rank: int, world_size: int) -> np.ndarray: + """Split tensor along dim 0 (output features). For Q/K/V/gate/up weights and biases.""" + chunk = tensor.shape[0] // world_size + return tensor[rank * chunk : (rank + 1) * chunk].copy() + + +def split_row(tensor: np.ndarray, rank: int, world_size: int) -> np.ndarray: + """Split tensor along dim 1 (input features). For attn_o/down weights.""" + chunk = tensor.shape[1] // world_size + return tensor[:, rank * chunk : (rank + 1) * chunk].copy() + + +# Weight name patterns that get column-split (dim 0) +_COLUMN_SPLIT = { + "self_attn.q_proj.weight", + "self_attn.k_proj.weight", + "self_attn.v_proj.weight", + "self_attn.q_proj.bias", + "self_attn.k_proj.bias", + "self_attn.v_proj.bias", + "mlp.gate_proj.weight", + "mlp.up_proj.weight", +} + +# Weight name patterns that get row-split (dim 1) +_ROW_SPLIT = { + "self_attn.o_proj.weight", + "mlp.down_proj.weight", +} + + +def shard_qwen2_weights( + weights_dict: dict[str, np.ndarray], rank: int, world_size: int +) -> dict[str, np.ndarray]: + """Shard Qwen2 model weights for tensor parallelism. + + Megatron-style: column split Q/K/V/gate/up, row split attn_o/down. + Replicate: embeddings, norms, everything else. + """ + if world_size <= 1: + return weights_dict + + out = {} + for name, tensor in weights_dict.items(): + # Extract the sub-key for layer weights (e.g. "self_attn.q_proj.weight") + sub = None + if name.startswith("model.layers."): + parts = name.split(".") + if len(parts) >= 4: + sub = ".".join(parts[3:]) + + if sub in _COLUMN_SPLIT: + out[name] = split_column(tensor, rank, world_size) + elif sub in _ROW_SPLIT: + out[name] = split_row(tensor, rank, world_size) + else: + # Replicate: embeddings, norms, lm_head + out[name] = tensor + return out diff --git a/python/llaisys/tokenizer.py b/python/llaisys/tokenizer.py new file mode 100644 index 000000000..12053bd12 --- /dev/null +++ b/python/llaisys/tokenizer.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from ctypes import POINTER, c_char_p, c_int64, c_size_t, create_string_buffer +from pathlib import Path +from typing import Iterable, List, Optional + +from .libllaisys import LIB_LLAISYS, LlaisysTokenizer + + +class Tokenizer: + def __init__(self, model_path: str): + self._backend: str = "sentencepiece" + self._tokenizer: Optional[LlaisysTokenizer] = None + self._hf_tokenizer = None + + tokenizer_path = self._resolve_tokenizer_path(model_path) + if tokenizer_path.suffix.lower() == ".json": + self._backend = "hf" + self._hf_tokenizer = self._load_hf_tokenizer(tokenizer_path) + else: + self._tokenizer = LIB_LLAISYS.llaisysTokenizerCreateSentencePiece( + c_char_p(str(tokenizer_path).encode("utf-8")) + ) + if not self._tokenizer: + raise RuntimeError("llaisysTokenizerCreateSentencePiece failed") + + def encode(self, text: str) -> List[int]: + if self._backend == "hf": + return list(self._hf_tokenizer.encode(text).ids) + data = text.encode("utf-8") + n = int( + LIB_LLAISYS.llaisysTokenizerEncode( + self._tokenizer, c_char_p(data), None, c_size_t(0) + ) + ) + if n < 0: + raise RuntimeError("llaisysTokenizerEncode failed") + if n == 0: + return [] + out_ids = (c_int64 * n)() + written = int( + LIB_LLAISYS.llaisysTokenizerEncode( + self._tokenizer, c_char_p(data), out_ids, c_size_t(n) + ) + ) + if written < 0: + raise RuntimeError("llaisysTokenizerEncode failed") + return [int(out_ids[i]) for i in range(written)] + + def decode(self, ids: Iterable[int]) -> str: + ids_list = list(ids) + n = len(ids_list) + if n == 0: + return "" + if self._backend == "hf": + return self._hf_tokenizer.decode(ids_list, skip_special_tokens=False) + buf = (c_int64 * n)(*ids_list) + max_len = int( + LIB_LLAISYS.llaisysTokenizerDecode( + self._tokenizer, buf, c_size_t(n), None, c_size_t(0) + ) + ) + if max_len < 0: + raise RuntimeError("llaisysTokenizerDecode failed") + out = create_string_buffer(max_len) + written = int( + LIB_LLAISYS.llaisysTokenizerDecode( + self._tokenizer, buf, c_size_t(n), out, c_size_t(max_len) + ) + ) + if written < 0: + raise RuntimeError("llaisysTokenizerDecode failed") + return out.value.decode("utf-8") + + def close(self) -> None: + if self._tokenizer: + LIB_LLAISYS.llaisysTokenizerDestroy(self._tokenizer) + self._tokenizer = None + + def __del__(self) -> None: + self.close() + + @staticmethod + def _resolve_tokenizer_path(model_path: str) -> Path: + path = Path(model_path) + if path.is_dir(): + sp = path / "tokenizer.model" + if sp.exists(): + return sp + hf = path / "tokenizer.json" + if hf.exists(): + return hf + raise FileNotFoundError( + f"No tokenizer.model or tokenizer.json found under: {path}" + ) + if not path.exists(): + raise FileNotFoundError(f"Tokenizer file not found: {path}") + return path + + @staticmethod + def _load_hf_tokenizer(path: Path): + try: + from tokenizers import Tokenizer as HFTokenizer + except Exception as exc: + raise RuntimeError( + "tokenizer.json requires the 'tokenizers' package. " + "Install with: pip install tokenizers" + ) from exc + return HFTokenizer.from_file(str(path)) + + +__all__ = ["Tokenizer"] diff --git a/scripts/_tp_worker.py b/scripts/_tp_worker.py new file mode 100644 index 000000000..40365f74b --- /dev/null +++ b/scripts/_tp_worker.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +"""TP worker process -- spawned by launch_tp.py. + +Reads env vars: RANK, WORLD_SIZE, CUDA_VISIBLE_DEVICES, TP_UID_FILE, +TP_MODEL_PATH, TP_DEVICE, TP_PROMPT, TP_MAX_TOKENS. +""" + +import os +import sys +import ctypes +import time +from pathlib import Path + +_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, _project_root) + +import numpy as np +import safetensors + +from llaisys.libllaisys import ( + LIB_LLAISYS, + LlaisysCommAPI, + llaisysComm_t, + LLAISYS_COMM_UNIQUE_ID_MAX_SIZE, + DeviceType, + DataType, + LlaisysQwen2Meta, + llaisysDeviceType_t, + llaisysDataType_t, + LlaisysSamplingParams, +) +from llaisys.tensor_parallel import shard_qwen2_weights + + +def main(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + uid_file = os.environ["TP_UID_FILE"] + model_path = Path(os.environ["TP_MODEL_PATH"]) + device_name = os.environ.get("TP_DEVICE", "nvidia") + prompt = os.environ.get("TP_PROMPT", "Hello") + max_tokens = int(os.environ.get("TP_MAX_TOKENS", "64")) + + device = DeviceType.NVIDIA if device_name == "nvidia" else DeviceType.ILUVATAR + backend = 0 # NCCL + + # Read unique ID + for _ in range(100): + if os.path.exists(uid_file) and os.path.getsize(uid_file) > 0: + break + time.sleep(0.1) + with open(uid_file, "rb") as f: + uid_bytes = f.read() + + # Init comm + api_ptr = LIB_LLAISYS.llaisysGetCommAPI(backend) + api = api_ptr.contents + comm = llaisysComm_t() + uid_buf = ctypes.create_string_buffer(uid_bytes, LLAISYS_COMM_UNIQUE_ID_MAX_SIZE) + ret = api.init(ctypes.byref(comm), rank, world_size, uid_buf) + if ret != 0: + raise RuntimeError(f"commInit failed: {ret}") + + # Load tokenizer (use transformers for HF tokenizer.json) + from transformers import AutoTokenizer + tok = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + + # Tokenize prompt + import json + config_path = model_path / "config.json" + with open(config_path, "r", encoding="utf-8") as f: + cfg = json.load(f) + + # Load and shard weights + weights = {} + for file in sorted(model_path.glob("*.safetensors")): + import torch + data_ = safetensors.safe_open(file, framework="pt", device="cpu") + for name in data_.keys(): + arr = data_.get_tensor(name) + if arr.dtype == torch.bfloat16: + arr = arr.to(torch.float16) + weights[name] = arr.cpu().numpy() + + weights = shard_qwen2_weights(weights, rank, world_size) + + # Build model meta + torch_dtype = str(cfg.get("torch_dtype", "bfloat16")).lower() + dtype = DataType.F16 # we convert bf16->f16 above + nlayer = int(cfg.get("num_hidden_layers", 0)) + hs = int(cfg.get("hidden_size", 0)) + nh = int(cfg.get("num_attention_heads", 0)) + nkvh = int(cfg.get("num_key_value_heads", nh)) + di = int(cfg.get("intermediate_size", 0)) + maxseq = int(cfg.get("max_position_embeddings", 0)) + voc = int(cfg.get("vocab_size", 0)) + epsilon = float(cfg.get("rms_norm_eps", 1e-6)) + theta = float(cfg.get("rope_theta", 10000.0)) + eos = cfg.get("eos_token_id", -1) + end_token = int(eos[0]) if isinstance(eos, list) else int(eos) + dh = int(cfg.get("head_dim", hs // nh if nh else 0)) + + # Adjust nh/nkvh for TP + tp_nh = nh // world_size + tp_nkvh = nkvh // world_size + + model_meta = LlaisysQwen2Meta( + llaisysDataType_t(dtype), + ctypes.c_size_t(nlayer), + ctypes.c_size_t(hs), + ctypes.c_size_t(tp_nh), + ctypes.c_size_t(tp_nkvh), + ctypes.c_size_t(dh), + ctypes.c_size_t(di // world_size), + ctypes.c_size_t(maxseq), + ctypes.c_size_t(voc), + ctypes.c_float(epsilon), + ctypes.c_float(theta), + ctypes.c_int64(end_token), + ) + + device_ids = (ctypes.c_int * 1)(0) + model = LIB_LLAISYS.llaisysQwen2ModelCreate( + ctypes.byref(model_meta), llaisysDeviceType_t(device), device_ids, 1 + ) + if not model: + raise RuntimeError("llaisysQwen2ModelCreate failed") + + LIB_LLAISYS.llaisysQwen2ModelSetKVCacheEnabled(model, ctypes.c_int(1)) + LIB_LLAISYS.llaisysQwen2ModelSetTensorParallel(model, comm, ctypes.c_void_p(None), world_size) + model_weights = LIB_LLAISYS.llaisysQwen2ModelWeights(model) + + # Upload sharded weights + def upload_tensor(arr): + arr = np.ascontiguousarray(arr) + shape = (ctypes.c_size_t * arr.ndim)(*arr.shape) + dt = DataType.F16 if "float16" in arr.dtype.name else DataType.F32 + tensor = LIB_LLAISYS.tensorCreate( + shape, ctypes.c_size_t(arr.ndim), + llaisysDataType_t(dt), llaisysDeviceType_t(device), ctypes.c_int(0), + ) + LIB_LLAISYS.tensorLoad(tensor, ctypes.c_void_p(arr.ctypes.data)) + return tensor + + w = model_weights.contents + for name, arr in weights.items(): + tensor = upload_tensor(arr) + if name in {"model.embed_tokens.weight", "transformer.wte.weight"}: + w.in_embed = tensor + elif name in {"lm_head.weight", "model.lm_head.weight"}: + w.out_embed = tensor + elif name in {"model.norm.weight", "transformer.ln_f.weight"}: + w.out_norm_w = tensor + elif name.startswith("model.layers."): + parts = name.split(".") + if len(parts) < 4: + continue + layer = int(parts[2]) + sub = ".".join(parts[3:]) + if sub == "input_layernorm.weight": + w.attn_norm_w[layer] = tensor + elif sub == "self_attn.q_proj.weight": + w.attn_q_w[layer] = tensor + elif sub == "self_attn.q_proj.bias": + w.attn_q_b[layer] = tensor + elif sub == "self_attn.k_proj.weight": + w.attn_k_w[layer] = tensor + elif sub == "self_attn.k_proj.bias": + w.attn_k_b[layer] = tensor + elif sub == "self_attn.v_proj.weight": + w.attn_v_w[layer] = tensor + elif sub == "self_attn.v_proj.bias": + w.attn_v_b[layer] = tensor + elif sub == "self_attn.o_proj.weight": + w.attn_o_w[layer] = tensor + elif sub == "post_attention_layernorm.weight": + w.mlp_norm_w[layer] = tensor + elif sub == "mlp.gate_proj.weight": + w.mlp_gate_w[layer] = tensor + elif sub == "mlp.up_proj.weight": + w.mlp_up_w[layer] = tensor + elif sub == "mlp.down_proj.weight": + w.mlp_down_w[layer] = tensor + + if not w.out_embed and w.in_embed: + w.out_embed = w.in_embed + + # Tokenize and run inference + input_ids = tok.encode(prompt, add_special_tokens=True) + + # Prefill + decode + token_buf = (ctypes.c_int64 * len(input_ids))(*input_ids) + params = LlaisysSamplingParams(ctypes.c_int(1), ctypes.c_float(0.0), ctypes.c_float(0.0), ctypes.c_uint32(0)) + next_token = int(LIB_LLAISYS.llaisysQwen2ModelPrefillSampling( + model, token_buf, ctypes.c_size_t(len(input_ids)), ctypes.byref(params), + )) + + generated = list(input_ids) + for step in range(max_tokens): + if next_token < 0 or next_token == end_token: + break + generated.append(next_token) + tb = (ctypes.c_int64 * 1)(next_token) + next_token = int(LIB_LLAISYS.llaisysQwen2ModelStepSampling( + model, tb, ctypes.c_size_t(1), ctypes.byref(params), + )) + + # Decode and print from rank 0 + if rank == 0: + output_text = tok.decode(generated, skip_special_tokens=True) + print(output_text) + + # Cleanup + LIB_LLAISYS.llaisysQwen2ModelDestroy(model) + api.destroy(comm) + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_chat_scheduler.py b/scripts/benchmark_chat_scheduler.py new file mode 100644 index 000000000..d9d8af43e --- /dev/null +++ b/scripts/benchmark_chat_scheduler.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import argparse +from concurrent.futures import ThreadPoolExecutor, as_completed +import json +import math +import statistics +import time +import urllib.error +import urllib.request +import uuid +from typing import Any, Dict, List, Optional, Tuple + + +def _post_json(url: str, payload: Dict[str, Any], timeout: float) -> Tuple[int, Dict[str, Any], str]: + body = json.dumps(payload, ensure_ascii=False).encode("utf-8") + req = urllib.request.Request( + url, + data=body, + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + text = resp.read().decode("utf-8", errors="replace") + code = int(resp.status) + data = json.loads(text) if text else {} + return code, data, "" + except urllib.error.HTTPError as exc: + text = exc.read().decode("utf-8", errors="replace") + data = {} + try: + data = json.loads(text) if text else {} + except Exception: + pass + return int(exc.code), data, text or str(exc) + except Exception as exc: + return -1, {}, str(exc) + + +def _get_json(url: str, timeout: float) -> Dict[str, Any]: + req = urllib.request.Request(url, method="GET") + with urllib.request.urlopen(req, timeout=timeout) as resp: + text = resp.read().decode("utf-8", errors="replace") + return json.loads(text) if text else {} + + +def _percentile(sorted_values: List[float], p: float) -> float: + if not sorted_values: + return 0.0 + if len(sorted_values) == 1: + return float(sorted_values[0]) + rank = p * (len(sorted_values) - 1) + low = int(math.floor(rank)) + high = int(math.ceil(rank)) + if low == high: + return float(sorted_values[low]) + w = rank - low + return float(sorted_values[low] * (1 - w) + sorted_values[high] * w) + + +def run_benchmark(args: argparse.Namespace) -> int: + endpoint = args.endpoint.rstrip("/") + chat_url = f"{endpoint}/chat" + scheduler_url = f"{endpoint}/debug/scheduler" + health_url = f"{endpoint}/health" + + try: + health = _get_json(health_url, timeout=args.timeout) + except Exception as exc: + print(f"[ERROR] health check failed: {exc}") + return 2 + + print(f"[INFO] health: {health}") + before_debug: Dict[str, Any] = {} + try: + before_debug = _get_json(scheduler_url, timeout=args.timeout) + except Exception: + before_debug = {} + + if args.warmup > 0: + print(f"[INFO] warmup requests: {args.warmup}") + for i in range(args.warmup): + payload: Dict[str, Any] = { + "prompt": f"{args.prompt} [warmup-{i}]", + "stream": False, + "max_new_tokens": args.max_new_tokens, + } + _post_json(chat_url, payload, timeout=args.timeout) + + total = int(args.total_requests) + concurrency = int(args.concurrency) + print(f"[INFO] start benchmark: total={total}, concurrency={concurrency}, endpoint={chat_url}") + + t0 = time.perf_counter() + latencies_ms: List[float] = [] + errors: List[str] = [] + status_count: Dict[int, int] = {} + + def _one_request(i: int) -> Tuple[float, int, str]: + payload: Dict[str, Any] = { + "prompt": f"{args.prompt} [req-{i}]", + "stream": False, + "max_new_tokens": args.max_new_tokens, + } + if args.session_mode == "shared": + payload["session_id"] = args.shared_session_id + elif args.session_mode == "unique": + payload["session_id"] = f"{args.session_prefix}-{uuid.uuid4()}" + if args.sampling: + payload["sampling"] = args.sampling + if args.temperature is not None: + payload["temperature"] = args.temperature + if args.top_k is not None: + payload["top_k"] = args.top_k + if args.top_p is not None: + payload["top_p"] = args.top_p + + s = time.perf_counter() + code, data, err = _post_json(chat_url, payload, timeout=args.timeout) + elapsed_ms = (time.perf_counter() - s) * 1000.0 + if code == 200 and not data.get("error"): + return elapsed_ms, code, "" + detail = err or str(data.get("error") or f"HTTP {code}") + return elapsed_ms, code, detail + + with ThreadPoolExecutor(max_workers=concurrency) as ex: + futures = [ex.submit(_one_request, i) for i in range(total)] + for fut in as_completed(futures): + elapsed_ms, code, detail = fut.result() + status_count[code] = status_count.get(code, 0) + 1 + if code == 200 and not detail: + latencies_ms.append(elapsed_ms) + else: + errors.append(f"[{code}] {detail}") + + total_elapsed_s = max(1e-9, time.perf_counter() - t0) + success = len(latencies_ms) + failed = len(errors) + throughput = total / total_elapsed_s + + latencies_sorted = sorted(latencies_ms) + p50 = _percentile(latencies_sorted, 0.50) + p95 = _percentile(latencies_sorted, 0.95) + p99 = _percentile(latencies_sorted, 0.99) + avg = statistics.mean(latencies_ms) if latencies_ms else 0.0 + + after_debug: Dict[str, Any] = {} + try: + after_debug = _get_json(scheduler_url, timeout=args.timeout) + except Exception: + after_debug = {} + + print("\n=== Benchmark Summary ===") + print(f"success: {success}/{total} ({(success / total) * 100:.1f}%)") + print(f"failed: {failed}") + print(f"elapsed_s: {total_elapsed_s:.3f}") + print(f"throughput_rps: {throughput:.2f}") + print(f"latency_ms: avg={avg:.1f}, p50={p50:.1f}, p95={p95:.1f}, p99={p99:.1f}") + print(f"status_count: {status_count}") + + if before_debug: + print("\n=== /debug/scheduler (before) ===") + print(json.dumps(before_debug, ensure_ascii=False, indent=2)) + if after_debug: + print("\n=== /debug/scheduler (after) ===") + print(json.dumps(after_debug, ensure_ascii=False, indent=2)) + + if errors: + print("\n=== Sample Errors (up to 10) ===") + for line in errors[:10]: + print(line) + return 0 if success > 0 else 1 + + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="LLAISYS scheduler benchmark (non-stream chat requests)") + p.add_argument("--endpoint", default="http://127.0.0.1:8000", type=str) + p.add_argument("--total-requests", default=20, type=int) + p.add_argument("--concurrency", default=5, type=int) + p.add_argument("--prompt", default="请用一句话介绍北京", type=str) + p.add_argument("--max-new-tokens", default=32, type=int) + p.add_argument("--timeout", default=60.0, type=float, help="per-request timeout in seconds") + p.add_argument("--warmup", default=1, type=int) + p.add_argument("--session-mode", choices=["none", "shared", "unique"], default="none") + p.add_argument("--shared-session-id", default="bench-shared-session", type=str) + p.add_argument("--session-prefix", default="bench-session", type=str) + p.add_argument("--sampling", default="", type=str) + p.add_argument("--temperature", default=None, type=float) + p.add_argument("--top-k", default=None, type=int) + p.add_argument("--top-p", default=None, type=float) + return p + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + raise SystemExit(run_benchmark(args)) + + +if __name__ == "__main__": + main() diff --git a/scripts/launch_tp.py b/scripts/launch_tp.py new file mode 100644 index 000000000..a7ccf2788 --- /dev/null +++ b/scripts/launch_tp.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +"""Tensor-parallel multi-process launcher for llaisys inference. + +Rank 0 generates a NCCL unique ID, writes it to a temp file, then spawns +N subprocesses (one per GPU). Each subprocess loads sharded weights, inits +the communicator with the shared unique ID, and runs inference. Output is +printed from rank 0. + +Usage: + python scripts/launch_tp.py --model /path/to/qwen2 --nranks 2 --prompt "Hello" +""" + +import argparse +import os +import sys +import subprocess +import tempfile +import ctypes + +# Ensure project root is on sys.path +_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, _project_root) + + +def generate_unique_id(backend=0): + """Generate NCCL unique ID via llaisysCommGenerateUniqueId.""" + from llaisys.libllaisys import LIB_LLAISYS, LLAISYS_COMM_UNIQUE_ID_MAX_SIZE + id_buf = ctypes.create_string_buffer(LLAISYS_COMM_UNIQUE_ID_MAX_SIZE) + id_size = ctypes.c_size_t(0) + ret = LIB_LLAISYS.llaisysCommGenerateUniqueId( + backend, id_buf, ctypes.byref(id_size) + ) + if ret != 0: + raise RuntimeError(f"llaisysCommGenerateUniqueId failed: {ret}") + return id_buf.raw[: id_size.value] + + +def main(): + parser = argparse.ArgumentParser(description="Tensor-parallel launcher") + parser.add_argument("--model", required=True, help="Path to model directory") + parser.add_argument("--nranks", type=int, default=2, help="Number of TP ranks") + parser.add_argument("--device", default="nvidia", choices=["nvidia", "iluvatar"]) + parser.add_argument("--prompt", default="Hello", help="Input prompt") + parser.add_argument("--max-tokens", type=int, default=64) + args = parser.parse_args() + + # Generate unique ID on rank 0 process + uid_bytes = generate_unique_id() + + # Write unique ID to temp file for subprocesses + tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".uid") + tmp.write(uid_bytes) + tmp.close() + uid_path = tmp.name + + worker = os.path.join(os.path.dirname(__file__), "_tp_worker.py") + + procs = [] + for rank in range(args.nranks): + env = os.environ.copy() + env["RANK"] = str(rank) + env["WORLD_SIZE"] = str(args.nranks) + env["CUDA_VISIBLE_DEVICES"] = str(rank) + env["TP_UID_FILE"] = uid_path + env["TP_MODEL_PATH"] = args.model + env["TP_DEVICE"] = args.device + env["TP_PROMPT"] = args.prompt + env["TP_MAX_TOKENS"] = str(args.max_tokens) + + proc = subprocess.Popen( + [sys.executable, worker], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + procs.append((rank, proc)) + + # Wait for all and collect output + for rank, proc in procs: + stdout, stderr = proc.communicate() + if proc.returncode != 0: + print(f"[rank {rank}] FAILED (exit {proc.returncode})", file=sys.stderr) + if stderr: + print(stderr.decode(errors="replace"), file=sys.stderr) + elif rank == 0: + print(stdout.decode(errors="replace"), end="") + + # Cleanup + try: + os.unlink(uid_path) + except OSError: + pass + + +if __name__ == "__main__": + main() diff --git a/scripts/run_gpu.ps1 b/scripts/run_gpu.ps1 new file mode 100644 index 000000000..150f84486 --- /dev/null +++ b/scripts/run_gpu.ps1 @@ -0,0 +1,130 @@ +param( + [ValidateSet("build", "test", "server", "all")] + [string]$Mode = "all", + [string]$Model = "", + [string]$Device = "nvidia", + [string]$CondaEnv = "llaisys-gpu", + [string]$ConfigPath = "", + [switch]$SkipTests, + [switch]$ActivateConda +) + +$ErrorActionPreference = "Stop" + +function Write-Step([string]$Message) { + Write-Host "==> $Message" +} + +$PythonExe = "python" +function Resolve-PythonExe { + if ($env:CONDA_PREFIX) { + $candidate = Join-Path $env:CONDA_PREFIX "python.exe" + if (Test-Path $candidate) { + return $candidate + } + } + return "python" +} + +$RepoRoot = Resolve-Path (Join-Path $PSScriptRoot "..") +Set-Location $RepoRoot + +if ([string]::IsNullOrWhiteSpace($ConfigPath)) { + $ConfigPath = Join-Path $RepoRoot "scripts\run_gpu.config.json" +} + +$Config = $null +if (Test-Path $ConfigPath) { + try { + $Config = Get-Content $ConfigPath -Raw | ConvertFrom-Json + } catch { + throw "Failed to read config file: $ConfigPath" + } +} + +if ($Config -ne $null) { + if (-not $PSBoundParameters.ContainsKey("Model") -and $Config.model) { + $Model = $Config.model + } + if (-not $PSBoundParameters.ContainsKey("Device") -and $Config.device) { + $Device = $Config.device + } + if (-not $PSBoundParameters.ContainsKey("CondaEnv") -and $Config.conda_env) { + $CondaEnv = $Config.conda_env + } +} + +if ($ActivateConda) { + if (Get-Command conda -ErrorAction SilentlyContinue) { + Write-Step "Activating conda env: $CondaEnv" + conda activate $CondaEnv + } else { + throw "conda is not available in this shell. Run 'conda init powershell' and reopen PowerShell." + } +} +$PythonExe = Resolve-PythonExe + +function Build-Gpu { + Write-Step "Configuring xmake" + xmake f -m release --nv-gpu=y --vs=2022 + + Write-Step "Building" + xmake + + $dllSrc = Join-Path $RepoRoot "build\windows\x64\release\llaisys.dll" + $dllDst = Join-Path $RepoRoot "python\llaisys\libllaisys\llaisys.dll" + if (!(Test-Path $dllSrc)) { + throw "Build output not found: $dllSrc" + } + + Write-Step "Copying DLL to python package" + Copy-Item $dllSrc $dllDst -Force +} + +function Ensure-Dll { + $dllDst = Join-Path $RepoRoot "python\llaisys\libllaisys\llaisys.dll" + if (Test-Path $dllDst) { + return + } + $dllCandidates = @( + (Join-Path $RepoRoot "bin\llaisys.dll"), + (Join-Path $RepoRoot "build\windows\x64\release\llaisys.dll") + ) + foreach ($dllSrc in $dllCandidates) { + if (Test-Path $dllSrc) { + Write-Step "Copying DLL to python package" + Copy-Item $dllSrc $dllDst -Force + return + } + } + throw "Missing llaisys.dll. Run '-Mode build' or copy it to: $dllDst" +} + +function Test-Gpu { + Ensure-Dll + Write-Step "Running GPU op tests" + & $PythonExe test/ops_gpu/run_all.py +} + +function Run-Server { + if ([string]::IsNullOrWhiteSpace($Model)) { + throw "Model path is required. Provide -Model or set 'model' in $ConfigPath" + } + Ensure-Dll + Write-Step "Starting server on $Device" + & $PythonExe -m llaisys.server --model $Model --device $Device +} + +switch ($Mode) { + "build" { Build-Gpu } + "test" { Test-Gpu } + "server" { Run-Server } + "all" { + Build-Gpu + if (-not $SkipTests) { + Test-Gpu + } + Run-Server + } +} + diff --git a/src/core/context/context.cpp b/src/core/context/context.cpp index 44894b9e7..cbcf1dc6b 100644 --- a/src/core/context/context.cpp +++ b/src/core/context/context.cpp @@ -3,7 +3,8 @@ #include namespace llaisys::core { - + +//构造函数,初始化运行时 Context::Context() { // All device types, put CPU at the end std::vector device_typs; @@ -31,6 +32,7 @@ Context::Context() { } } +//销毁上下文及其包含的运行时 Context::~Context() { // Destroy current runtime first. delete _current_runtime; @@ -49,6 +51,7 @@ Context::~Context() { _runtime_map.clear(); } +//设置当前设备 void Context::setDevice(llaisysDeviceType_t device_type, int device_id) { // If doest not match the current runtime. if (_current_runtime == nullptr || _current_runtime->deviceType() != device_type || _current_runtime->deviceId() != device_id) { @@ -65,6 +68,7 @@ void Context::setDevice(llaisysDeviceType_t device_type, int device_id) { } } +//获取当前运行时 Runtime &Context::runtime() { ASSERT(_current_runtime != nullptr, "No runtime is activated, please call setDevice() first."); return *_current_runtime; diff --git a/src/core/context/context.hpp b/src/core/context/context.hpp index a3ebcdecf..bd9707263 100644 --- a/src/core/context/context.hpp +++ b/src/core/context/context.hpp @@ -27,6 +27,7 @@ class Context { Context(Context &&) = delete; Context &operator=(Context &&) = delete; + //设置当前设备 void setDevice(llaisysDeviceType_t device_type, int device_id); Runtime &runtime(); diff --git a/src/device/comm_api.cpp b/src/device/comm_api.cpp new file mode 100644 index 000000000..d70b4b0bf --- /dev/null +++ b/src/device/comm_api.cpp @@ -0,0 +1,89 @@ +#include "comm_api.hpp" + +namespace llaisys::device { + +int commInit(llaisysComm_t *, int, int, const void *) { + EXCEPTION_UNSUPPORTED_DEVICE; + return -1; +} + +void commDestroy(llaisysComm_t) { + EXCEPTION_UNSUPPORTED_DEVICE; +} + +int commGetRank(llaisysComm_t) { + EXCEPTION_UNSUPPORTED_DEVICE; + return -1; +} + +int commGetSize(llaisysComm_t) { + EXCEPTION_UNSUPPORTED_DEVICE; + return -1; +} + +void commAllreduce(const void *, void *, size_t, llaisysDataType_t, llaisysReduceOp_t, llaisysComm_t, llaisysStream_t) { + EXCEPTION_UNSUPPORTED_DEVICE; +} + +void commBroadcast(void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t) { + EXCEPTION_UNSUPPORTED_DEVICE; +} + +void commSend(const void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t) { + EXCEPTION_UNSUPPORTED_DEVICE; +} + +void commRecv(void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t) { + EXCEPTION_UNSUPPORTED_DEVICE; +} + +static const LlaisysCommAPI NOOP_COMM_API = { + &commInit, + &commDestroy, + &commGetRank, + &commGetSize, + &commAllreduce, + &commBroadcast, + &commSend, + &commRecv}; + +const LlaisysCommAPI *getUnsupportedCommAPI() { + return &NOOP_COMM_API; +} + +const LlaisysCommAPI *getCommAPI(llaisysCommBackend_t backend) { + switch (backend) { + case LLAISYS_COMM_NCCL: +#ifdef ENABLE_NVIDIA_API + return llaisys::device::nccl::getCommAPI(); +#else + return getUnsupportedCommAPI(); +#endif + case LLAISYS_COMM_IXCCL: +#ifdef ENABLE_ILUVATAR_API + return llaisys::device::ixccl::getCommAPI(); +#else + return getUnsupportedCommAPI(); +#endif + case LLAISYS_COMM_MPI: + return getUnsupportedCommAPI(); + default: + return getUnsupportedCommAPI(); + } +} + +int commGenerateUniqueId(llaisysCommBackend_t backend, void *id_out, size_t *id_size) { + switch (backend) { + case LLAISYS_COMM_NCCL: +#ifdef ENABLE_NVIDIA_API + return llaisys::device::nccl::commGenerateUniqueId(id_out, id_size); +#else + EXCEPTION_UNSUPPORTED_DEVICE; + return -1; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + return -1; + } +} +} // namespace llaisys::device diff --git a/src/device/comm_api.hpp b/src/device/comm_api.hpp new file mode 100644 index 000000000..c0290e7d4 --- /dev/null +++ b/src/device/comm_api.hpp @@ -0,0 +1,25 @@ +#pragma once +#include "llaisys/comm.h" + +#include "../utils.hpp" + +namespace llaisys::device { +const LlaisysCommAPI *getCommAPI(llaisysCommBackend_t backend); +int commGenerateUniqueId(llaisysCommBackend_t backend, void *id_out, size_t *id_size); + +const LlaisysCommAPI *getUnsupportedCommAPI(); + +#ifdef ENABLE_NVIDIA_API +namespace nccl { +const LlaisysCommAPI *getCommAPI(); +int commGenerateUniqueId(void *id_out, size_t *id_size); +} +#endif + +#ifdef ENABLE_ILUVATAR_API +namespace ixccl { +const LlaisysCommAPI *getCommAPI(); +} +#endif + +} // namespace llaisys::device diff --git a/src/device/iluvatar/devlink_stub.cu b/src/device/iluvatar/devlink_stub.cu new file mode 100644 index 000000000..b64d3641b --- /dev/null +++ b/src/device/iluvatar/devlink_stub.cu @@ -0,0 +1,3 @@ +#include + +__global__ void llaisys_devlink_stub() {} diff --git a/src/device/iluvatar/iluvatar_resource.cu b/src/device/iluvatar/iluvatar_resource.cu new file mode 100644 index 000000000..67850c218 --- /dev/null +++ b/src/device/iluvatar/iluvatar_resource.cu @@ -0,0 +1,7 @@ +#include "iluvatar_resource.cuh" + +namespace llaisys::device::iluvatar { + +Resource::Resource(int device_id) : llaisys::device::DeviceResource(LLAISYS_DEVICE_ILUVATAR, device_id) {} + +} // namespace llaisys::device::iluvatar diff --git a/src/device/iluvatar/iluvatar_resource.cuh b/src/device/iluvatar/iluvatar_resource.cuh new file mode 100644 index 000000000..d3e637c39 --- /dev/null +++ b/src/device/iluvatar/iluvatar_resource.cuh @@ -0,0 +1,11 @@ +#pragma once + +#include "../device_resource.hpp" + +namespace llaisys::device::iluvatar { +class Resource : public llaisys::device::DeviceResource { +public: + Resource(int device_id); + ~Resource(); +}; +} // namespace llaisys::device::iluvatar diff --git a/src/device/iluvatar/iluvatar_runtime_api.cu b/src/device/iluvatar/iluvatar_runtime_api.cu new file mode 100644 index 000000000..24445f19d --- /dev/null +++ b/src/device/iluvatar/iluvatar_runtime_api.cu @@ -0,0 +1,119 @@ +#include "../runtime_api.hpp" +#include "iluvatar_utils.hpp" + +#include + +namespace llaisys::device::iluvatar { + +namespace runtime_api { +int getDeviceCount() { + int count = 0; + cuda_check(cudaGetDeviceCount(&count)); + return count; +} + +void setDevice(int device_id) { + cuda_check(cudaSetDevice(device_id)); +} + +void deviceSynchronize() { + cuda_check(cudaDeviceSynchronize()); +} + +llaisysStream_t createStream() { + cudaStream_t stream{}; + cuda_check(cudaStreamCreate(&stream)); + return reinterpret_cast(stream); +} + +void destroyStream(llaisysStream_t stream) { + cuda_check(cudaStreamDestroy(reinterpret_cast(stream))); +} +void streamSynchronize(llaisysStream_t stream) { + cuda_check(cudaStreamSynchronize(reinterpret_cast(stream))); +} + +void *mallocDevice(size_t size) { + void *ptr = nullptr; + cuda_check(cudaMalloc(&ptr, size)); + return ptr; +} + +void freeDevice(void *ptr) { + cuda_check(cudaFree(ptr)); +} + +void *mallocHost(size_t size) { + void *ptr = nullptr; + cuda_check(cudaMallocHost(&ptr, size)); + return ptr; +} + +void freeHost(void *ptr) { + cuda_check(cudaFreeHost(ptr)); +} + +void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { + cudaMemcpyKind cuda_kind = cudaMemcpyDefault; + switch (kind) { + case LLAISYS_MEMCPY_H2H: + cuda_kind = cudaMemcpyHostToHost; + break; + case LLAISYS_MEMCPY_H2D: + cuda_kind = cudaMemcpyHostToDevice; + break; + case LLAISYS_MEMCPY_D2H: + cuda_kind = cudaMemcpyDeviceToHost; + break; + case LLAISYS_MEMCPY_D2D: + cuda_kind = cudaMemcpyDeviceToDevice; + break; + default: + cuda_kind = cudaMemcpyDefault; + break; + } + cuda_check(cudaMemcpy(dst, src, size, cuda_kind)); +} + +void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) { + cudaMemcpyKind cuda_kind = cudaMemcpyDefault; + switch (kind) { + case LLAISYS_MEMCPY_H2H: + cuda_kind = cudaMemcpyHostToHost; + break; + case LLAISYS_MEMCPY_H2D: + cuda_kind = cudaMemcpyHostToDevice; + break; + case LLAISYS_MEMCPY_D2H: + cuda_kind = cudaMemcpyDeviceToHost; + break; + case LLAISYS_MEMCPY_D2D: + cuda_kind = cudaMemcpyDeviceToDevice; + break; + default: + cuda_kind = cudaMemcpyDefault; + break; + } + cuda_check(cudaMemcpyAsync(dst, src, size, cuda_kind, reinterpret_cast(stream))); +} + +static const LlaisysRuntimeAPI RUNTIME_API = { + &getDeviceCount, + &setDevice, + &deviceSynchronize, + &createStream, + &destroyStream, + &streamSynchronize, + &mallocDevice, + &freeDevice, + &mallocHost, + &freeHost, + &memcpySync, + &memcpyAsync}; + +} // namespace runtime_api + +const LlaisysRuntimeAPI *getRuntimeAPI() { + return &runtime_api::RUNTIME_API; +} +} // namespace llaisys::device::iluvatar diff --git a/src/device/iluvatar/iluvatar_utils.hpp b/src/device/iluvatar/iluvatar_utils.hpp new file mode 100644 index 000000000..5254b3c11 --- /dev/null +++ b/src/device/iluvatar/iluvatar_utils.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "../../utils/types.hpp" + +#include +#include +#include + +#include + +namespace llaisys::device::iluvatar { +inline void cuda_check(cudaError_t err) { + if (err == cudaSuccess) { + return; + } + if (err == cudaErrorCudartUnloading || err == cudaErrorContextIsDestroyed) { + return; + } + throw std::runtime_error(cudaGetErrorString(err)); +} + +template +struct ScalarOps; + +template <> +struct ScalarOps { + __device__ static inline float load(const float *ptr) { + return *ptr; + } + __device__ static inline void store(float *ptr, float v) { + *ptr = v; + } +}; + +template <> +struct ScalarOps { + __device__ static inline float load(const llaisys::fp16_t *ptr) { + return __half2float(*reinterpret_cast(ptr)); + } + __device__ static inline void store(llaisys::fp16_t *ptr, float v) { + *reinterpret_cast<__half *>(ptr) = __float2half(v); + } +}; + +template <> +struct ScalarOps { + __device__ static inline float load(const llaisys::bf16_t *ptr) { + return __bfloat162float(*reinterpret_cast(ptr)); + } + __device__ static inline void store(llaisys::bf16_t *ptr, float v) { + *reinterpret_cast<__nv_bfloat16 *>(ptr) = __float2bfloat16(v); + } +}; +} // namespace llaisys::device::iluvatar diff --git a/src/device/nvidia/cuda_utils.hpp b/src/device/nvidia/cuda_utils.hpp new file mode 100644 index 000000000..20522193f --- /dev/null +++ b/src/device/nvidia/cuda_utils.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include "../../utils/types.hpp" + +#include +#include +#include + +#include + +namespace llaisys::device::nvidia { +inline void cuda_check(cudaError_t err) { + if (err == cudaSuccess) { + return; + } + // During process shutdown, CUDA may report unloading/destroyed context. + if (err == cudaErrorCudartUnloading || err == cudaErrorContextIsDestroyed) { + return; + } + throw std::runtime_error(cudaGetErrorString(err)); +} + +template +struct ScalarOps; + +template <> +struct ScalarOps { + __device__ static inline float load(const float *ptr) { + return *ptr; + } + __device__ static inline void store(float *ptr, float v) { + *ptr = v; + } +}; + +template <> +struct ScalarOps { + __device__ static inline float load(const llaisys::fp16_t *ptr) { + return __half2float(*reinterpret_cast(ptr)); + } + __device__ static inline void store(llaisys::fp16_t *ptr, float v) { + *reinterpret_cast<__half *>(ptr) = __float2half(v); + } +}; + +template <> +struct ScalarOps { + __device__ static inline float load(const llaisys::bf16_t *ptr) { + return __bfloat162float(*reinterpret_cast(ptr)); + } + __device__ static inline void store(llaisys::bf16_t *ptr, float v) { + *reinterpret_cast<__nv_bfloat16 *>(ptr) = __float2bfloat16(v); + } +}; +} // namespace llaisys::device::nvidia diff --git a/src/device/nvidia/devlink_stub.cu b/src/device/nvidia/devlink_stub.cu new file mode 100644 index 000000000..6dc8ecc01 --- /dev/null +++ b/src/device/nvidia/devlink_stub.cu @@ -0,0 +1,4 @@ +#include + +__global__ void llaisys_devlink_stub() {} + diff --git a/src/device/nvidia/nvidia_comm.cu b/src/device/nvidia/nvidia_comm.cu new file mode 100644 index 000000000..f94691ed5 --- /dev/null +++ b/src/device/nvidia/nvidia_comm.cu @@ -0,0 +1,140 @@ +#include "../comm_api.hpp" +#include "cuda_utils.hpp" + +#include +#include +#include + +namespace llaisys::device::nvidia { + +inline void nccl_check(ncclResult_t result) { + if (result == ncclSuccess) { + return; + } + throw std::runtime_error(ncclGetErrorString(result)); +} + +inline ncclDataType_t to_nccl_dtype(llaisysDataType_t dtype) { + switch (dtype) { + case LLAISYS_DTYPE_F32: return ncclFloat32; + case LLAISYS_DTYPE_F16: return ncclFloat16; + case LLAISYS_DTYPE_BF16: return ncclBfloat16; + case LLAISYS_DTYPE_I32: return ncclInt32; + case LLAISYS_DTYPE_I8: return ncclInt8; + default: throw std::runtime_error("Unsupported data type"); + } +} + +inline ncclRedOp_t to_nccl_op(llaisysReduceOp_t op) { + switch (op) { + case LLAISYS_REDUCE_SUM: return ncclSum; + case LLAISYS_REDUCE_PROD: return ncclProd; + case LLAISYS_REDUCE_MIN: return ncclMin; + case LLAISYS_REDUCE_MAX: return ncclMax; + default: throw std::runtime_error("Unsupported reduce op"); + } +} + +namespace nccl { + +int commInit(llaisysComm_t *comm, int rank, int size, const void *unique_id) { + ncclComm_t nccl_comm; + ncclUniqueId id; + + if (unique_id) { + memcpy(&id, unique_id, sizeof(id)); + } else if (rank == 0) { + nccl_check(ncclGetUniqueId(&id)); + } + + nccl_check(ncclCommInitRank(&nccl_comm, size, id, rank)); + *comm = reinterpret_cast(nccl_comm); + return 0; +} + +int commGenerateUniqueId(void *id_out, size_t *id_size) { + ncclUniqueId id; + nccl_check(ncclGetUniqueId(&id)); + memcpy(id_out, &id, sizeof(id)); + *id_size = sizeof(id); + return 0; +} + +void commDestroy(llaisysComm_t comm) { + ncclComm_t nccl_comm = reinterpret_cast(comm); + nccl_check(ncclCommDestroy(nccl_comm)); +} + +int commGetRank(llaisysComm_t comm) { + ncclComm_t nccl_comm = reinterpret_cast(comm); + int rank; + nccl_check(ncclCommUserRank(nccl_comm, &rank)); + return rank; +} + +int commGetSize(llaisysComm_t comm) { + ncclComm_t nccl_comm = reinterpret_cast(comm); + int size; + nccl_check(ncclCommCount(nccl_comm, &size)); + return size; +} + +void commAllreduce(const void *sendbuf, void *recvbuf, size_t count, + llaisysDataType_t dtype, llaisysReduceOp_t op, + llaisysComm_t comm, llaisysStream_t stream) { + ncclComm_t nccl_comm = reinterpret_cast(comm); + cudaStream_t cuda_stream = reinterpret_cast(stream); + nccl_check(ncclAllReduce(sendbuf, recvbuf, count, to_nccl_dtype(dtype), + to_nccl_op(op), nccl_comm, cuda_stream)); +} + +void commBroadcast(void *buf, size_t count, llaisysDataType_t dtype, int root, + llaisysComm_t comm, llaisysStream_t stream) { + ncclComm_t nccl_comm = reinterpret_cast(comm); + cudaStream_t cuda_stream = reinterpret_cast(stream); + nccl_check(ncclBroadcast(buf, buf, count, to_nccl_dtype(dtype), root, + nccl_comm, cuda_stream)); +} + +void commSend(const void *buf, size_t count, llaisysDataType_t dtype, int peer, + llaisysComm_t comm, llaisysStream_t stream) { + ncclComm_t nccl_comm = reinterpret_cast(comm); + cudaStream_t cuda_stream = reinterpret_cast(stream); + nccl_check(ncclSend(buf, count, to_nccl_dtype(dtype), peer, nccl_comm, + cuda_stream)); +} + +void commRecv(void *buf, size_t count, llaisysDataType_t dtype, int peer, + llaisysComm_t comm, llaisysStream_t stream) { + ncclComm_t nccl_comm = reinterpret_cast(comm); + cudaStream_t cuda_stream = reinterpret_cast(stream); + nccl_check(ncclRecv(buf, count, to_nccl_dtype(dtype), peer, nccl_comm, + cuda_stream)); +} + +static const LlaisysCommAPI NCCL_COMM_API = { + &commInit, + &commDestroy, + &commGetRank, + &commGetSize, + &commAllreduce, + &commBroadcast, + &commSend, + &commRecv +}; + +const LlaisysCommAPI *getCommAPI() { + return &NCCL_COMM_API; +} + +} // namespace nccl +} // namespace llaisys::device::nvidia + +namespace llaisys::device::nccl { +const LlaisysCommAPI *getCommAPI() { + return llaisys::device::nvidia::nccl::getCommAPI(); +} +int commGenerateUniqueId(void *id_out, size_t *id_size) { + return llaisys::device::nvidia::nccl::commGenerateUniqueId(id_out, id_size); +} +} diff --git a/src/device/nvidia/nvidia_runtime_api.cu b/src/device/nvidia/nvidia_runtime_api.cu index cab928261..2c8bf713e 100644 --- a/src/device/nvidia/nvidia_runtime_api.cu +++ b/src/device/nvidia/nvidia_runtime_api.cu @@ -1,56 +1,100 @@ #include "../runtime_api.hpp" +#include "cuda_utils.hpp" -#include -#include +#include namespace llaisys::device::nvidia { namespace runtime_api { int getDeviceCount() { - TO_BE_IMPLEMENTED(); + int count = 0; + cuda_check(cudaGetDeviceCount(&count)); + return count; } -void setDevice(int) { - TO_BE_IMPLEMENTED(); +void setDevice(int device_id) { + cuda_check(cudaSetDevice(device_id)); } void deviceSynchronize() { - TO_BE_IMPLEMENTED(); + cuda_check(cudaDeviceSynchronize()); } llaisysStream_t createStream() { - TO_BE_IMPLEMENTED(); + cudaStream_t stream{}; + cuda_check(cudaStreamCreate(&stream)); + return reinterpret_cast(stream); } void destroyStream(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + cuda_check(cudaStreamDestroy(reinterpret_cast(stream))); } void streamSynchronize(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + cuda_check(cudaStreamSynchronize(reinterpret_cast(stream))); } void *mallocDevice(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr = nullptr; + cuda_check(cudaMalloc(&ptr, size)); + return ptr; } void freeDevice(void *ptr) { - TO_BE_IMPLEMENTED(); + cuda_check(cudaFree(ptr)); } void *mallocHost(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr = nullptr; + cuda_check(cudaMallocHost(&ptr, size)); + return ptr; } void freeHost(void *ptr) { - TO_BE_IMPLEMENTED(); + cuda_check(cudaFreeHost(ptr)); } void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); + cudaMemcpyKind cuda_kind = cudaMemcpyDefault; + switch (kind) { + case LLAISYS_MEMCPY_H2H: + cuda_kind = cudaMemcpyHostToHost; + break; + case LLAISYS_MEMCPY_H2D: + cuda_kind = cudaMemcpyHostToDevice; + break; + case LLAISYS_MEMCPY_D2H: + cuda_kind = cudaMemcpyDeviceToHost; + break; + case LLAISYS_MEMCPY_D2D: + cuda_kind = cudaMemcpyDeviceToDevice; + break; + default: + cuda_kind = cudaMemcpyDefault; + break; + } + cuda_check(cudaMemcpy(dst, src, size, cuda_kind)); } -void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); +void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) { + cudaMemcpyKind cuda_kind = cudaMemcpyDefault; + switch (kind) { + case LLAISYS_MEMCPY_H2H: + cuda_kind = cudaMemcpyHostToHost; + break; + case LLAISYS_MEMCPY_H2D: + cuda_kind = cudaMemcpyHostToDevice; + break; + case LLAISYS_MEMCPY_D2H: + cuda_kind = cudaMemcpyDeviceToHost; + break; + case LLAISYS_MEMCPY_D2D: + cuda_kind = cudaMemcpyDeviceToDevice; + break; + default: + cuda_kind = cudaMemcpyDefault; + break; + } + cuda_check(cudaMemcpyAsync(dst, src, size, cuda_kind, reinterpret_cast(stream))); } static const LlaisysRuntimeAPI RUNTIME_API = { diff --git a/src/device/runtime_api.cpp b/src/device/runtime_api.cpp index 2de3eca02..1a8dfc6be 100644 --- a/src/device/runtime_api.cpp +++ b/src/device/runtime_api.cpp @@ -80,6 +80,12 @@ const LlaisysRuntimeAPI *getRuntimeAPI(llaisysDeviceType_t device_type) { return llaisys::device::nvidia::getRuntimeAPI(); #else return getUnsupportedRuntimeAPI(); +#endif + case LLAISYS_DEVICE_ILUVATAR: +#ifdef ENABLE_ILUVATAR_API + return llaisys::device::iluvatar::getRuntimeAPI(); +#else + return getUnsupportedRuntimeAPI(); #endif default: EXCEPTION_UNSUPPORTED_DEVICE; diff --git a/src/device/runtime_api.hpp b/src/device/runtime_api.hpp index e6b9f80d6..12ebdc40f 100644 --- a/src/device/runtime_api.hpp +++ b/src/device/runtime_api.hpp @@ -17,4 +17,10 @@ namespace nvidia { const LlaisysRuntimeAPI *getRuntimeAPI(); } #endif + +#ifdef ENABLE_ILUVATAR_API +namespace iluvatar { +const LlaisysRuntimeAPI *getRuntimeAPI(); +} +#endif } // namespace llaisys::device diff --git a/src/llaisys/comm.cc b/src/llaisys/comm.cc new file mode 100644 index 000000000..d3b0c9c8c --- /dev/null +++ b/src/llaisys/comm.cc @@ -0,0 +1,10 @@ +#include "llaisys/comm.h" +#include "../device/comm_api.hpp" + +__C const LlaisysCommAPI *llaisysGetCommAPI(llaisysCommBackend_t backend) { + return llaisys::device::getCommAPI(backend); +} + +__C int llaisysCommGenerateUniqueId(llaisysCommBackend_t backend, void *id_out, size_t *id_size) { + return llaisys::device::commGenerateUniqueId(backend, id_out, id_size); +} diff --git a/src/llaisys/models/qwen2.cpp b/src/llaisys/models/qwen2.cpp new file mode 100644 index 000000000..5151c2537 --- /dev/null +++ b/src/llaisys/models/qwen2.cpp @@ -0,0 +1,479 @@ +// Qwen2 C API implementation (skeleton) +#include "llaisys/models/qwen2.h" +#include "../../models/qwen2/qwen2.hpp" +#include "qwen2_kv_internal.hpp" + +#include +#include +#include +#include +#include + +struct LlaisysQwen2Model { + LlaisysQwen2Meta meta{}; + LlaisysQwen2Weights weights{}; + llaisysDeviceType_t device = LLAISYS_DEVICE_CPU; + std::vector device_ids; + std::unique_ptr impl; + LlaisysQwen2KVContext *kv_ctx = nullptr; // experimental, non-decoder path +}; + +static void init_layer_arrays(LlaisysQwen2Weights &w, size_t nlayer) { + w.attn_norm_w = new llaisysTensor_t[nlayer](); + w.attn_q_w = new llaisysTensor_t[nlayer](); + w.attn_q_b = new llaisysTensor_t[nlayer](); + w.attn_k_w = new llaisysTensor_t[nlayer](); + w.attn_k_b = new llaisysTensor_t[nlayer](); + w.attn_v_w = new llaisysTensor_t[nlayer](); + w.attn_v_b = new llaisysTensor_t[nlayer](); + w.attn_o_w = new llaisysTensor_t[nlayer](); + w.mlp_norm_w = new llaisysTensor_t[nlayer](); + w.mlp_gate_w = new llaisysTensor_t[nlayer](); + w.mlp_up_w = new llaisysTensor_t[nlayer](); + w.mlp_down_w = new llaisysTensor_t[nlayer](); +} + +static void destroy_layer_arrays(LlaisysQwen2Weights &w, size_t nlayer) { + auto destroy_array = [nlayer](llaisysTensor_t *arr) { + if (!arr) return; + for (size_t i = 0; i < nlayer; ++i) { + if (arr[i]) { + tensorDestroy(arr[i]); + arr[i] = nullptr; + } + } + delete[] arr; + }; + + destroy_array(w.attn_norm_w); + destroy_array(w.attn_q_w); + destroy_array(w.attn_q_b); + destroy_array(w.attn_k_w); + destroy_array(w.attn_k_b); + destroy_array(w.attn_v_w); + destroy_array(w.attn_v_b); + destroy_array(w.attn_o_w); + destroy_array(w.mlp_norm_w); + destroy_array(w.mlp_gate_w); + destroy_array(w.mlp_up_w); + destroy_array(w.mlp_down_w); + + w.attn_norm_w = nullptr; + w.attn_q_w = nullptr; + w.attn_q_b = nullptr; + w.attn_k_w = nullptr; + w.attn_k_b = nullptr; + w.attn_v_w = nullptr; + w.attn_v_b = nullptr; + w.attn_o_w = nullptr; + w.mlp_norm_w = nullptr; + w.mlp_gate_w = nullptr; + w.mlp_up_w = nullptr; + w.mlp_down_w = nullptr; +} + +__C { + __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate( + const LlaisysQwen2Meta *meta, + llaisysDeviceType_t device, + int *device_ids, + int ndevice) { + if (!meta || ndevice <= 0) return nullptr; + + auto *model = new LlaisysQwen2Model(); + model->meta = *meta; + model->device = device; + model->device_ids.assign(device_ids, device_ids + ndevice); + + init_layer_arrays(model->weights, model->meta.nlayer); + model->impl = std::make_unique( + model->meta, + model->weights, + model->device, + model->device_ids); + + return model; + } + + //销毁千问2模型实例 + __export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model *model) { + if (!model) return; + + if (model->weights.in_embed) { + tensorDestroy(model->weights.in_embed); + model->weights.in_embed = nullptr; + } + if (model->weights.out_embed) { + tensorDestroy(model->weights.out_embed); + model->weights.out_embed = nullptr; + } + if (model->weights.out_norm_w) { + tensorDestroy(model->weights.out_norm_w); + model->weights.out_norm_w = nullptr; + } + + destroy_layer_arrays(model->weights, model->meta.nlayer); + if (model->kv_ctx) { + llaisysQwen2KVContextRelease(model->kv_ctx); + model->kv_ctx = nullptr; + } + + model->impl.reset(); + delete model; + } + + + //获取千问2模型权重 + __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model *model) { + if (!model) return nullptr; + return &model->weights; + } + + //执行千问2模型推理 + __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken) { + if (!model || !model->impl) return -1; + try { + return model->impl->infer(token_ids, ntoken); + } catch (const std::exception &e) { + std::cerr << "[ERROR] Qwen2 infer failed: " << e.what() << std::endl; + return -1; + } catch (...) { + std::cerr << "[ERROR] Qwen2 infer failed: unknown exception" << std::endl; + return -1; + } + } + + __export int64_t llaisysQwen2ModelPrefill(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken) { + if (!model || !model->impl) return -1; + try { + return model->impl->prefill(token_ids, ntoken); + } catch (const std::exception &e) { + std::cerr << "[ERROR] Qwen2 prefill failed: " << e.what() << std::endl; + return -1; + } catch (...) { + std::cerr << "[ERROR] Qwen2 prefill failed: unknown exception" << std::endl; + return -1; + } + } + + __export int64_t llaisysQwen2ModelStep(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken) { + if (!model || !model->impl) return -1; + try { + return model->impl->step(token_ids, ntoken); + } catch (const std::exception &e) { + std::cerr << "[ERROR] Qwen2 step failed: " << e.what() << std::endl; + return -1; + } catch (...) { + std::cerr << "[ERROR] Qwen2 step failed: unknown exception" << std::endl; + return -1; + } + } + + __export int32_t llaisysQwen2ModelPrefillPacked(struct LlaisysQwen2Model *model, + int64_t *token_ids, + const int64_t *token_offsets, + size_t nseq, + int64_t *out_next_tokens) { + if (!model || !model->impl || !token_ids || !token_offsets || !out_next_tokens || nseq == 0) return -1; + try { + const size_t ntoken = static_cast(token_offsets[nseq]); + if (!model->impl->prefillPacked(token_ids, ntoken, token_offsets, nseq, out_next_tokens)) return -2; + return 0; + } catch (const std::exception &e) { + std::cerr << "[ERROR] Qwen2 prefill packed failed: " << e.what() << std::endl; + return -3; + } catch (...) { + std::cerr << "[ERROR] Qwen2 prefill packed failed: unknown exception" << std::endl; + return -4; + } + } + + __export int32_t llaisysQwen2ModelStepPacked(struct LlaisysQwen2Model *model, + int64_t *token_ids, + const int64_t *token_offsets, + size_t nseq, + int64_t *out_next_tokens) { + if (!model || !model->impl || !token_ids || !token_offsets || !out_next_tokens || nseq == 0) return -1; + try { + const size_t ntoken = static_cast(token_offsets[nseq]); + if (!model->impl->stepPacked(token_ids, ntoken, token_offsets, nseq, out_next_tokens)) return -2; + return 0; + } catch (const std::exception &e) { + std::cerr << "[ERROR] Qwen2 step packed failed: " << e.what() << std::endl; + return -3; + } catch (...) { + std::cerr << "[ERROR] Qwen2 step packed failed: unknown exception" << std::endl; + return -4; + } + } + + __export int64_t llaisysQwen2ModelPrefillSampling(struct LlaisysQwen2Model *model, + int64_t *token_ids, + size_t ntoken, + const LlaisysSamplingParams *params) { + if (!model || !model->impl) return -1; + try { + return model->impl->prefillSampling(token_ids, ntoken, params); + } catch (const std::exception &e) { + std::cerr << "[ERROR] Qwen2 prefill sampling failed: " << e.what() << std::endl; + return -1; + } catch (...) { + std::cerr << "[ERROR] Qwen2 prefill sampling failed: unknown exception" << std::endl; + return -1; + } + } + + __export int64_t llaisysQwen2ModelStepSampling(struct LlaisysQwen2Model *model, + int64_t *token_ids, + size_t ntoken, + const LlaisysSamplingParams *params) { + if (!model || !model->impl) return -1; + try { + return model->impl->stepSampling(token_ids, ntoken, params); + } catch (const std::exception &e) { + std::cerr << "[ERROR] Qwen2 step sampling failed: " << e.what() << std::endl; + return -1; + } catch (...) { + std::cerr << "[ERROR] Qwen2 step sampling failed: unknown exception" << std::endl; + return -1; + } + } + + __export int64_t llaisysQwen2ModelInferSampling(struct LlaisysQwen2Model *model, + int64_t *token_ids, + size_t ntoken, + const LlaisysSamplingParams *params) { + if (!model || !model->impl) return -1; + try { + return model->impl->prefillSampling(token_ids, ntoken, params); + } catch (const std::exception &e) { + std::cerr << "[ERROR] Qwen2 infer sampling failed: " << e.what() << std::endl; + return -1; + } catch (...) { + std::cerr << "[ERROR] Qwen2 infer sampling failed: unknown exception" << std::endl; + return -1; + } + } + + __export int64_t llaisysQwen2ModelInferSamplingEx(struct LlaisysQwen2Model *model, + int64_t *token_ids, + size_t ntoken, + int32_t top_k, + float top_p, + float temperature, + uint32_t seed) { + if (!model || !model->impl) return -1; + LlaisysSamplingParams params{}; + params.top_k = top_k; + params.top_p = top_p; + params.temperature = temperature; + params.seed = seed; + return llaisysQwen2ModelInferSampling(model, token_ids, ntoken, ¶ms); + } + + __export void llaisysQwen2ModelResetKVCache(struct LlaisysQwen2Model *model) { + if (!model || !model->impl) return; + model->impl->resetKVCache(); + } + + __export void llaisysQwen2ModelSetKVCacheEnabled(struct LlaisysQwen2Model *model, uint8_t enabled) { + if (!model || !model->impl) return; + model->impl->setKVCacheEnabled(enabled != 0); + } + + __export int32_t llaisysQwen2ModelSetTensorParallel(struct LlaisysQwen2Model *model, + llaisysComm_t comm, + llaisysStream_t stream, + int tp_size) { + if (!model || !model->impl) return -1; + model->impl->setTensorParallel(comm, stream, tp_size); + return 0; + } + + __export struct LlaisysQwen2KVBlock *llaisysQwen2KVBlockCreate( + const struct LlaisysQwen2KVBlockMeta *meta, + llaisysDeviceType_t device, + int device_id) { + if (!meta || meta->nlayer == 0 || meta->max_tokens == 0) return nullptr; + auto *block = new LlaisysQwen2KVBlock(); + block->meta = *meta; + block->device = device; + block->device_id = device_id; + block->k_layers.assign(meta->nlayer, nullptr); + block->v_layers.assign(meta->nlayer, nullptr); + size_t kv_shape[3] = {meta->max_tokens, meta->nkvh, meta->dh}; + for (size_t layer = 0; layer < meta->nlayer; ++layer) { + block->k_layers[layer] = tensorCreate(kv_shape, 3, meta->dtype, device, device_id); + block->v_layers[layer] = tensorCreate(kv_shape, 3, meta->dtype, device, device_id); + if (!block->k_layers[layer] || !block->v_layers[layer]) { + for (auto *t : block->k_layers) { + if (t) tensorDestroy(t); + } + for (auto *t : block->v_layers) { + if (t) tensorDestroy(t); + } + delete block; + return nullptr; + } + } + return block; + } + + __export void llaisysQwen2KVBlockRetain(struct LlaisysQwen2KVBlock *block) { + if (!block) return; + block->ref_count.fetch_add(1, std::memory_order_relaxed); + } + + __export void llaisysQwen2KVBlockRelease(struct LlaisysQwen2KVBlock *block) { + if (!block) return; + if (block->ref_count.fetch_sub(1, std::memory_order_acq_rel) == 1) { + for (auto *t : block->k_layers) { + if (t) tensorDestroy(t); + } + for (auto *t : block->v_layers) { + if (t) tensorDestroy(t); + } + block->k_layers.clear(); + block->v_layers.clear(); + delete block; + } + } + + __export int32_t llaisysQwen2KVBlockSetTokenCount(struct LlaisysQwen2KVBlock *block, size_t used_tokens) { + if (!block) return -1; + if (used_tokens > block->meta.max_tokens) return -2; + block->used_tokens = used_tokens; + return 0; + } + + __export size_t llaisysQwen2KVBlockTokenCount(const struct LlaisysQwen2KVBlock *block) { + if (!block) return 0; + return block->used_tokens; + } + + __export llaisysTensor_t llaisysQwen2KVBlockKeyTensor(struct LlaisysQwen2KVBlock *block, size_t layer) { + if (!block || layer >= block->k_layers.size()) return nullptr; + return block->k_layers[layer]; + } + + __export llaisysTensor_t llaisysQwen2KVBlockValueTensor(struct LlaisysQwen2KVBlock *block, size_t layer) { + if (!block || layer >= block->v_layers.size()) return nullptr; + return block->v_layers[layer]; + } + + __export struct LlaisysQwen2KVContext *llaisysQwen2KVContextCreate( + llaisysDataType_t dtype, + llaisysDeviceType_t device, + int device_id, + size_t nlayer, + size_t nh, + size_t nkvh, + size_t dh) { + if (nlayer == 0 || dh == 0) return nullptr; + auto *ctx = new LlaisysQwen2KVContext(); + ctx->dtype = dtype; + ctx->device = device; + ctx->device_id = device_id; + ctx->nlayer = nlayer; + ctx->nh = nh; + ctx->nkvh = nkvh; + ctx->dh = dh; + return ctx; + } + + __export void llaisysQwen2KVContextRetain(struct LlaisysQwen2KVContext *ctx) { + if (!ctx) return; + ctx->ref_count.fetch_add(1, std::memory_order_relaxed); + } + + __export void llaisysQwen2KVContextRelease(struct LlaisysQwen2KVContext *ctx) { + if (!ctx) return; + if (ctx->ref_count.fetch_sub(1, std::memory_order_acq_rel) == 1) { + for (auto *blk : ctx->chain) { + llaisysQwen2KVBlockRelease(blk); + } + ctx->chain.clear(); + delete ctx; + } + } + + __export int32_t llaisysQwen2KVContextAttachBlock( + struct LlaisysQwen2KVContext *ctx, + struct LlaisysQwen2KVBlock *block) { + if (!ctx || !block) return -1; + if (ctx->device != block->device || ctx->device_id != block->device_id) return -2; + if (ctx->dtype != block->meta.dtype) return -3; + if (ctx->nlayer != block->meta.nlayer || ctx->dh != block->meta.dh) return -4; + if (ctx->nkvh != block->meta.nkvh || ctx->nh != block->meta.nh) return -5; + llaisysQwen2KVBlockRetain(block); + ctx->chain.push_back(block); + return 0; + } + + __export void llaisysQwen2KVContextDetachAll(struct LlaisysQwen2KVContext *ctx) { + if (!ctx) return; + for (auto *blk : ctx->chain) { + llaisysQwen2KVBlockRelease(blk); + } + ctx->chain.clear(); + } + + __export size_t llaisysQwen2KVContextBlockCount(const struct LlaisysQwen2KVContext *ctx) { + if (!ctx) return 0; + return ctx->chain.size(); + } + + __export size_t llaisysQwen2KVContextTokenCount(const struct LlaisysQwen2KVContext *ctx) { + if (!ctx) return 0; + size_t total = 0; + for (auto *blk : ctx->chain) { + if (!blk) continue; + total += std::min(blk->used_tokens, blk->meta.max_tokens); + } + return total; + } + + __export int32_t llaisysQwen2ModelSetKVContext( + struct LlaisysQwen2Model *model, + struct LlaisysQwen2KVContext *ctx) { + if (!model) return -1; + if (ctx) { + if (model->device != ctx->device) return -2; + const int model_device_id = model->device_ids.empty() ? 0 : model->device_ids[0]; + if (model_device_id != ctx->device_id) return -3; + llaisysQwen2KVContextRetain(ctx); + } + if (model->kv_ctx) { + llaisysQwen2KVContextRelease(model->kv_ctx); + } + model->kv_ctx = ctx; + if (model->impl) { + const size_t past_len_tokens = llaisysQwen2KVContextTokenCount(ctx); + model->impl->setKVContext(ctx, past_len_tokens); + } + return 0; + } + + __export struct LlaisysQwen2KVContext *llaisysQwen2ModelGetKVContext( + struct LlaisysQwen2Model *model) { + if (!model) return nullptr; + auto *ctx = model->kv_ctx; + if (model->impl) { + ctx = reinterpret_cast(model->impl->getKVContext()); + } + if (!ctx) return nullptr; + llaisysQwen2KVContextRetain(ctx); + return ctx; + } + + __export int32_t llaisysQwen2ModelExportKVContext( + struct LlaisysQwen2Model *model, + struct LlaisysQwen2KVContext *ctx, + size_t block_tokens) { + if (!model || !model->impl || !ctx) return -1; + if (model->device != ctx->device) return -2; + const int model_device_id = model->device_ids.empty() ? 0 : model->device_ids[0]; + if (model_device_id != ctx->device_id) return -3; + return static_cast(model->impl->exportKVContext(ctx, block_tokens)); + } +} diff --git a/src/llaisys/models/qwen2_kv_internal.hpp b/src/llaisys/models/qwen2_kv_internal.hpp new file mode 100644 index 000000000..7c83bd071 --- /dev/null +++ b/src/llaisys/models/qwen2_kv_internal.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "llaisys/models/qwen2.h" + +#include +#include + +struct LlaisysQwen2KVBlock { + LlaisysQwen2KVBlockMeta meta{}; + llaisysDeviceType_t device = LLAISYS_DEVICE_CPU; + int device_id = 0; + size_t used_tokens = 0; + std::vector k_layers; + std::vector v_layers; + std::atomic ref_count{1}; +}; + +struct LlaisysQwen2KVContext { + llaisysDataType_t dtype = LLAISYS_DTYPE_F32; + llaisysDeviceType_t device = LLAISYS_DEVICE_CPU; + int device_id = 0; + size_t nlayer = 0; + size_t nh = 0; + size_t nkvh = 0; + size_t dh = 0; + std::vector chain; + std::atomic ref_count{1}; +}; diff --git a/src/llaisys/ops.cc b/src/llaisys/ops.cc index c99fbc32f..625887cac 100644 --- a/src/llaisys/ops.cc +++ b/src/llaisys/ops.cc @@ -23,7 +23,10 @@ __C { llaisys::ops::embedding(out->tensor, index->tensor, weight->tensor); } void llaisysLinear(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, llaisysTensor_t bias) { - llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, bias->tensor); + llaisys::ops::linear(out->tensor, + in->tensor, + weight->tensor, + bias ? bias->tensor : nullptr); } void llaisysRearrange(llaisysTensor_t out, llaisysTensor_t in) { llaisys::ops::rearrange(out->tensor, in->tensor); @@ -37,6 +40,24 @@ __C { void llaisysSelfAttention(llaisysTensor_t attn_val, llaisysTensor_t q, llaisysTensor_t k, llaisysTensor_t v, float scale) { llaisys::ops::self_attention(attn_val->tensor, q->tensor, k->tensor, v->tensor, scale); } + void llaisysSelfAttentionSegmented(llaisysTensor_t attn_val, + llaisysTensor_t q, + llaisysTensor_t k, + llaisysTensor_t v, + float scale, + const int64_t *q_offsets, + const int64_t *kv_offsets, + size_t nseg) { + llaisys::ops::self_attention_segmented( + attn_val->tensor, + q->tensor, + k->tensor, + v->tensor, + scale, + q_offsets, + kv_offsets, + nseg); + } void llaisysSwiGLU(llaisysTensor_t out, llaisysTensor_t gate, llaisysTensor_t up) { llaisys::ops::swiglu(out->tensor, gate->tensor, up->tensor); } diff --git a/src/llaisys/tokenizer.cc b/src/llaisys/tokenizer.cc new file mode 100644 index 000000000..c15abd11e --- /dev/null +++ b/src/llaisys/tokenizer.cc @@ -0,0 +1,84 @@ +#include "llaisys/tokenizer.h" + +#include "../tokenizer/sentencepiece/sentencepiece.hpp" + +#include +#include +#include +#include + +struct LlaisysTokenizer { + std::unique_ptr impl; +}; + +__C { +__export struct LlaisysTokenizer *llaisysTokenizerCreateSentencePiece(const char *model_path) { + if (!model_path || model_path[0] == '\0') return nullptr; + auto tokenizer = std::make_unique(); + tokenizer->impl = std::make_unique(model_path); + if (!tokenizer->impl || !tokenizer->impl->isLoaded()) { + return nullptr; + } + return tokenizer.release(); +} + +__export void llaisysTokenizerDestroy(struct LlaisysTokenizer *tokenizer) { + delete tokenizer; +} + +__export int llaisysTokenizerEncode(struct LlaisysTokenizer *tokenizer, + const char *text, + int64_t *out_ids, + size_t max_ids) { + if (!tokenizer || !tokenizer->impl || !text) return -1; + std::vector ids; + if (!tokenizer->impl->encode(text, ids)) return -1; + if (!out_ids || max_ids == 0) { + return static_cast(ids.size()); + } + const size_t n = ids.size() < max_ids ? ids.size() : max_ids; + for (size_t i = 0; i < n; ++i) out_ids[i] = ids[i]; + return static_cast(n); +} + +__export int llaisysTokenizerDecode(struct LlaisysTokenizer *tokenizer, + const int64_t *ids, + size_t len, + char *out_text, + size_t max_len) { + if (!tokenizer || !tokenizer->impl) return -1; + std::string text; + if (!tokenizer->impl->decode(ids, len, text)) return -1; + if (!out_text || max_len == 0) { + return static_cast(text.size() + 1); + } + const size_t n = text.size() < (max_len - 1) ? text.size() : (max_len - 1); + std::memcpy(out_text, text.data(), n); + out_text[n] = '\0'; + return static_cast(n); +} + +__export int64_t llaisysTokenizerTokenToId(struct LlaisysTokenizer *tokenizer, const char *token) { + if (!tokenizer || !tokenizer->impl || !token) return -1; + int64_t id = -1; + if (!tokenizer->impl->pieceToId(token, id)) return -1; + return id; +} + +__export int llaisysTokenizerIdToToken(struct LlaisysTokenizer *tokenizer, + int64_t id, + char *out_token, + size_t max_len) { + if (!tokenizer || !tokenizer->impl) return -1; + std::string piece; + if (!tokenizer->impl->idToPiece(id, piece)) return -1; + if (!out_token || max_len == 0) { + return static_cast(piece.size() + 1); + } + const size_t n = piece.size() < (max_len - 1) ? piece.size() : (max_len - 1); + std::memcpy(out_token, piece.data(), n); + out_token[n] = '\0'; + return static_cast(n); +} + +} diff --git a/src/models/qwen2/qwen2.cpp b/src/models/qwen2/qwen2.cpp new file mode 100644 index 000000000..0082aba8f --- /dev/null +++ b/src/models/qwen2/qwen2.cpp @@ -0,0 +1,500 @@ +#include "qwen2.hpp" + +#include "llaisys/ops.h" + +#include "../../utils.hpp" +#include "../../core/context/context.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace llaisys::models { +Qwen2::Qwen2(const LlaisysQwen2Meta &meta, + const LlaisysQwen2Weights &weights, + llaisysDeviceType_t device, + const std::vector &device_ids) + : _meta(meta), + _weights(&weights), + _device(device), + _device_ids(device_ids), + _decoder(transformer::DecoderConfig{ + meta.dtype, + meta.nlayer, + meta.hs, + meta.nh, + meta.nkvh, + meta.dh, + meta.di, + meta.maxseq, + meta.voc, + meta.epsilon, + meta.theta}, + &weights, + device, + device_ids) {} + +Qwen2::~Qwen2() { + clearPackedState(); +} + +void Qwen2::resetKVCache() { + clearPackedState(); + _decoder.resetKVCache(); +} + +void Qwen2::setKVCacheEnabled(bool enabled) { + _decoder.setKVCacheEnabled(enabled); +} + +void Qwen2::setTensorParallel(llaisysComm_t comm, llaisysStream_t stream, int tp_size) { + _decoder.setTensorParallel(comm, stream, tp_size); +} + +void Qwen2::setKVContext(void *ctx, size_t past_len_tokens) { + clearPackedState(); + _kv_ctx = ctx; + if (ctx) { + _decoder.bindExternalKVContext(ctx, past_len_tokens); + } else { + _decoder.clearExternalKVContext(); + } +} + +void *Qwen2::getKVContext() const { + return _kv_ctx; +} + +int Qwen2::exportKVContext(void *ctx, size_t block_tokens) { + return _decoder.exportKVContext(ctx, block_tokens); +} + +void Qwen2::clearPackedState() { + for (auto *ctx : _packed_kv_contexts) { + if (ctx) { + ::llaisysQwen2KVContextRelease(ctx); + } + } + _packed_kv_contexts.clear(); + _packed_prompts.clear(); +} + +//执行千问2模型推理 +static int64_t argmax_from_logits(llaisysTensor_t logits, + llaisysDataType_t dtype, + llaisysDeviceType_t device, + int device_id) { + int64_t next_token = -1; + size_t one_shape[1] = {1}; + llaisysTensor_t max_idx = tensorCreate(one_shape, 1, LLAISYS_DTYPE_I64, device, device_id); + llaisysTensor_t max_val = tensorCreate(one_shape, 1, dtype, device, device_id); + if (!max_idx || !max_val) { + if (max_idx) tensorDestroy(max_idx); + if (max_val) tensorDestroy(max_val); + return -1; + } + ::llaisysArgmax(max_idx, max_val, logits); + if (tensorGetDeviceType(max_idx) == LLAISYS_DEVICE_CPU) { + next_token = *reinterpret_cast(tensorGetData(max_idx)); + } else { + int64_t host_val = -1; + llaisys::core::context().setDevice(device, device_id); + llaisys::core::context().runtime().api()->memcpy_sync( + &host_val, + tensorGetData(max_idx), + sizeof(int64_t), + LLAISYS_MEMCPY_D2H); + next_token = host_val; + } + tensorDestroy(max_idx); + tensorDestroy(max_val); + return next_token; +} + +static std::vector logits_to_host(llaisysTensor_t logits, + llaisysDataType_t dtype, + llaisysDeviceType_t device, + int device_id, + size_t vocab) { + std::vector host(vocab, 0.0f); + const size_t bytes = vocab * utils::dsize(dtype); + if (device == LLAISYS_DEVICE_CPU) { + const std::byte *src = reinterpret_cast(tensorGetData(logits)); + if (dtype == LLAISYS_DTYPE_F32) { + const float *vals = reinterpret_cast(src); + for (size_t i = 0; i < vocab; ++i) { + host[i] = vals[i]; + } + } else if (dtype == LLAISYS_DTYPE_F16) { + const fp16_t *vals = reinterpret_cast(src); + for (size_t i = 0; i < vocab; ++i) { + host[i] = utils::cast(vals[i]); + } + } else if (dtype == LLAISYS_DTYPE_BF16) { + const bf16_t *vals = reinterpret_cast(src); + for (size_t i = 0; i < vocab; ++i) { + host[i] = utils::cast(vals[i]); + } + } else { + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } + return host; + } + + std::vector tmp(bytes); + llaisys::core::context().setDevice(device, device_id); + llaisys::core::context().runtime().api()->memcpy_sync( + tmp.data(), tensorGetData(logits), bytes, LLAISYS_MEMCPY_D2H); + + if (dtype == LLAISYS_DTYPE_F32) { + const float *vals = reinterpret_cast(tmp.data()); + for (size_t i = 0; i < vocab; ++i) { + host[i] = vals[i]; + } + } else if (dtype == LLAISYS_DTYPE_F16) { + const fp16_t *vals = reinterpret_cast(tmp.data()); + for (size_t i = 0; i < vocab; ++i) { + host[i] = utils::cast(vals[i]); + } + } else if (dtype == LLAISYS_DTYPE_BF16) { + const bf16_t *vals = reinterpret_cast(tmp.data()); + for (size_t i = 0; i < vocab; ++i) { + host[i] = utils::cast(vals[i]); + } + } else { + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } + return host; +} + +static int64_t sample_from_logits(const std::vector &logits, + const LlaisysSamplingParams *params) { + const size_t vocab = logits.size(); + if (vocab == 0) { + return -1; + } + + int top_k = params ? params->top_k : 1; + float top_p = params ? params->top_p : 0.0f; + float temperature = params ? params->temperature : 0.0f; + uint32_t seed = params ? params->seed : 0u; + + if (temperature <= 0.0f && top_k <= 1 && top_p <= 0.0f) { + return static_cast(std::distance(logits.begin(), + std::max_element(logits.begin(), logits.end()))); + } + + std::vector indices(vocab); + std::iota(indices.begin(), indices.end(), 0); + + if (top_k > 0 && static_cast(top_k) < vocab) { + std::partial_sort(indices.begin(), indices.begin() + top_k, indices.end(), + [&](int a, int b) { return logits[a] > logits[b]; }); + indices.resize(top_k); + } + + const float temp = temperature > 0.0f ? temperature : 1.0f; + std::vector filtered_logits; + filtered_logits.reserve(indices.size()); + for (int idx : indices) { + filtered_logits.push_back(logits[idx] / std::max(temp, 1e-6f)); + } + + float max_logit = *std::max_element(filtered_logits.begin(), filtered_logits.end()); + std::vector probs(filtered_logits.size()); + float sum = 0.0f; + for (size_t i = 0; i < filtered_logits.size(); ++i) { + probs[i] = std::exp(filtered_logits[i] - max_logit); + sum += probs[i]; + } + if (sum <= 0.0f) { + return indices.front(); + } + for (float &p : probs) { + p /= sum; + } + + if (top_p > 0.0f && top_p < 1.0f) { + std::vector order(probs.size()); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&](size_t a, size_t b) { return probs[a] > probs[b]; }); + float cumulative = 0.0f; + size_t keep = 0; + for (size_t idx : order) { + cumulative += probs[idx]; + keep++; + if (cumulative >= top_p) { + break; + } + } + std::vector new_indices; + std::vector new_probs; + new_indices.reserve(keep); + new_probs.reserve(keep); + for (size_t i = 0; i < keep; ++i) { + size_t idx = order[i]; + new_indices.push_back(indices[idx]); + new_probs.push_back(probs[idx]); + } + indices.swap(new_indices); + probs.swap(new_probs); + float new_sum = std::accumulate(probs.begin(), probs.end(), 0.0f); + if (new_sum > 0.0f) { + for (float &p : probs) { + p /= new_sum; + } + } + } + + std::mt19937 rng(seed == 0 ? std::random_device{}() : seed); + std::uniform_real_distribution dist(0.0f, 1.0f); + float r = dist(rng); + float cumulative = 0.0f; + for (size_t i = 0; i < probs.size(); ++i) { + cumulative += probs[i]; + if (r <= cumulative) { + return indices[i]; + } + } + return indices.back(); +} + +static int64_t next_token_from_logits(llaisysTensor_t logits, + llaisysDataType_t dtype, + llaisysDeviceType_t device, + int device_id, + size_t vocab, + const LlaisysSamplingParams *params) { + if (!params) { + return argmax_from_logits(logits, dtype, device, device_id); + } + auto host_logits = logits_to_host(logits, dtype, device, device_id, vocab); + return sample_from_logits(host_logits, params); +} + +int64_t Qwen2::infer(const int64_t *token_ids, size_t ntoken) { + return prefill(token_ids, ntoken); +} + +int64_t Qwen2::prefill(const int64_t *token_ids, size_t ntoken) { + if (!token_ids || ntoken == 0) return -1; + clearPackedState(); + + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + size_t logits_shape[2] = {1, _meta.voc}; + llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id); + if (!logits) return -1; + if (!_decoder.prefill(token_ids, ntoken, logits)) { + tensorDestroy(logits); + return -1; + } + + int64_t next_token = argmax_from_logits(logits, _meta.dtype, _device, device_id); + tensorDestroy(logits); + + return next_token; +} + +int64_t Qwen2::step(const int64_t *token_ids, size_t ntoken) { + if (!token_ids || ntoken == 0) return -1; + clearPackedState(); + + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + size_t logits_shape[2] = {1, _meta.voc}; + llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id); + if (!logits) return -1; + if (!_decoder.decodeStep(token_ids, ntoken, logits)) { + tensorDestroy(logits); + return -1; + } + + int64_t next_token = argmax_from_logits(logits, _meta.dtype, _device, device_id); + tensorDestroy(logits); + return next_token; +} + +bool Qwen2::prefillPacked(const int64_t *token_ids, + size_t ntoken, + const int64_t *token_offsets, + size_t nseq, + int64_t *out_next_tokens) { + if (!token_ids || !token_offsets || nseq == 0 || ntoken == 0 || !out_next_tokens) return false; + clearPackedState(); + if (token_offsets[0] != 0 || static_cast(token_offsets[nseq]) != ntoken) return false; + for (size_t i = 0; i < nseq; ++i) { + if (token_offsets[i] >= token_offsets[i + 1]) return false; + } + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + size_t logits_shape[2] = {nseq, _meta.voc}; + llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id); + if (!logits) return false; + if (!_decoder.prefillPacked(token_ids, ntoken, token_offsets, nseq, logits)) { + tensorDestroy(logits); + return false; + } + for (size_t i = 0; i < nseq; ++i) { + llaisysTensor_t row = tensorSlice(logits, 0, i, i + 1); + if (!row) { + tensorDestroy(logits); + return false; + } + out_next_tokens[i] = argmax_from_logits(row, _meta.dtype, _device, device_id); + tensorDestroy(row); + } + tensorDestroy(logits); + + _packed_prompts.resize(nseq); + for (size_t i = 0; i < nseq; ++i) { + const size_t begin = static_cast(token_offsets[i]); + const size_t end = static_cast(token_offsets[i + 1]); + _packed_prompts[i].assign(token_ids + begin, token_ids + end); + } + + // Build per-sequence KV snapshots once after packed prefill. + constexpr size_t kPackedBlockTokens = 64; + _packed_kv_contexts.assign(nseq, nullptr); + size_t single_logits_shape[2] = {1, _meta.voc}; + llaisysTensor_t single_logits = tensorCreate(single_logits_shape, 2, _meta.dtype, _device, device_id); + if (!single_logits) { + clearPackedState(); + return false; + } + for (size_t i = 0; i < nseq; ++i) { + _decoder.resetKVCache(); + _decoder.clearExternalKVContext(); + const auto &prompt = _packed_prompts[i]; + if (prompt.empty()) { + tensorDestroy(single_logits); + clearPackedState(); + return false; + } + if (!_decoder.prefill(prompt.data(), prompt.size(), single_logits)) { + tensorDestroy(single_logits); + clearPackedState(); + return false; + } + auto *ctx = ::llaisysQwen2KVContextCreate( + _meta.dtype, + _device, + device_id, + _meta.nlayer, + _meta.nh, + _meta.nkvh, + _meta.dh); + if (!ctx) { + tensorDestroy(single_logits); + clearPackedState(); + return false; + } + if (_decoder.exportKVContext(ctx, kPackedBlockTokens) != 0) { + ::llaisysQwen2KVContextRelease(ctx); + tensorDestroy(single_logits); + clearPackedState(); + return false; + } + _packed_kv_contexts[i] = ctx; + } + _decoder.clearExternalKVContext(); + _decoder.resetKVCache(); + tensorDestroy(single_logits); + return true; +} + +bool Qwen2::stepPacked(const int64_t *token_ids, + size_t ntoken, + const int64_t *token_offsets, + size_t nseq, + int64_t *out_next_tokens) { + if (!token_ids || !token_offsets || nseq == 0 || !out_next_tokens) return false; + if (token_offsets[0] != 0 || static_cast(token_offsets[nseq]) != ntoken) return false; + for (size_t i = 0; i < nseq; ++i) { + if (token_offsets[i] >= token_offsets[i + 1]) return false; + } + if (_packed_prompts.size() != nseq || _packed_kv_contexts.size() != nseq) return false; + + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + constexpr size_t kPackedBlockTokens = 64; + std::vector step_tokens(nseq, 0); + for (size_t i = 0; i < nseq; ++i) { + const size_t begin = static_cast(token_offsets[i]); + const size_t end = static_cast(token_offsets[i + 1]); + const size_t step_len = end - begin; + if (step_len != 1) return false; + step_tokens[i] = token_ids[begin]; + } + std::vector contexts(nseq, nullptr); + for (size_t i = 0; i < nseq; ++i) { + contexts[i] = _packed_kv_contexts[i]; + if (!contexts[i]) return false; + } + + size_t logits_shape[2] = {nseq, _meta.voc}; + llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id); + if (!logits) return false; + if (!_decoder.decodePacked(step_tokens.data(), nseq, contexts, logits, kPackedBlockTokens)) { + tensorDestroy(logits); + clearPackedState(); + return false; + } + for (size_t i = 0; i < nseq; ++i) { + llaisysTensor_t row = tensorSlice(logits, 0, i, i + 1); + if (!row) { + tensorDestroy(logits); + clearPackedState(); + return false; + } + out_next_tokens[i] = argmax_from_logits(row, _meta.dtype, _device, device_id); + tensorDestroy(row); + if (out_next_tokens[i] < 0) { + tensorDestroy(logits); + clearPackedState(); + return false; + } + _packed_prompts[i].push_back(step_tokens[i]); + _packed_prompts[i].push_back(out_next_tokens[i]); + } + tensorDestroy(logits); + return true; +} + +int64_t Qwen2::prefillSampling(const int64_t *token_ids, size_t ntoken, const LlaisysSamplingParams *params) { + if (!token_ids || ntoken == 0) return -1; + clearPackedState(); + + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + size_t logits_shape[2] = {1, _meta.voc}; + llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id); + if (!logits) return -1; + if (!_decoder.prefill(token_ids, ntoken, logits)) { + tensorDestroy(logits); + return -1; + } + + int64_t next_token = next_token_from_logits(logits, _meta.dtype, _device, device_id, _meta.voc, params); + tensorDestroy(logits); + return next_token; +} + +int64_t Qwen2::stepSampling(const int64_t *token_ids, size_t ntoken, const LlaisysSamplingParams *params) { + if (!token_ids || ntoken == 0) return -1; + clearPackedState(); + + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + size_t logits_shape[2] = {1, _meta.voc}; + llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id); + if (!logits) return -1; + if (!_decoder.decodeStep(token_ids, ntoken, logits)) { + tensorDestroy(logits); + return -1; + } + + int64_t next_token = next_token_from_logits(logits, _meta.dtype, _device, device_id, _meta.voc, params); + tensorDestroy(logits); + return next_token; +} +} // namespace llaisys::models diff --git a/src/models/qwen2/qwen2.hpp b/src/models/qwen2/qwen2.hpp new file mode 100644 index 000000000..5f437356c --- /dev/null +++ b/src/models/qwen2/qwen2.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "llaisys/models/qwen2.h" +#include "llaisys/tensor.h" +#include "../transformer/decoder/decoder.hpp" + +#include +#include + +namespace llaisys::models { +class Qwen2 { +public: + Qwen2(const LlaisysQwen2Meta &meta, + const LlaisysQwen2Weights &weights, + llaisysDeviceType_t device, + const std::vector &device_ids); + ~Qwen2(); + + // Compatibility entrypoint; prefer prefill/step for streaming. + int64_t infer(const int64_t *token_ids, size_t ntoken); + int64_t prefill(const int64_t *token_ids, size_t ntoken); + int64_t step(const int64_t *token_ids, size_t ntoken); + bool prefillPacked(const int64_t *token_ids, + size_t ntoken, + const int64_t *token_offsets, + size_t nseq, + int64_t *out_next_tokens); + bool stepPacked(const int64_t *token_ids, + size_t ntoken, + const int64_t *token_offsets, + size_t nseq, + int64_t *out_next_tokens); + int64_t prefillSampling(const int64_t *token_ids, size_t ntoken, const LlaisysSamplingParams *params); + int64_t stepSampling(const int64_t *token_ids, size_t ntoken, const LlaisysSamplingParams *params); + void resetKVCache(); + void setKVCacheEnabled(bool enabled); + void setTensorParallel(llaisysComm_t comm, llaisysStream_t stream, int tp_size); + void setKVContext(void *ctx, size_t past_len_tokens = 0); + void *getKVContext() const; + int exportKVContext(void *ctx, size_t block_tokens); + +private: + void clearPackedState(); + + LlaisysQwen2Meta _meta{}; + const LlaisysQwen2Weights *_weights{nullptr}; + llaisysDeviceType_t _device{LLAISYS_DEVICE_CPU}; + std::vector _device_ids; + transformer::Decoder _decoder; + void *_kv_ctx{nullptr}; + std::vector _packed_kv_contexts; + std::vector> _packed_prompts; +}; +} // namespace llaisys::models diff --git a/src/models/transformer/decoder/decoder.cpp b/src/models/transformer/decoder/decoder.cpp new file mode 100644 index 000000000..ce97a9dd1 --- /dev/null +++ b/src/models/transformer/decoder/decoder.cpp @@ -0,0 +1,1235 @@ +#include "decoder.hpp" +#include "../../../llaisys/models/qwen2_kv_internal.hpp" +#include "../../../device/comm_api.hpp" + +#include "llaisys/ops.h" + +#include +#include +#include + +namespace llaisys::models::transformer { +namespace { +bool trace_enabled() { + static bool enabled = false; + static bool inited = false; + if (!inited) { +#if defined(_WIN32) + char *value = nullptr; + size_t len = 0; + if (_dupenv_s(&value, &len, "LLAISYS_QWEN2_TRACE") == 0 && value) { + if (value[0] != '\0' && value[0] != '0') enabled = true; + free(value); + } +#else + const char *value = std::getenv("LLAISYS_QWEN2_TRACE"); + if (value && value[0] != '\0' && value[0] != '0') enabled = true; +#endif + inited = true; + } + return enabled; +} + +void trace(const char *stage) { + if (trace_enabled()) { + std::cerr << "[TRACE] Decoder forward: " << stage << std::endl; + } +} + +bool require_tensor(llaisysTensor_t t, const char *stage) { + if (t) return true; + std::cerr << "[ERROR] Decoder: tensorCreate failed at " << stage << std::endl; + return false; +} + +bool ensure_data(llaisysTensor_t t, const char *stage) { + if (!t) { + std::cerr << "[ERROR] Decoder: null tensor at " << stage << std::endl; + return false; + } + if (!tensorGetData(t)) { + std::cerr << "[ERROR] Decoder: null data at " << stage << std::endl; + return false; + } + return true; +} + +void destroy_if_not_null(llaisysTensor_t t) { + if (t) tensorDestroy(t); +} +} // namespace + +Decoder::Decoder(const DecoderConfig &config, + const LlaisysQwen2Weights *weights, + llaisysDeviceType_t device, + const std::vector &device_ids) + : _config(config), + _weights(weights), + _device(device), + _device_ids(device_ids) {} + +void Decoder::setTensorParallel(llaisysComm_t comm, llaisysStream_t stream, int tp_size) { + _comm = comm; + _comm_stream = stream; + _tp_size = tp_size > 0 ? tp_size : 1; +} + +Decoder::~Decoder() { + releaseCache(); +} + +void Decoder::ensureCache() { + if (!_kv_cache_enabled || _cache_inited || _config.maxseq == 0 || _config.nlayer == 0) return; + _k_cache.assign(_config.nlayer, nullptr); + _v_cache.assign(_config.nlayer, nullptr); + + size_t kv_shape[3] = {_config.maxseq, _config.nkvh, _config.dh}; + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + for (size_t i = 0; i < _config.nlayer; ++i) { + _k_cache[i] = tensorCreate(kv_shape, 3, _config.dtype, _device, device_id); + _v_cache[i] = tensorCreate(kv_shape, 3, _config.dtype, _device, device_id); + } + _past_len = 0; + _cache_inited = true; +} + +void Decoder::releaseCache() { + for (auto &t : _k_cache) { + if (t) tensorDestroy(t); + t = nullptr; + } + for (auto &t : _v_cache) { + if (t) tensorDestroy(t); + t = nullptr; + } + _k_cache.clear(); + _v_cache.clear(); + _past_len = 0; + _cache_inited = false; +} + +void Decoder::resetKVCache() { + if (!_cache_inited) return; + _past_len = 0; +} + +void Decoder::setKVCacheEnabled(bool enabled) { + if (_kv_cache_enabled == enabled) return; + _kv_cache_enabled = enabled; + if (!enabled) { + releaseCache(); + } +} + +void Decoder::bindExternalKVContext(void *ctx, size_t past_len_tokens) { + _external_kv_ctx = ctx; + _external_past_len = past_len_tokens; + _external_cache_ready = false; + if (ctx) { + releaseCache(); + } +} + +void Decoder::clearExternalKVContext() { + _external_kv_ctx = nullptr; + _external_past_len = 0; + _external_cache_ready = false; +} + +bool Decoder::hasExternalKVContext() const { + return _external_kv_ctx != nullptr; +} + +int Decoder::exportKVContext(void *ctx_ptr, size_t block_tokens) { + if (!ctx_ptr) return -1; + if (!_kv_cache_enabled) return -2; + ensureCache(); + if (!_cache_inited) return -3; + + auto *ctx = reinterpret_cast(ctx_ptr); + if (!ctx) return -4; + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + if (ctx->dtype != _config.dtype || ctx->device != _device || ctx->device_id != device_id) return -5; + if (ctx->nlayer != _config.nlayer || ctx->nkvh != _config.nkvh || ctx->dh != _config.dh) return -6; + + llaisysQwen2KVContextDetachAll(ctx); + if (_past_len == 0) return 0; + + const size_t chunk_size = block_tokens > 0 ? block_tokens : _past_len; + size_t offset = 0; + while (offset < _past_len) { + const size_t used = std::min(chunk_size, _past_len - offset); + LlaisysQwen2KVBlockMeta meta{}; + meta.dtype = _config.dtype; + meta.nlayer = _config.nlayer; + meta.nh = _config.nh; + meta.nkvh = _config.nkvh; + meta.dh = _config.dh; + meta.max_tokens = used; + auto *block = llaisysQwen2KVBlockCreate(&meta, _device, device_id); + if (!block) { + llaisysQwen2KVContextDetachAll(ctx); + return -7; + } + if (llaisysQwen2KVBlockSetTokenCount(block, used) != 0) { + llaisysQwen2KVBlockRelease(block); + llaisysQwen2KVContextDetachAll(ctx); + return -8; + } + + bool copy_ok = true; + for (size_t layer = 0; layer < _config.nlayer && copy_ok; ++layer) { + llaisysTensor_t src_k = tensorSlice(_k_cache[layer], 0, offset, offset + used); + llaisysTensor_t src_v = tensorSlice(_v_cache[layer], 0, offset, offset + used); + llaisysTensor_t dst_k_full = llaisysQwen2KVBlockKeyTensor(block, layer); + llaisysTensor_t dst_v_full = llaisysQwen2KVBlockValueTensor(block, layer); + llaisysTensor_t dst_k = dst_k_full ? tensorSlice(dst_k_full, 0, 0, used) : nullptr; + llaisysTensor_t dst_v = dst_v_full ? tensorSlice(dst_v_full, 0, 0, used) : nullptr; + if (!src_k || !src_v || !dst_k || !dst_v) { + copy_ok = false; + } else { + ::llaisysRearrange(dst_k, src_k); + ::llaisysRearrange(dst_v, src_v); + } + destroy_if_not_null(src_k); + destroy_if_not_null(src_v); + destroy_if_not_null(dst_k); + destroy_if_not_null(dst_v); + } + + if (!copy_ok || llaisysQwen2KVContextAttachBlock(ctx, block) != 0) { + llaisysQwen2KVBlockRelease(block); + llaisysQwen2KVContextDetachAll(ctx); + return -9; + } + llaisysQwen2KVBlockRelease(block); + offset += used; + } + return 0; +} + +bool Decoder::recoverExternalCache() { + if (!_external_kv_ctx || _external_cache_ready) return true; + if (!_kv_cache_enabled) return false; + ensureCache(); + if (!_cache_inited) return false; + + auto *ctx = reinterpret_cast(_external_kv_ctx); + if (!ctx) return false; + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + if (ctx->dtype != _config.dtype || ctx->device != _device || ctx->device_id != device_id) return false; + if (ctx->nlayer != _config.nlayer || ctx->nkvh != _config.nkvh || ctx->dh != _config.dh) return false; + + size_t total_tokens = 0; + for (auto *blk : ctx->chain) { + if (!blk) return false; + if (blk->meta.dtype != _config.dtype || blk->device != _device || blk->device_id != device_id) return false; + if (blk->meta.nlayer != _config.nlayer || blk->meta.nkvh != _config.nkvh || blk->meta.dh != _config.dh) return false; + if (blk->used_tokens > blk->meta.max_tokens) return false; + total_tokens += blk->used_tokens; + if (total_tokens > _config.maxseq) return false; + } + + _past_len = 0; + for (size_t layer = 0; layer < _config.nlayer; ++layer) { + size_t offset = 0; + for (auto *blk : ctx->chain) { + const size_t used = blk->used_tokens; + if (used == 0) continue; + if (layer >= blk->k_layers.size() || layer >= blk->v_layers.size()) return false; + auto *k_block = blk->k_layers[layer]; + auto *v_block = blk->v_layers[layer]; + if (!k_block || !v_block) return false; + + llaisysTensor_t src_k = tensorSlice(k_block, 0, 0, used); + llaisysTensor_t src_v = tensorSlice(v_block, 0, 0, used); + llaisysTensor_t dst_k = tensorSlice(_k_cache[layer], 0, offset, offset + used); + llaisysTensor_t dst_v = tensorSlice(_v_cache[layer], 0, offset, offset + used); + if (!src_k || !src_v || !dst_k || !dst_v) { + destroy_if_not_null(src_k); + destroy_if_not_null(src_v); + destroy_if_not_null(dst_k); + destroy_if_not_null(dst_v); + return false; + } + ::llaisysRearrange(dst_k, src_k); + ::llaisysRearrange(dst_v, src_v); + tensorDestroy(src_k); + tensorDestroy(src_v); + tensorDestroy(dst_k); + tensorDestroy(dst_v); + offset += used; + } + } + + _past_len = total_tokens; + _external_past_len = total_tokens; + _external_cache_ready = true; + return true; +} + +bool Decoder::runHidden(const int64_t *token_ids, + size_t ntoken, + bool append_only, + const int64_t *segment_offsets, + size_t nseg, + size_t &past_len, + size_t &cur_len, + llaisysTensor_t &idx, + llaisysTensor_t &pos_ids, + llaisysTensor_t &hidden) { + idx = nullptr; + pos_ids = nullptr; + hidden = nullptr; + if (!token_ids || ntoken == 0) return false; + if (!_weights || !_weights->in_embed) return false; + const bool segmented = (segment_offsets != nullptr && nseg > 0); + if (segmented && append_only) return false; + + if (!segmented) { + ensureCache(); + if (_external_kv_ctx && !_external_cache_ready) { + if (!recoverExternalCache()) { + clearExternalKVContext(); + _past_len = 0; + } + } + } + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + // Segmented packed prefill treats each call as an independent packed forward. + // Reusing decoder KV cache here breaks offset domains (past_len vs packed offsets). + const bool can_cache = (!segmented) && _cache_inited && _config.maxseq > 0; + if (can_cache && ntoken > _config.maxseq) return false; + past_len = can_cache ? _past_len : 0; + if (append_only && !can_cache) { + return false; + } + if (!append_only) { + if (!can_cache || ntoken <= past_len) { + past_len = 0; + if (can_cache) _past_len = 0; + } + cur_len = ntoken - past_len; + } else { + cur_len = ntoken; + } + if (cur_len == 0) return false; + if (trace_enabled()) { + std::cerr << "[TRACE] Decoder cache: enabled=" << (_kv_cache_enabled ? 1 : 0) + << " inited=" << (_cache_inited ? 1 : 0) + << " can_cache=" << (can_cache ? 1 : 0) + << " past_len=" << past_len + << " cur_len=" << cur_len + << " ntoken=" << ntoken << std::endl; + } + const int64_t *new_tokens = append_only ? token_ids : (token_ids + past_len); + if (can_cache) { + if (_k_cache.size() != _config.nlayer || _v_cache.size() != _config.nlayer) return false; + if (past_len + cur_len > _config.maxseq) return false; + } + + trace("begin"); + // 1) token ids -> embedding + size_t idx_shape[1] = {cur_len}; + idx = tensorCreate(idx_shape, 1, LLAISYS_DTYPE_I64, _device, device_id); + if (!require_tensor(idx, "idx")) return false; + tensorLoad(idx, new_tokens); + + size_t hidden_shape[2] = {cur_len, _config.hs}; + hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + if (!require_tensor(hidden, "hidden")) { + tensorDestroy(idx); + idx = nullptr; + return false; + } + + trace("embedding"); + ::llaisysEmbedding(hidden, idx, _weights->in_embed); + + // 2) position ids for RoPE + std::vector pos_buf(cur_len); + for (size_t i = 0; i < cur_len; ++i) pos_buf[i] = static_cast(past_len + i); + trace("pos_ids"); + pos_ids = tensorCreate(idx_shape, 1, LLAISYS_DTYPE_I64, _device, device_id); + if (!require_tensor(pos_ids, "pos_ids")) { + tensorDestroy(hidden); + tensorDestroy(idx); + hidden = nullptr; + idx = nullptr; + return false; + } + tensorLoad(pos_ids, pos_buf.data()); + + // 3) Attention + MLP blocks + const float scale = 1.0f / std::sqrt(static_cast(_config.dh)); + for (size_t layer = 0; layer < _config.nlayer; ++layer) { + trace("attn.weights.check"); + if (!_weights->attn_norm_w || !_weights->attn_q_w || !_weights->attn_k_w || !_weights->attn_v_w || + !_weights->attn_o_w || !_weights->mlp_norm_w || !_weights->mlp_gate_w || !_weights->mlp_up_w || + !_weights->mlp_down_w) { + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + if (!_weights->attn_norm_w[layer] || !_weights->attn_q_w[layer] || !_weights->attn_k_w[layer] || + !_weights->attn_v_w[layer] || !_weights->attn_o_w[layer] || !_weights->mlp_norm_w[layer] || + !_weights->mlp_gate_w[layer] || !_weights->mlp_up_w[layer] || !_weights->mlp_down_w[layer]) { + std::cerr << "[ERROR] Decoder: missing weights at layer " << layer << std::endl; + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + + trace("attn.norm"); + llaisysTensor_t norm = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + if (!require_tensor(norm, "attn.norm")) { + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + ::llaisysRmsNorm(norm, hidden, _weights->attn_norm_w[layer], _config.epsilon); + + trace("attn.qkv"); + size_t q2d_shape[2] = {cur_len, _config.nh * _config.dh}; + size_t kv2d_shape[2] = {cur_len, _config.nkvh * _config.dh}; + llaisysTensor_t q2d = tensorCreate(q2d_shape, 2, _config.dtype, _device, device_id); + llaisysTensor_t k2d = tensorCreate(kv2d_shape, 2, _config.dtype, _device, device_id); + llaisysTensor_t v2d = tensorCreate(kv2d_shape, 2, _config.dtype, _device, device_id); + if (!require_tensor(q2d, "attn.q2d") || !require_tensor(k2d, "attn.k2d") || + !require_tensor(v2d, "attn.v2d")) { + tensorDestroy(norm); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + if (q2d) tensorDestroy(q2d); + if (k2d) tensorDestroy(k2d); + if (v2d) tensorDestroy(v2d); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + + llaisysTensor_t q_bias = (_weights->attn_q_b && _weights->attn_q_b[layer]) ? _weights->attn_q_b[layer] : nullptr; + llaisysTensor_t k_bias = (_weights->attn_k_b && _weights->attn_k_b[layer]) ? _weights->attn_k_b[layer] : nullptr; + llaisysTensor_t v_bias = (_weights->attn_v_b && _weights->attn_v_b[layer]) ? _weights->attn_v_b[layer] : nullptr; + + ::llaisysLinear(q2d, norm, _weights->attn_q_w[layer], q_bias); + ::llaisysLinear(k2d, norm, _weights->attn_k_w[layer], k_bias); + ::llaisysLinear(v2d, norm, _weights->attn_v_w[layer], v_bias); + + trace("attn.view"); + size_t q3d_shape[3] = {cur_len, _config.nh, _config.dh}; + size_t k3d_shape[3] = {cur_len, _config.nkvh, _config.dh}; + llaisysTensor_t q3d = tensorView(q2d, q3d_shape, 3); + llaisysTensor_t k3d = tensorView(k2d, k3d_shape, 3); + llaisysTensor_t v3d = tensorView(v2d, k3d_shape, 3); + if (!require_tensor(q3d, "attn.q3d") || !require_tensor(k3d, "attn.k3d") || + !require_tensor(v3d, "attn.v3d")) { + tensorDestroy(norm); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + tensorDestroy(q2d); + tensorDestroy(k2d); + tensorDestroy(v2d); + if (q3d) tensorDestroy(q3d); + if (k3d) tensorDestroy(k3d); + if (v3d) tensorDestroy(v3d); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + + trace("attn.rope"); + llaisysTensor_t q_rope = tensorCreate(q3d_shape, 3, _config.dtype, _device, device_id); + llaisysTensor_t k_rope = tensorCreate(k3d_shape, 3, _config.dtype, _device, device_id); + if (!require_tensor(q_rope, "attn.q_rope") || !require_tensor(k_rope, "attn.k_rope")) { + tensorDestroy(norm); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + tensorDestroy(q2d); + tensorDestroy(k2d); + tensorDestroy(v2d); + tensorDestroy(q3d); + tensorDestroy(k3d); + tensorDestroy(v3d); + if (q_rope) tensorDestroy(q_rope); + if (k_rope) tensorDestroy(k_rope); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + ::llaisysROPE(q_rope, q3d, pos_ids, _config.theta); + ::llaisysROPE(k_rope, k3d, pos_ids, _config.theta); + + if (can_cache) { + trace("attn.cache.write"); + llaisysTensor_t k_slot = tensorSlice(_k_cache[layer], 0, past_len, past_len + cur_len); + llaisysTensor_t v_slot = tensorSlice(_v_cache[layer], 0, past_len, past_len + cur_len); + ::llaisysRearrange(k_slot, k_rope); + ::llaisysRearrange(v_slot, v3d); + tensorDestroy(k_slot); + tensorDestroy(v_slot); + } + + llaisysTensor_t k_attn = k_rope; + llaisysTensor_t v_attn = v3d; + llaisysTensor_t k_cache_view = nullptr; + llaisysTensor_t v_cache_view = nullptr; + if (can_cache) { + trace("attn.cache.read"); + size_t total_len = past_len + cur_len; + k_cache_view = tensorSlice(_k_cache[layer], 0, 0, total_len); + v_cache_view = tensorSlice(_v_cache[layer], 0, 0, total_len); + k_attn = k_cache_view; + v_attn = v_cache_view; + } + + trace("attn.softmax"); + llaisysTensor_t attn_out3d = tensorCreate(q3d_shape, 3, _config.dtype, _device, device_id); + if (!require_tensor(attn_out3d, "attn.out3d")) { + tensorDestroy(norm); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + tensorDestroy(q2d); + tensorDestroy(k2d); + tensorDestroy(v2d); + tensorDestroy(q3d); + tensorDestroy(k3d); + tensorDestroy(v3d); + tensorDestroy(q_rope); + tensorDestroy(k_rope); + if (k_cache_view) tensorDestroy(k_cache_view); + if (v_cache_view) tensorDestroy(v_cache_view); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + if (segmented) { + ::llaisysSelfAttentionSegmented( + attn_out3d, + q_rope, + k_attn, + v_attn, + scale, + segment_offsets, + segment_offsets, + nseg); + } else { + ::llaisysSelfAttention(attn_out3d, q_rope, k_attn, v_attn, scale); + } + if (k_cache_view) tensorDestroy(k_cache_view); + if (v_cache_view) tensorDestroy(v_cache_view); + + trace("attn.proj"); + llaisysTensor_t attn_out2d = tensorView(attn_out3d, q2d_shape, 2); + llaisysTensor_t proj_out = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + if (!require_tensor(attn_out2d, "attn.out2d") || !require_tensor(proj_out, "attn.proj_out")) { + tensorDestroy(norm); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + tensorDestroy(q2d); + tensorDestroy(k2d); + tensorDestroy(v2d); + tensorDestroy(q3d); + tensorDestroy(k3d); + tensorDestroy(v3d); + tensorDestroy(q_rope); + tensorDestroy(k_rope); + tensorDestroy(attn_out3d); + if (attn_out2d) tensorDestroy(attn_out2d); + if (proj_out) tensorDestroy(proj_out); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + if (!ensure_data(attn_out2d, "attn.proj.in") || !ensure_data(proj_out, "attn.proj.out") || + !ensure_data(_weights->attn_o_w[layer], "attn.proj.w")) { + tensorDestroy(norm); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + tensorDestroy(q2d); + tensorDestroy(k2d); + tensorDestroy(v2d); + tensorDestroy(q3d); + tensorDestroy(k3d); + tensorDestroy(v3d); + tensorDestroy(q_rope); + tensorDestroy(k_rope); + tensorDestroy(attn_out3d); + tensorDestroy(attn_out2d); + tensorDestroy(proj_out); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + ::llaisysLinear(proj_out, attn_out2d, _weights->attn_o_w[layer], nullptr); + + // Tensor parallel: allreduce after attn_o projection + if (_tp_size > 1 && _comm) { + size_t ndim = tensorGetNdim(proj_out); + size_t shape[4]; + tensorGetShape(proj_out, shape); + size_t count = 1; + for (size_t d = 0; d < ndim; ++d) count *= shape[d]; + auto backend = (_device == LLAISYS_DEVICE_ILUVATAR) ? LLAISYS_COMM_IXCCL : LLAISYS_COMM_NCCL; + auto *api = llaisys::device::getCommAPI(backend); + api->allreduce(tensorGetData(proj_out), tensorGetData(proj_out), + count, _config.dtype, LLAISYS_REDUCE_SUM, _comm, _comm_stream); + } + + trace("attn.residual"); + llaisysTensor_t new_hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + if (!require_tensor(new_hidden, "attn.residual")) { + tensorDestroy(norm); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + tensorDestroy(q2d); + tensorDestroy(k2d); + tensorDestroy(v2d); + tensorDestroy(q3d); + tensorDestroy(k3d); + tensorDestroy(v3d); + tensorDestroy(q_rope); + tensorDestroy(k_rope); + tensorDestroy(attn_out3d); + tensorDestroy(attn_out2d); + tensorDestroy(proj_out); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + ::llaisysAdd(new_hidden, hidden, proj_out); + + tensorDestroy(hidden); + hidden = new_hidden; + + tensorDestroy(norm); + tensorDestroy(q2d); + tensorDestroy(k2d); + tensorDestroy(v2d); + tensorDestroy(q3d); + tensorDestroy(k3d); + tensorDestroy(v3d); + tensorDestroy(q_rope); + tensorDestroy(k_rope); + tensorDestroy(attn_out3d); + tensorDestroy(attn_out2d); + tensorDestroy(proj_out); + + // 4) MLP + trace("mlp.norm"); + llaisysTensor_t mlp_norm = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + if (!require_tensor(mlp_norm, "mlp.norm")) { + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + ::llaisysRmsNorm(mlp_norm, hidden, _weights->mlp_norm_w[layer], _config.epsilon); + + trace("mlp.gate_up"); + size_t mlp_shape[2] = {cur_len, _config.di}; + llaisysTensor_t gate = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id); + llaisysTensor_t up = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id); + if (!require_tensor(gate, "mlp.gate") || !require_tensor(up, "mlp.up")) { + tensorDestroy(mlp_norm); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + if (gate) tensorDestroy(gate); + if (up) tensorDestroy(up); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + ::llaisysLinear(gate, mlp_norm, _weights->mlp_gate_w[layer], nullptr); + ::llaisysLinear(up, mlp_norm, _weights->mlp_up_w[layer], nullptr); + + trace("mlp.swiglu"); + llaisysTensor_t swiglu = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id); + if (!require_tensor(swiglu, "mlp.swiglu")) { + tensorDestroy(mlp_norm); + tensorDestroy(gate); + tensorDestroy(up); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + ::llaisysSwiGLU(swiglu, gate, up); + + trace("mlp.down"); + llaisysTensor_t mlp_out = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + if (!require_tensor(mlp_out, "mlp.down")) { + tensorDestroy(mlp_norm); + tensorDestroy(gate); + tensorDestroy(up); + tensorDestroy(swiglu); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + ::llaisysLinear(mlp_out, swiglu, _weights->mlp_down_w[layer], nullptr); + + // Tensor parallel: allreduce after mlp_down projection + if (_tp_size > 1 && _comm) { + size_t ndim = tensorGetNdim(mlp_out); + size_t shape[4]; + tensorGetShape(mlp_out, shape); + size_t count = 1; + for (size_t d = 0; d < ndim; ++d) count *= shape[d]; + auto backend = (_device == LLAISYS_DEVICE_ILUVATAR) ? LLAISYS_COMM_IXCCL : LLAISYS_COMM_NCCL; + auto *api = llaisys::device::getCommAPI(backend); + api->allreduce(tensorGetData(mlp_out), tensorGetData(mlp_out), + count, _config.dtype, LLAISYS_REDUCE_SUM, _comm, _comm_stream); + } + + trace("mlp.residual"); + llaisysTensor_t mlp_hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + if (!require_tensor(mlp_hidden, "mlp.residual")) { + tensorDestroy(mlp_norm); + tensorDestroy(gate); + tensorDestroy(up); + tensorDestroy(swiglu); + tensorDestroy(mlp_out); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + tensorDestroy(idx); + pos_ids = nullptr; + hidden = nullptr; + idx = nullptr; + return false; + } + ::llaisysAdd(mlp_hidden, hidden, mlp_out); + + tensorDestroy(hidden); + hidden = mlp_hidden; + + tensorDestroy(mlp_norm); + tensorDestroy(gate); + tensorDestroy(up); + tensorDestroy(swiglu); + tensorDestroy(mlp_out); + } + + if (can_cache) { + _past_len = past_len + cur_len; + } + + return true; +} + +bool Decoder::prefill(const int64_t *token_ids, size_t ntoken, llaisysTensor_t out_last_logits) { + if (!out_last_logits) return false; + if (!ensure_data(out_last_logits, "head.logits.out")) return false; + + size_t past_len = 0; + size_t cur_len = 0; + llaisysTensor_t idx = nullptr; + llaisysTensor_t pos_ids = nullptr; + llaisysTensor_t hidden = nullptr; + if (!runHidden(token_ids, ntoken, false, nullptr, 0, past_len, cur_len, idx, pos_ids, hidden)) return false; + + if (!_weights || !_weights->out_norm_w || !_weights->out_embed) { + tensorDestroy(idx); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + return false; + } + + trace("head.slice"); + llaisysTensor_t last_hidden = tensorSlice(hidden, 0, cur_len - 1, cur_len); + if (!require_tensor(last_hidden, "head.last_hidden")) { + tensorDestroy(idx); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + return false; + } + + size_t last_shape[2] = {1, _config.hs}; + trace("head.norm"); + llaisysTensor_t final_norm = tensorCreate(last_shape, 2, _config.dtype, _device, _device_ids.empty() ? 0 : _device_ids[0]); + if (!require_tensor(final_norm, "head.norm")) { + tensorDestroy(last_hidden); + tensorDestroy(idx); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + return false; + } + ::llaisysRmsNorm(final_norm, last_hidden, _weights->out_norm_w, _config.epsilon); + + trace("head.logits"); + ::llaisysLinear(out_last_logits, final_norm, _weights->out_embed, nullptr); + + tensorDestroy(last_hidden); + tensorDestroy(final_norm); + tensorDestroy(idx); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + return true; +} + +bool Decoder::prefillPacked(const int64_t *token_ids, + size_t ntoken, + const int64_t *token_offsets, + size_t nseq, + llaisysTensor_t out_last_logits) { + if (!out_last_logits || !token_ids || !token_offsets || nseq == 0 || ntoken == 0) return false; + if (!ensure_data(out_last_logits, "head.packed.logits.out")) return false; + if (tensorGetNdim(out_last_logits) != 2) return false; + size_t out_shape[2] = {0, 0}; + tensorGetShape(out_last_logits, out_shape); + if (out_shape[0] != nseq || out_shape[1] != _config.voc) return false; + if (token_offsets[0] != 0 || static_cast(token_offsets[nseq]) != ntoken) return false; + for (size_t i = 0; i < nseq; ++i) { + if (token_offsets[i] > token_offsets[i + 1]) return false; + if (token_offsets[i] == token_offsets[i + 1]) return false; + } + + size_t past_len = 0; + size_t cur_len = 0; + llaisysTensor_t idx = nullptr; + llaisysTensor_t pos_ids = nullptr; + llaisysTensor_t hidden = nullptr; + if (!runHidden(token_ids, ntoken, false, token_offsets, nseq, past_len, cur_len, idx, pos_ids, hidden)) return false; + + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + if (!_weights || !_weights->out_norm_w || !_weights->out_embed) { + tensorDestroy(idx); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + return false; + } + + bool ok = true; + for (size_t i = 0; i < nseq && ok; ++i) { + const size_t seg_end = static_cast(token_offsets[i + 1]); + const size_t last_pos = seg_end - 1; + llaisysTensor_t last_hidden = tensorSlice(hidden, 0, last_pos, last_pos + 1); + llaisysTensor_t row_logits = tensorSlice(out_last_logits, 0, i, i + 1); + size_t last_shape[2] = {1, _config.hs}; + llaisysTensor_t final_norm = tensorCreate(last_shape, 2, _config.dtype, _device, device_id); + if (!last_hidden || !row_logits || !final_norm) { + ok = false; + } else { + ::llaisysRmsNorm(final_norm, last_hidden, _weights->out_norm_w, _config.epsilon); + ::llaisysLinear(row_logits, final_norm, _weights->out_embed, nullptr); + } + destroy_if_not_null(last_hidden); + destroy_if_not_null(row_logits); + destroy_if_not_null(final_norm); + } + + tensorDestroy(idx); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + return ok; +} + +bool Decoder::decodePacked(const int64_t *token_ids, + size_t nseq, + const std::vector &contexts, + llaisysTensor_t out_last_logits, + size_t block_tokens_hint) { + if (!token_ids || nseq == 0 || contexts.size() != nseq || !out_last_logits) return false; + if (!ensure_data(out_last_logits, "head.decode_packed.logits.out")) return false; + if (tensorGetNdim(out_last_logits) != 2) return false; + size_t out_shape[2] = {0, 0}; + tensorGetShape(out_last_logits, out_shape); + if (out_shape[0] != nseq || out_shape[1] != _config.voc) return false; + + const int device_id = _device_ids.empty() ? 0 : _device_ids[0]; + std::vector past_lens(nseq, 0); + std::vector q_offsets(nseq + 1, 0); + std::vector kv_offsets(nseq + 1, 0); + std::vector append_blocks(nseq, nullptr); + std::vector append_pos(nseq, 0); + size_t kv_total = 0; + + for (size_t i = 0; i < nseq; ++i) { + auto *ctx = contexts[i]; + if (!ctx) return false; + if (ctx->dtype != _config.dtype || ctx->device != _device || ctx->device_id != device_id) return false; + if (ctx->nlayer != _config.nlayer || ctx->nkvh != _config.nkvh || ctx->dh != _config.dh) return false; + const size_t past = llaisysQwen2KVContextTokenCount(ctx); + if (past + 1 > _config.maxseq) return false; + past_lens[i] = past; + q_offsets[i + 1] = static_cast(i + 1); + kv_total += past + 1; + kv_offsets[i + 1] = static_cast(kv_total); + + LlaisysQwen2KVBlock *target = nullptr; + size_t pos = 0; + if (!ctx->chain.empty()) { + auto *last = ctx->chain.back(); + if (last && last->used_tokens < last->meta.max_tokens) { + target = last; + pos = last->used_tokens; + } + } + if (!target) { + const size_t max_tokens = block_tokens_hint > 0 ? block_tokens_hint : 64; + LlaisysQwen2KVBlockMeta meta{}; + meta.dtype = _config.dtype; + meta.nlayer = _config.nlayer; + meta.nh = _config.nh; + meta.nkvh = _config.nkvh; + meta.dh = _config.dh; + meta.max_tokens = max_tokens; + auto *blk = llaisysQwen2KVBlockCreate(&meta, _device, device_id); + if (!blk) return false; + if (llaisysQwen2KVContextAttachBlock(ctx, blk) != 0) { + llaisysQwen2KVBlockRelease(blk); + return false; + } + llaisysQwen2KVBlockRelease(blk); + if (ctx->chain.empty() || !ctx->chain.back()) return false; + target = ctx->chain.back(); + pos = target->used_tokens; + } + append_blocks[i] = target; + append_pos[i] = pos; + } + + size_t idx_shape[1] = {nseq}; + size_t hidden_shape[2] = {nseq, _config.hs}; + llaisysTensor_t idx = tensorCreate(idx_shape, 1, LLAISYS_DTYPE_I64, _device, device_id); + llaisysTensor_t pos_ids = tensorCreate(idx_shape, 1, LLAISYS_DTYPE_I64, _device, device_id); + llaisysTensor_t hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + if (!idx || !pos_ids || !hidden) { + destroy_if_not_null(idx); + destroy_if_not_null(pos_ids); + destroy_if_not_null(hidden); + return false; + } + tensorLoad(idx, token_ids); + std::vector pos_buf(nseq, 0); + for (size_t i = 0; i < nseq; ++i) pos_buf[i] = static_cast(past_lens[i]); + tensorLoad(pos_ids, pos_buf.data()); + ::llaisysEmbedding(hidden, idx, _weights->in_embed); + + const float scale = 1.0f / std::sqrt(static_cast(_config.dh)); + bool ok = true; + for (size_t layer = 0; layer < _config.nlayer && ok; ++layer) { + llaisysTensor_t norm = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + size_t q2d_shape[2] = {nseq, _config.nh * _config.dh}; + size_t kv2d_shape[2] = {nseq, _config.nkvh * _config.dh}; + llaisysTensor_t q2d = tensorCreate(q2d_shape, 2, _config.dtype, _device, device_id); + llaisysTensor_t k2d = tensorCreate(kv2d_shape, 2, _config.dtype, _device, device_id); + llaisysTensor_t v2d = tensorCreate(kv2d_shape, 2, _config.dtype, _device, device_id); + if (!norm || !q2d || !k2d || !v2d) { + destroy_if_not_null(norm); + destroy_if_not_null(q2d); + destroy_if_not_null(k2d); + destroy_if_not_null(v2d); + ok = false; + break; + } + ::llaisysRmsNorm(norm, hidden, _weights->attn_norm_w[layer], _config.epsilon); + llaisysTensor_t q_bias = (_weights->attn_q_b && _weights->attn_q_b[layer]) ? _weights->attn_q_b[layer] : nullptr; + llaisysTensor_t k_bias = (_weights->attn_k_b && _weights->attn_k_b[layer]) ? _weights->attn_k_b[layer] : nullptr; + llaisysTensor_t v_bias = (_weights->attn_v_b && _weights->attn_v_b[layer]) ? _weights->attn_v_b[layer] : nullptr; + ::llaisysLinear(q2d, norm, _weights->attn_q_w[layer], q_bias); + ::llaisysLinear(k2d, norm, _weights->attn_k_w[layer], k_bias); + ::llaisysLinear(v2d, norm, _weights->attn_v_w[layer], v_bias); + + size_t q3d_shape[3] = {nseq, _config.nh, _config.dh}; + size_t k3d_shape[3] = {nseq, _config.nkvh, _config.dh}; + llaisysTensor_t q3d = tensorView(q2d, q3d_shape, 3); + llaisysTensor_t k3d = tensorView(k2d, k3d_shape, 3); + llaisysTensor_t v3d = tensorView(v2d, k3d_shape, 3); + llaisysTensor_t q_rope = tensorCreate(q3d_shape, 3, _config.dtype, _device, device_id); + llaisysTensor_t k_rope = tensorCreate(k3d_shape, 3, _config.dtype, _device, device_id); + if (!q3d || !k3d || !v3d || !q_rope || !k_rope) { + destroy_if_not_null(norm); + destroy_if_not_null(q2d); + destroy_if_not_null(k2d); + destroy_if_not_null(v2d); + destroy_if_not_null(q3d); + destroy_if_not_null(k3d); + destroy_if_not_null(v3d); + destroy_if_not_null(q_rope); + destroy_if_not_null(k_rope); + ok = false; + break; + } + ::llaisysROPE(q_rope, q3d, pos_ids, _config.theta); + ::llaisysROPE(k_rope, k3d, pos_ids, _config.theta); + + size_t kv_all_shape[3] = {kv_total, _config.nkvh, _config.dh}; + llaisysTensor_t k_all = tensorCreate(kv_all_shape, 3, _config.dtype, _device, device_id); + llaisysTensor_t v_all = tensorCreate(kv_all_shape, 3, _config.dtype, _device, device_id); + if (!k_all || !v_all) { + destroy_if_not_null(norm); + destroy_if_not_null(q2d); + destroy_if_not_null(k2d); + destroy_if_not_null(v2d); + destroy_if_not_null(q3d); + destroy_if_not_null(k3d); + destroy_if_not_null(v3d); + destroy_if_not_null(q_rope); + destroy_if_not_null(k_rope); + destroy_if_not_null(k_all); + destroy_if_not_null(v_all); + ok = false; + break; + } + for (size_t i = 0; i < nseq && ok; ++i) { + auto *ctx = contexts[i]; + const size_t kv_begin = static_cast(kv_offsets[i]); + const size_t past = past_lens[i]; + size_t copied = 0; + for (auto *blk : ctx->chain) { + if (!blk) { + ok = false; + break; + } + const size_t used = blk->used_tokens; + if (used == 0) continue; + llaisysTensor_t src_k = tensorSlice(blk->k_layers[layer], 0, 0, used); + llaisysTensor_t src_v = tensorSlice(blk->v_layers[layer], 0, 0, used); + llaisysTensor_t dst_k = tensorSlice(k_all, 0, kv_begin + copied, kv_begin + copied + used); + llaisysTensor_t dst_v = tensorSlice(v_all, 0, kv_begin + copied, kv_begin + copied + used); + if (!src_k || !src_v || !dst_k || !dst_v) { + destroy_if_not_null(src_k); + destroy_if_not_null(src_v); + destroy_if_not_null(dst_k); + destroy_if_not_null(dst_v); + ok = false; + break; + } + ::llaisysRearrange(dst_k, src_k); + ::llaisysRearrange(dst_v, src_v); + tensorDestroy(src_k); + tensorDestroy(src_v); + tensorDestroy(dst_k); + tensorDestroy(dst_v); + copied += used; + } + if (!ok || copied != past) { + ok = false; + break; + } + + const size_t kv_new_pos = kv_begin + past; + llaisysTensor_t src_new_k = tensorSlice(k_rope, 0, i, i + 1); + llaisysTensor_t src_new_v = tensorSlice(v3d, 0, i, i + 1); + llaisysTensor_t dst_new_k = tensorSlice(k_all, 0, kv_new_pos, kv_new_pos + 1); + llaisysTensor_t dst_new_v = tensorSlice(v_all, 0, kv_new_pos, kv_new_pos + 1); + llaisysTensor_t dst_ctx_k = tensorSlice(append_blocks[i]->k_layers[layer], 0, append_pos[i], append_pos[i] + 1); + llaisysTensor_t dst_ctx_v = tensorSlice(append_blocks[i]->v_layers[layer], 0, append_pos[i], append_pos[i] + 1); + if (!src_new_k || !src_new_v || !dst_new_k || !dst_new_v || !dst_ctx_k || !dst_ctx_v) { + destroy_if_not_null(src_new_k); + destroy_if_not_null(src_new_v); + destroy_if_not_null(dst_new_k); + destroy_if_not_null(dst_new_v); + destroy_if_not_null(dst_ctx_k); + destroy_if_not_null(dst_ctx_v); + ok = false; + break; + } + ::llaisysRearrange(dst_new_k, src_new_k); + ::llaisysRearrange(dst_new_v, src_new_v); + ::llaisysRearrange(dst_ctx_k, src_new_k); + ::llaisysRearrange(dst_ctx_v, src_new_v); + tensorDestroy(src_new_k); + tensorDestroy(src_new_v); + tensorDestroy(dst_new_k); + tensorDestroy(dst_new_v); + tensorDestroy(dst_ctx_k); + tensorDestroy(dst_ctx_v); + } + + llaisysTensor_t attn_out3d = nullptr; + llaisysTensor_t attn_out2d = nullptr; + llaisysTensor_t proj_out = nullptr; + llaisysTensor_t attn_hidden = nullptr; + llaisysTensor_t mlp_norm = nullptr; + llaisysTensor_t gate = nullptr; + llaisysTensor_t up = nullptr; + llaisysTensor_t swiglu = nullptr; + llaisysTensor_t mlp_out = nullptr; + llaisysTensor_t mlp_hidden = nullptr; + + if (ok) { + attn_out3d = tensorCreate(q3d_shape, 3, _config.dtype, _device, device_id); + if (!attn_out3d) ok = false; + } + if (ok) { + ::llaisysSelfAttentionSegmented( + attn_out3d, q_rope, k_all, v_all, scale, q_offsets.data(), kv_offsets.data(), nseq); + attn_out2d = tensorView(attn_out3d, q2d_shape, 2); + proj_out = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); attn_hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + if (!attn_out2d || !proj_out || !attn_hidden) ok = false; + } + if (ok) { + ::llaisysLinear(proj_out, attn_out2d, _weights->attn_o_w[layer], nullptr); + ::llaisysAdd(attn_hidden, hidden, proj_out); + } + + if (ok) { + mlp_norm = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + size_t mlp_shape[2] = {nseq, _config.di}; + gate = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id); + up = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id); + swiglu = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id); + mlp_out = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + mlp_hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); + if (!mlp_norm || !gate || !up || !swiglu || !mlp_out || !mlp_hidden) { + ok = false; + } + } + if (ok) { + ::llaisysRmsNorm(mlp_norm, attn_hidden, _weights->mlp_norm_w[layer], _config.epsilon); + ::llaisysLinear(gate, mlp_norm, _weights->mlp_gate_w[layer], nullptr); + ::llaisysLinear(up, mlp_norm, _weights->mlp_up_w[layer], nullptr); + ::llaisysSwiGLU(swiglu, gate, up); + ::llaisysLinear(mlp_out, swiglu, _weights->mlp_down_w[layer], nullptr); + ::llaisysAdd(mlp_hidden, attn_hidden, mlp_out); + } + + if (ok) { + tensorDestroy(hidden); + hidden = mlp_hidden; + mlp_hidden = nullptr; + } + + destroy_if_not_null(norm); + destroy_if_not_null(q2d); + destroy_if_not_null(k2d); + destroy_if_not_null(v2d); + destroy_if_not_null(q3d); + destroy_if_not_null(k3d); + destroy_if_not_null(v3d); + destroy_if_not_null(q_rope); + destroy_if_not_null(k_rope); + destroy_if_not_null(k_all); + destroy_if_not_null(v_all); + destroy_if_not_null(attn_out3d); + destroy_if_not_null(attn_out2d); + destroy_if_not_null(proj_out); + destroy_if_not_null(attn_hidden); + destroy_if_not_null(mlp_norm); + destroy_if_not_null(gate); + destroy_if_not_null(up); + destroy_if_not_null(swiglu); + destroy_if_not_null(mlp_out); + destroy_if_not_null(mlp_hidden); + } + + if (ok) { + for (size_t i = 0; i < nseq; ++i) { + if (append_blocks[i] && append_blocks[i]->used_tokens < append_pos[i] + 1) { + append_blocks[i]->used_tokens = append_pos[i] + 1; + } + } + for (size_t i = 0; i < nseq && ok; ++i) { + llaisysTensor_t last_hidden = tensorSlice(hidden, 0, i, i + 1); + llaisysTensor_t row_logits = tensorSlice(out_last_logits, 0, i, i + 1); + size_t last_shape[2] = {1, _config.hs}; + llaisysTensor_t final_norm = tensorCreate(last_shape, 2, _config.dtype, _device, device_id); + if (!last_hidden || !row_logits || !final_norm) { + ok = false; + } else { + ::llaisysRmsNorm(final_norm, last_hidden, _weights->out_norm_w, _config.epsilon); + ::llaisysLinear(row_logits, final_norm, _weights->out_embed, nullptr); + } + destroy_if_not_null(last_hidden); + destroy_if_not_null(row_logits); + destroy_if_not_null(final_norm); + } + } + + tensorDestroy(idx); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + return ok; +} + +bool Decoder::decodeStep(const int64_t *token_ids, size_t ntoken, llaisysTensor_t out_last_logits) { + if (!out_last_logits) return false; + if (!ensure_data(out_last_logits, "head.logits.out")) return false; + + size_t past_len = 0; + size_t cur_len = 0; + llaisysTensor_t idx = nullptr; + llaisysTensor_t pos_ids = nullptr; + llaisysTensor_t hidden = nullptr; + if (!runHidden(token_ids, ntoken, true, nullptr, 0, past_len, cur_len, idx, pos_ids, hidden)) return false; + + if (!_weights || !_weights->out_norm_w || !_weights->out_embed) { + tensorDestroy(idx); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + return false; + } + + trace("head.slice"); + llaisysTensor_t last_hidden = tensorSlice(hidden, 0, cur_len - 1, cur_len); + if (!require_tensor(last_hidden, "head.last_hidden")) { + tensorDestroy(idx); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + return false; + } + + size_t last_shape[2] = {1, _config.hs}; + trace("head.norm"); + llaisysTensor_t final_norm = tensorCreate(last_shape, 2, _config.dtype, _device, _device_ids.empty() ? 0 : _device_ids[0]); + if (!require_tensor(final_norm, "head.norm")) { + tensorDestroy(last_hidden); + tensorDestroy(idx); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + return false; + } + ::llaisysRmsNorm(final_norm, last_hidden, _weights->out_norm_w, _config.epsilon); + + trace("head.logits"); + ::llaisysLinear(out_last_logits, final_norm, _weights->out_embed, nullptr); + + tensorDestroy(last_hidden); + tensorDestroy(final_norm); + tensorDestroy(idx); + tensorDestroy(pos_ids); + tensorDestroy(hidden); + return true; +} + +} // namespace llaisys::models::transformer diff --git a/src/models/transformer/decoder/decoder.hpp b/src/models/transformer/decoder/decoder.hpp new file mode 100644 index 000000000..5784a849e --- /dev/null +++ b/src/models/transformer/decoder/decoder.hpp @@ -0,0 +1,95 @@ +#pragma once + +#include "llaisys/models/qwen2.h" +#include "llaisys/comm.h" +#include "llaisys/tensor.h" + +#include +#include +#include + +namespace llaisys::models::transformer { + +struct DecoderConfig { + llaisysDataType_t dtype{}; + size_t nlayer{}; + size_t hs{}; + size_t nh{}; + size_t nkvh{}; + size_t dh{}; + size_t di{}; + size_t maxseq{}; + size_t voc{}; + float epsilon{}; + float theta{}; +}; + +class Decoder { +public: + Decoder(const DecoderConfig &config, + const LlaisysQwen2Weights *weights, + llaisysDeviceType_t device, + const std::vector &device_ids); + ~Decoder(); + + // Prefill with a full sequence, returns last-step logits. + bool prefill(const int64_t *token_ids, size_t ntoken, llaisysTensor_t out_last_logits); + // Prefill packed independent sequences, outputs one logits row per sequence. + bool prefillPacked(const int64_t *token_ids, + size_t ntoken, + const int64_t *token_offsets, + size_t nseq, + llaisysTensor_t out_last_logits); + + // Decode with only new tokens (append-only), returns last-step logits. + bool decodeStep(const int64_t *token_ids, size_t ntoken, llaisysTensor_t out_last_logits); + // Decode one token per sequence in a packed batch with per-sequence KV contexts. + bool decodePacked(const int64_t *token_ids, + size_t nseq, + const std::vector &contexts, + llaisysTensor_t out_last_logits, + size_t block_tokens_hint); + + void resetKVCache(); + + void setKVCacheEnabled(bool enabled); + void bindExternalKVContext(void *ctx, size_t past_len_tokens); + void clearExternalKVContext(); + bool hasExternalKVContext() const; + int exportKVContext(void *ctx, size_t block_tokens); + + void setTensorParallel(llaisysComm_t comm, llaisysStream_t stream, int tp_size); + +private: + bool recoverExternalCache(); + bool runHidden(const int64_t *token_ids, + size_t ntoken, + bool append_only, + const int64_t *segment_offsets, + size_t nseg, + size_t &past_len, + size_t &cur_len, + llaisysTensor_t &idx, + llaisysTensor_t &pos_ids, + llaisysTensor_t &hidden); + void ensureCache(); + void releaseCache(); + + DecoderConfig _config{}; + const LlaisysQwen2Weights *_weights{nullptr}; + llaisysDeviceType_t _device{}; + std::vector _device_ids; + std::vector _k_cache; + std::vector _v_cache; + size_t _past_len{0}; + bool _cache_inited{false}; + bool _kv_cache_enabled{true}; + void *_external_kv_ctx{nullptr}; + size_t _external_past_len{0}; + bool _external_cache_ready{false}; + llaisysComm_t _comm{nullptr}; + llaisysStream_t _comm_stream{nullptr}; + int _tp_size{1}; +}; + +} // namespace llaisys::models::transformer diff --git a/src/ops/add/cpu/add_cpu.cpp b/src/ops/add/cpu/add_cpu.cpp index 47f6a3d49..04d499d7b 100644 --- a/src/ops/add/cpu/add_cpu.cpp +++ b/src/ops/add/cpu/add_cpu.cpp @@ -5,29 +5,29 @@ #include template -void add_(T *c, const T *a, const T *b, size_t numel) { - for (size_t i = 0; i < numel; i++) { - if constexpr (std::is_same_v || std::is_same_v) { - c[i] = llaisys::utils::cast(llaisys::utils::cast(a[i]) + llaisys::utils::cast(b[i])); - } else { - c[i] = a[i] + b[i]; + void add_(T *c, const T *a, const T *b, size_t numel) { + for (size_t i = 0; i < numel; i++) { + if constexpr (std::is_same_v || std::is_same_v) { + c[i] = llaisys::utils::cast(llaisys::utils::cast(a[i]) + llaisys::utils::cast(b[i])); + } else { + c[i] = a[i] + b[i]; + } } } -} namespace llaisys::ops::cpu { -void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel) { - switch (type) { - case LLAISYS_DTYPE_F32: - return add_(reinterpret_cast(c), reinterpret_cast(a), reinterpret_cast(b), numel); - case LLAISYS_DTYPE_BF16: - return add_(reinterpret_cast(c), reinterpret_cast(a), - reinterpret_cast(b), numel); - case LLAISYS_DTYPE_F16: - return add_(reinterpret_cast(c), reinterpret_cast(a), - reinterpret_cast(b), numel); - default: - EXCEPTION_UNSUPPORTED_DATATYPE(type); + void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel) { + switch (type) { + case LLAISYS_DTYPE_F32: + return add_(reinterpret_cast(c), reinterpret_cast(a), reinterpret_cast(b), numel); + case LLAISYS_DTYPE_BF16: + return add_(reinterpret_cast(c), reinterpret_cast(a), + reinterpret_cast(b), numel); + case LLAISYS_DTYPE_F16: + return add_(reinterpret_cast(c), reinterpret_cast(a), + reinterpret_cast(b), numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } } -} } // namespace llaisys::ops::cpu diff --git a/src/ops/add/cpu/add_cpu.hpp b/src/ops/add/cpu/add_cpu.hpp index 34d809a11..20f5396ef 100644 --- a/src/ops/add/cpu/add_cpu.hpp +++ b/src/ops/add/cpu/add_cpu.hpp @@ -4,5 +4,5 @@ #include namespace llaisys::ops::cpu { -void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t size); + void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t size); } \ No newline at end of file diff --git a/src/ops/add/nvidia/add_nvidia.cu b/src/ops/add/nvidia/add_nvidia.cu new file mode 100644 index 000000000..49b7208cf --- /dev/null +++ b/src/ops/add/nvidia/add_nvidia.cu @@ -0,0 +1,41 @@ +#include "add_nvidia.hpp" + +#include "../../../device/nvidia/cuda_utils.hpp" + +namespace llaisys::ops::nvidia { +namespace { +template +__global__ void add_kernel(T *c, const T *a, const T *b, size_t numel) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < numel) { + float av = llaisys::device::nvidia::ScalarOps::load(a + idx); + float bv = llaisys::device::nvidia::ScalarOps::load(b + idx); + llaisys::device::nvidia::ScalarOps::store(c + idx, av + bv); + } +} + +template +void launch_add(T *c, const T *a, const T *b, size_t numel) { + const int threads = 256; + const int blocks = static_cast((numel + threads - 1) / threads); + add_kernel<<>>(c, a, b, numel); + llaisys::device::nvidia::cuda_check(cudaGetLastError()); +} +} // namespace + +void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel) { + switch (type) { + case LLAISYS_DTYPE_F32: + return launch_add(reinterpret_cast(c), reinterpret_cast(a), + reinterpret_cast(b), numel); + case LLAISYS_DTYPE_BF16: + return launch_add(reinterpret_cast(c), reinterpret_cast(a), + reinterpret_cast(b), numel); + case LLAISYS_DTYPE_F16: + return launch_add(reinterpret_cast(c), reinterpret_cast(a), + reinterpret_cast(b), numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/add/nvidia/add_nvidia.hpp b/src/ops/add/nvidia/add_nvidia.hpp new file mode 100644 index 000000000..8424ad596 --- /dev/null +++ b/src/ops/add/nvidia/add_nvidia.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include "../../../utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel); +} diff --git a/src/ops/add/op.cpp b/src/ops/add/op.cpp index a057330d7..fea297cf7 100644 --- a/src/ops/add/op.cpp +++ b/src/ops/add/op.cpp @@ -4,9 +4,16 @@ #include "../../utils.hpp" #include "cpu/add_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/add_nvidia.hpp" +#endif +#ifdef ENABLE_ILUVATAR_API +#include "nvidia/add_nvidia.hpp" +#endif namespace llaisys::ops { void add(tensor_t c, tensor_t a, tensor_t b) { + //确保所有张量都在同一设备上 CHECK_SAME_DEVICE(c, a, b); // Only support contiguous inputs with same shape for now. CHECK_SAME_SHAPE(c->shape(), a->shape(), b->shape()); @@ -25,8 +32,11 @@ void add(tensor_t c, tensor_t a, tensor_t b) { return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #ifdef ENABLE_NVIDIA_API case LLAISYS_DEVICE_NVIDIA: - TO_BE_IMPLEMENTED(); - return; + return nvidia::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); +#endif +#ifdef ENABLE_ILUVATAR_API + case LLAISYS_DEVICE_ILUVATAR: + return nvidia::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #endif default: EXCEPTION_UNSUPPORTED_DEVICE; diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 000000000..ab96b2b2f --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,45 @@ +#include "argmax_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +namespace { + template + void argmax_impl(std::byte *max_idx, std::byte *max_val, const std::byte *vals, size_t numel) { + // Work in float for fp16/bf16 comparisons to avoid precision issues. + using value_t = T; + const value_t *v = reinterpret_cast(vals); + int64_t *out_idx = reinterpret_cast(max_idx); + value_t *out_val = reinterpret_cast(max_val); + + float best = llaisys::utils::cast(v[0]); + int64_t best_idx = 0; + for (size_t i = 1; i < numel; ++i) { + float cur = llaisys::utils::cast(v[i]); + if (cur > best) { + best = cur; + best_idx = static_cast(i); + } + } + + *out_idx = best_idx; + *out_val = llaisys::utils::cast(best); + } +} + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) { + switch (type) { + case LLAISYS_DTYPE_F32: + return argmax_impl(max_idx, max_val, vals, numel); + case LLAISYS_DTYPE_BF16: + return argmax_impl(max_idx, max_val, vals, numel); + case LLAISYS_DTYPE_F16: + return argmax_impl(max_idx, max_val, vals, numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 000000000..26ae3ef03 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel); +} diff --git a/src/ops/argmax/nvidia/argmax_nvidia.cu b/src/ops/argmax/nvidia/argmax_nvidia.cu new file mode 100644 index 000000000..8e6b0abf7 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_nvidia.cu @@ -0,0 +1,42 @@ +#include "argmax_nvidia.hpp" + +#include "../../../device/nvidia/cuda_utils.hpp" + +namespace llaisys::ops::nvidia { +namespace { +template +__global__ void argmax_kernel(int64_t *out_idx, T *out_val, const T *vals, size_t numel) { + float best = llaisys::device::nvidia::ScalarOps::load(vals); + int64_t best_idx = 0; + for (size_t i = 1; i < numel; ++i) { + float v = llaisys::device::nvidia::ScalarOps::load(vals + i); + if (v > best) { + best = v; + best_idx = static_cast(i); + } + } + *out_idx = best_idx; + llaisys::device::nvidia::ScalarOps::store(out_val, best); +} + +template +void launch_argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, size_t numel) { + argmax_kernel<<<1, 1>>>(reinterpret_cast(max_idx), reinterpret_cast(max_val), + reinterpret_cast(vals), numel); + llaisys::device::nvidia::cuda_check(cudaGetLastError()); +} +} // namespace + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) { + switch (type) { + case LLAISYS_DTYPE_F32: + return launch_argmax(max_idx, max_val, vals, numel); + case LLAISYS_DTYPE_BF16: + return launch_argmax(max_idx, max_val, vals, numel); + case LLAISYS_DTYPE_F16: + return launch_argmax(max_idx, max_val, vals, numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/argmax/nvidia/argmax_nvidia.hpp b/src/ops/argmax/nvidia/argmax_nvidia.hpp new file mode 100644 index 000000000..e51bc30d5 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_nvidia.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include "../../../utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel); +} diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d426..f565d0a83 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,46 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/argmax_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/argmax_nvidia.hpp" +#endif +#ifdef ENABLE_ILUVATAR_API +#include "nvidia/argmax_nvidia.hpp" +#endif + + namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(max_idx, max_val, vals); + CHECK_SAME_DTYPE(max_val->dtype(), vals->dtype()); + ASSERT(max_idx->dtype() == LLAISYS_DTYPE_I64, "Argmax: max_idx must be int64."); + // 当前实现按扁平化处理多维输入,相当于对全部元素取全局最大 + ASSERT(vals->numel() > 0, "Argmax: input must be non-empty."); + ASSERT(max_idx->numel() == 1 && max_val->numel() == 1, "Argmax: outputs must have a single element."); + ASSERT(max_idx->isContiguous() && max_val->isContiguous() && vals->isContiguous(), + "Argmax: all tensors must be contiguous."); + + if (vals->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); + } + llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId()); + + switch (vals->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#endif +#ifdef ENABLE_ILUVATAR_API + case LLAISYS_DEVICE_ILUVATAR: + return nvidia::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 000000000..6839372d3 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,33 @@ +#include "embedding_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, + size_t index_numel, size_t embd_dim, size_t weight_rows) { + size_t elem_size = 0; + switch (type) { + case LLAISYS_DTYPE_F32: + case LLAISYS_DTYPE_F16: + case LLAISYS_DTYPE_BF16: + elem_size = llaisys::utils::dsize(type); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + + const int64_t *idx_ptr = reinterpret_cast(index); + size_t row_bytes = embd_dim * elem_size; + + for (size_t i = 0; i < index_numel; ++i) { + int64_t idx = idx_ptr[i]; + ASSERT(idx >= 0 && static_cast(idx) < weight_rows, "Embedding: index out of range."); + const std::byte *src = weight + static_cast(idx) * row_bytes; + std::byte *dst = out + i * row_bytes; + std::memcpy(dst, src, row_bytes); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 000000000..1b1626278 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, + size_t index_numel, size_t embd_dim, size_t weight_rows); +} diff --git a/src/ops/embedding/nvidia/embedding_nvidia.cu b/src/ops/embedding/nvidia/embedding_nvidia.cu new file mode 100644 index 000000000..7595c5cc8 --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_nvidia.cu @@ -0,0 +1,51 @@ +#include "embedding_nvidia.hpp" + +#include "../../../device/nvidia/cuda_utils.hpp" + +namespace llaisys::ops::nvidia { +namespace { +template +__global__ void embedding_kernel(T *out, const int64_t *index, const T *weight, size_t index_numel, size_t dim, + size_t vocab) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + size_t total = index_numel * dim; + if (idx >= total) { + return; + } + size_t row = idx / dim; + size_t col = idx % dim; + int64_t token = index[row]; + if (token < 0 || static_cast(token) >= vocab) { + return; + } + size_t w_idx = static_cast(token) * dim + col; + float v = llaisys::device::nvidia::ScalarOps::load(weight + w_idx); + llaisys::device::nvidia::ScalarOps::store(out + idx, v); +} + +template +void launch_embedding(std::byte *out, const std::byte *index, const std::byte *weight, size_t index_numel, + size_t dim, size_t vocab) { + size_t total = index_numel * dim; + const int threads = 256; + const int blocks = static_cast((total + threads - 1) / threads); + embedding_kernel<<>>(reinterpret_cast(out), reinterpret_cast(index), + reinterpret_cast(weight), index_numel, dim, vocab); + llaisys::device::nvidia::cuda_check(cudaGetLastError()); +} +} // namespace + +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, + size_t index_numel, size_t embd_dim, size_t weight_rows) { + switch (type) { + case LLAISYS_DTYPE_F32: + return launch_embedding(out, index, weight, index_numel, embd_dim, weight_rows); + case LLAISYS_DTYPE_BF16: + return launch_embedding(out, index, weight, index_numel, embd_dim, weight_rows); + case LLAISYS_DTYPE_F16: + return launch_embedding(out, index, weight, index_numel, embd_dim, weight_rows); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/embedding/nvidia/embedding_nvidia.hpp b/src/ops/embedding/nvidia/embedding_nvidia.hpp new file mode 100644 index 000000000..225d23fec --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_nvidia.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "../../../utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, + size_t index_numel, size_t embd_dim, size_t weight_rows); +} diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d06..23251c207 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,52 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/embedding_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/embedding_nvidia.hpp" +#endif +#ifdef ENABLE_ILUVATAR_API +#include "nvidia/embedding_nvidia.hpp" +#endif + namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, index, weight); + CHECK_SAME_DTYPE(out->dtype(), weight->dtype()); + ASSERT(index->dtype() == LLAISYS_DTYPE_I64, "Embedding: index must be int64."); + ASSERT(index->ndim() == 1, "Embedding: index must be 1D."); + ASSERT(weight->ndim() == 2, "Embedding: weight must be 2D."); + ASSERT(out->ndim() == 2, "Embedding: out must be 2D."); + + const auto &w_shape = weight->shape(); + size_t vocab = w_shape[0]; + size_t dim = w_shape[1]; + size_t index_numel = index->numel(); + ASSERT(out->shape()[0] == index_numel && out->shape()[1] == dim, "Embedding: output shape mismatch."); + + ASSERT(out->isContiguous() && index->isContiguous() && weight->isContiguous(), "Embedding: tensors must be contiguous."); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index_numel, dim, vocab); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index_numel, dim, vocab); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::embedding(out->data(), index->data(), weight->data(), out->dtype(), index_numel, dim, vocab); +#endif +#ifdef ENABLE_ILUVATAR_API + case LLAISYS_DEVICE_ILUVATAR: + return nvidia::embedding(out->data(), index->data(), weight->data(), out->dtype(), index_numel, dim, vocab); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 000000000..8a10398e0 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,48 @@ +#include "linear_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +namespace { + template + void linear_impl(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + size_t m, size_t n, size_t k) { + const T *in_ptr = reinterpret_cast(in); + const T *w_ptr = reinterpret_cast(weight); + const T *bias_ptr = bias ? reinterpret_cast(bias) : nullptr; + T *out_ptr = reinterpret_cast(out); + + for (size_t i = 0; i < m; ++i) { + for (size_t o = 0; o < n; ++o) { + //计算第i行第o列 + float acc = bias_ptr ? llaisys::utils::cast(bias_ptr[o]) : 0.f; + //weight的第o行 + const T *w_row = w_ptr + o * k; // weight shape [n, k] + //in的第i行 + const T *in_row = in_ptr + i * k; + //点积计算 + for (size_t j = 0; j < k; ++j) { + acc += llaisys::utils::cast(in_row[j]) * llaisys::utils::cast(w_row[j]); + } + out_ptr[i * n + o] = llaisys::utils::cast(acc); + } + } + } +} + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t m, size_t n, size_t k) { + switch (type) { + case LLAISYS_DTYPE_F32: + return linear_impl(out, in, weight, bias, m, n, k); + case LLAISYS_DTYPE_BF16: + return linear_impl(out, in, weight, bias, m, n, k); + case LLAISYS_DTYPE_F16: + return linear_impl(out, in, weight, bias, m, n, k); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 000000000..32a51c2bc --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t m, size_t n, size_t k); +} diff --git a/src/ops/linear/nvidia/linear_nvidia.cu b/src/ops/linear/nvidia/linear_nvidia.cu new file mode 100644 index 000000000..c5012f4f3 --- /dev/null +++ b/src/ops/linear/nvidia/linear_nvidia.cu @@ -0,0 +1,53 @@ +#include "linear_nvidia.hpp" + +#include "../../../device/nvidia/cuda_utils.hpp" + +namespace llaisys::ops::nvidia { +namespace { +template +__global__ void linear_kernel(T *out, const T *in, const T *weight, const T *bias, size_t m, size_t n, size_t k) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + size_t total = m * n; + if (idx >= total) { + return; + } + size_t row = idx / n; + size_t col = idx % n; + float acc = bias ? llaisys::device::nvidia::ScalarOps::load(bias + col) : 0.f; + const T *w_row = weight + col * k; + const T *in_row = in + row * k; + for (size_t j = 0; j < k; ++j) { + float a = llaisys::device::nvidia::ScalarOps::load(in_row + j); + float b = llaisys::device::nvidia::ScalarOps::load(w_row + j); + acc += a * b; + } + llaisys::device::nvidia::ScalarOps::store(out + idx, acc); +} + +template +void launch_linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, size_t m, + size_t n, size_t k) { + const int threads = 256; + const size_t total = m * n; + const int blocks = static_cast((total + threads - 1) / threads); + linear_kernel<<>>(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), + bias ? reinterpret_cast(bias) : nullptr, m, n, k); + llaisys::device::nvidia::cuda_check(cudaGetLastError()); +} +} // namespace + +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t m, size_t n, size_t k) { + switch (type) { + case LLAISYS_DTYPE_F32: + return launch_linear(out, in, weight, bias, m, n, k); + case LLAISYS_DTYPE_BF16: + return launch_linear(out, in, weight, bias, m, n, k); + case LLAISYS_DTYPE_F16: + return launch_linear(out, in, weight, bias, m, n, k); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/linear/nvidia/linear_nvidia.hpp b/src/ops/linear/nvidia/linear_nvidia.hpp new file mode 100644 index 000000000..31f1d8ebb --- /dev/null +++ b/src/ops/linear/nvidia/linear_nvidia.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "../../../utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t m, size_t n, size_t k); +} diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f8655..25d79e323 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,66 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/linear_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/linear_nvidia.hpp" +#endif +#ifdef ENABLE_ILUVATAR_API +#include "nvidia/linear_nvidia.hpp" +#endif + namespace llaisys::ops { void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight); + if (bias) { + CHECK_SAME_DEVICE(out, bias); + CHECK_SAME_DTYPE(out->dtype(), bias->dtype()); + } + CHECK_SAME_DTYPE(out->dtype(), in->dtype(), weight->dtype()); + + ASSERT(out->ndim() == 2, "Linear: out must be 2D."); + ASSERT(in->ndim() == 2, "Linear: input must be 2D."); + ASSERT(weight->ndim() == 2, "Linear: weight must be 2D."); + + size_t m = in->shape()[0]; + size_t k = in->shape()[1]; + size_t n = weight->shape()[0]; // weight shape [out_features, in_features] + + ASSERT(weight->shape()[1] == k, "Linear: weight in_features mismatch."); + ASSERT(out->shape()[0] == m && out->shape()[1] == n, "Linear: output shape mismatch."); + if (bias) { + ASSERT(bias->ndim() == 1 && bias->shape()[0] == n, "Linear: bias must be 1D with length out_features."); + } + + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous() + && (!bias || bias->isContiguous()), + "Linear: all tensors must be contiguous."); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, + out->dtype(), m, n, k); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, + out->dtype(), m, n, k); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, out->dtype(), + m, n, k); +#endif +#ifdef ENABLE_ILUVATAR_API + case LLAISYS_DEVICE_ILUVATAR: + return nvidia::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, out->dtype(), + m, n, k); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rearrange/cpu/rearrange_cpu.cpp b/src/ops/rearrange/cpu/rearrange_cpu.cpp new file mode 100644 index 000000000..0ccaf634f --- /dev/null +++ b/src/ops/rearrange/cpu/rearrange_cpu.cpp @@ -0,0 +1,47 @@ +#include "rearrange_cpu.hpp" + +#include + +namespace { +void rearrange_recursive(std::byte *out, + const std::byte *in, + const std::vector &shape, + const std::vector &out_strides, + const std::vector &in_strides, + size_t elem_size, + size_t dim, + ptrdiff_t out_off, + ptrdiff_t in_off) { + if (dim == shape.size()) { + std::memcpy(out + out_off * elem_size, in + in_off * elem_size, elem_size); + return; + } + + const size_t len = shape[dim]; + const ptrdiff_t os = out_strides[dim]; + const ptrdiff_t is = in_strides[dim]; + + for (size_t i = 0; i < len; ++i) { + rearrange_recursive(out, + in, + shape, + out_strides, + in_strides, + elem_size, + dim + 1, + out_off + static_cast(i) * os, + in_off + static_cast(i) * is); + } +} +} // namespace + +namespace llaisys::ops::cpu { +void rearrange(std::byte *out, + const std::byte *in, + const std::vector &shape, + const std::vector &out_strides, + const std::vector &in_strides, + size_t elem_size) { + rearrange_recursive(out, in, shape, out_strides, in_strides, elem_size, 0, 0, 0); +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rearrange/cpu/rearrange_cpu.hpp b/src/ops/rearrange/cpu/rearrange_cpu.hpp new file mode 100644 index 000000000..c78be3e6b --- /dev/null +++ b/src/ops/rearrange/cpu/rearrange_cpu.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::cpu { +void rearrange(std::byte *out, + const std::byte *in, + const std::vector &shape, + const std::vector &out_strides, + const std::vector &in_strides, + size_t elem_size); +} diff --git a/src/ops/rearrange/nvidia/rearrange_nvidia.cu b/src/ops/rearrange/nvidia/rearrange_nvidia.cu new file mode 100644 index 000000000..bcc1d9c4b --- /dev/null +++ b/src/ops/rearrange/nvidia/rearrange_nvidia.cu @@ -0,0 +1,40 @@ +#include "rearrange_nvidia.hpp" + +#include "../../../device/nvidia/cuda_utils.hpp" + +namespace llaisys::ops::nvidia { +namespace { +__global__ void rearrange_kernel(std::byte *out, const std::byte *in, const size_t *shape, + const ptrdiff_t *out_strides, const ptrdiff_t *in_strides, size_t ndim, + size_t elem_size, size_t numel) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= numel) { + return; + } + size_t tmp = idx; + ptrdiff_t out_off = 0; + ptrdiff_t in_off = 0; + for (size_t d = 0; d < ndim; ++d) { + size_t dim = ndim - 1 - d; + size_t size = shape[dim]; + size_t coord = tmp % size; + tmp /= size; + out_off += static_cast(coord) * out_strides[dim]; + in_off += static_cast(coord) * in_strides[dim]; + } + std::byte *dst = out + out_off * static_cast(elem_size); + const std::byte *src = in + in_off * static_cast(elem_size); + for (size_t i = 0; i < elem_size; ++i) { + dst[i] = src[i]; + } +} +} // namespace + +void rearrange(std::byte *out, const std::byte *in, const size_t *shape, const ptrdiff_t *out_strides, + const ptrdiff_t *in_strides, size_t ndim, size_t elem_size, size_t numel) { + const int threads = 256; + const int blocks = static_cast((numel + threads - 1) / threads); + rearrange_kernel<<>>(out, in, shape, out_strides, in_strides, ndim, elem_size, numel); + llaisys::device::nvidia::cuda_check(cudaGetLastError()); +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rearrange/nvidia/rearrange_nvidia.hpp b/src/ops/rearrange/nvidia/rearrange_nvidia.hpp new file mode 100644 index 000000000..9053f4611 --- /dev/null +++ b/src/ops/rearrange/nvidia/rearrange_nvidia.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "../../../utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +void rearrange(std::byte *out, const std::byte *in, const size_t *shape, const ptrdiff_t *out_strides, + const ptrdiff_t *in_strides, size_t ndim, size_t elem_size, size_t numel); +} diff --git a/src/ops/rearrange/op.cpp b/src/ops/rearrange/op.cpp index 017a6ae59..013d82308 100644 --- a/src/ops/rearrange/op.cpp +++ b/src/ops/rearrange/op.cpp @@ -1,7 +1,86 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../device/runtime_api.hpp" + +#include "cpu/rearrange_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/rearrange_nvidia.hpp" +#endif +#ifdef ENABLE_ILUVATAR_API +#include "nvidia/rearrange_nvidia.hpp" +#endif + namespace llaisys::ops { void rearrange(tensor_t out, tensor_t in) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in); + CHECK_SAME_DTYPE(out->dtype(), in->dtype()); + ASSERT(out->shape() == in->shape(), "Rearrange: shapes must match."); + + const auto elem_size = out->elementSize(); + const auto &shape = out->shape(); + const auto &out_strides = out->strides(); + const auto &in_strides = in->strides(); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rearrange(out->data(), in->data(), shape, out_strides, in_strides, elem_size); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rearrange(out->data(), in->data(), shape, out_strides, in_strides, elem_size); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + { + const auto runtime = llaisys::device::getRuntimeAPI(out->deviceType()); + const size_t ndim = shape.size(); + const size_t shape_bytes = ndim * sizeof(size_t); + const size_t stride_bytes = ndim * sizeof(ptrdiff_t); + void *shape_dev = runtime->malloc_device(shape_bytes); + void *out_strides_dev = runtime->malloc_device(stride_bytes); + void *in_strides_dev = runtime->malloc_device(stride_bytes); + runtime->memcpy_sync(shape_dev, shape.data(), shape_bytes, LLAISYS_MEMCPY_H2D); + runtime->memcpy_sync(out_strides_dev, out_strides.data(), stride_bytes, LLAISYS_MEMCPY_H2D); + runtime->memcpy_sync(in_strides_dev, in_strides.data(), stride_bytes, LLAISYS_MEMCPY_H2D); + nvidia::rearrange(out->data(), in->data(), + reinterpret_cast(shape_dev), + reinterpret_cast(out_strides_dev), + reinterpret_cast(in_strides_dev), + ndim, elem_size, out->numel()); + runtime->free_device(shape_dev); + runtime->free_device(out_strides_dev); + runtime->free_device(in_strides_dev); + return; + } +#endif +#ifdef ENABLE_ILUVATAR_API + case LLAISYS_DEVICE_ILUVATAR: + { + const auto runtime = llaisys::device::getRuntimeAPI(out->deviceType()); + const size_t ndim = shape.size(); + const size_t shape_bytes = ndim * sizeof(size_t); + const size_t stride_bytes = ndim * sizeof(ptrdiff_t); + void *shape_dev = runtime->malloc_device(shape_bytes); + void *out_strides_dev = runtime->malloc_device(stride_bytes); + void *in_strides_dev = runtime->malloc_device(stride_bytes); + runtime->memcpy_sync(shape_dev, shape.data(), shape_bytes, LLAISYS_MEMCPY_H2D); + runtime->memcpy_sync(out_strides_dev, out_strides.data(), stride_bytes, LLAISYS_MEMCPY_H2D); + runtime->memcpy_sync(in_strides_dev, in_strides.data(), stride_bytes, LLAISYS_MEMCPY_H2D); + nvidia::rearrange(out->data(), in->data(), + reinterpret_cast(shape_dev), + reinterpret_cast(out_strides_dev), + reinterpret_cast(in_strides_dev), + ndim, elem_size, out->numel()); + runtime->free_device(shape_dev); + runtime->free_device(out_strides_dev); + runtime->free_device(in_strides_dev); + return; + } +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.cpp b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp new file mode 100644 index 000000000..35e2d96ec --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp @@ -0,0 +1,50 @@ +#include "rms_norm_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +namespace { + template + void rms_norm_impl(std::byte *out, const std::byte *in, const std::byte *weight, size_t rows, size_t cols, + float eps) { + const T *in_ptr = reinterpret_cast(in); + const T *w_ptr = reinterpret_cast(weight); + T *out_ptr = reinterpret_cast(out); + + for (size_t i = 0; i < rows; ++i) { + const T *row_in = in_ptr + i * cols; + T *row_out = out_ptr + i * cols; + + float sum_sq = 0.f; + for (size_t j = 0; j < cols; ++j) { + float v = llaisys::utils::cast(row_in[j]); + sum_sq += v * v; + } + float mean = sum_sq / static_cast(cols); + float inv_rms = 1.0f / std::sqrt(mean + eps); + + for (size_t j = 0; j < cols; ++j) { + float v = llaisys::utils::cast(row_in[j]); + float w = llaisys::utils::cast(w_ptr[j]); + row_out[j] = llaisys::utils::cast(v * inv_rms * w); + } + } + } +} + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, + size_t rows, size_t cols, float eps) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rms_norm_impl(out, in, weight, rows, cols, eps); + case LLAISYS_DTYPE_BF16: + return rms_norm_impl(out, in, weight, rows, cols, eps); + case LLAISYS_DTYPE_F16: + return rms_norm_impl(out, in, weight, rows, cols, eps); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.hpp b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp new file mode 100644 index 000000000..b3cc8d21b --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, + size_t rows, size_t cols, float eps); +} diff --git a/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu new file mode 100644 index 000000000..ddc40e923 --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu @@ -0,0 +1,54 @@ +#include "rms_norm_nvidia.hpp" + +#include "../../../device/nvidia/cuda_utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +namespace { +template +__global__ void rms_norm_kernel(T *out, const T *in, const T *weight, size_t rows, size_t cols, float eps) { + size_t row = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + const T *row_in = in + row * cols; + T *row_out = out + row * cols; + float sum_sq = 0.f; + for (size_t j = 0; j < cols; ++j) { + float v = llaisys::device::nvidia::ScalarOps::load(row_in + j); + sum_sq += v * v; + } + float mean = sum_sq / static_cast(cols); + float inv_rms = rsqrtf(mean + eps); + for (size_t j = 0; j < cols; ++j) { + float v = llaisys::device::nvidia::ScalarOps::load(row_in + j); + float w = llaisys::device::nvidia::ScalarOps::load(weight + j); + llaisys::device::nvidia::ScalarOps::store(row_out + j, v * inv_rms * w); + } +} + +template +void launch_rms(std::byte *out, const std::byte *in, const std::byte *weight, size_t rows, size_t cols, float eps) { + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + rms_norm_kernel<<>>(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), rows, cols, eps); + llaisys::device::nvidia::cuda_check(cudaGetLastError()); +} +} // namespace + +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, + size_t rows, size_t cols, float eps) { + switch (type) { + case LLAISYS_DTYPE_F32: + return launch_rms(out, in, weight, rows, cols, eps); + case LLAISYS_DTYPE_BF16: + return launch_rms(out, in, weight, rows, cols, eps); + case LLAISYS_DTYPE_F16: + return launch_rms(out, in, weight, rows, cols, eps); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rms_norm/nvidia/rms_norm_nvidia.hpp b/src/ops/rms_norm/nvidia/rms_norm_nvidia.hpp new file mode 100644 index 000000000..25a0d28c4 --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_nvidia.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "../../../utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, + size_t rows, size_t cols, float eps); +} diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9d..50609db07 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,52 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rms_norm_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/rms_norm_nvidia.hpp" +#endif +#ifdef ENABLE_ILUVATAR_API +#include "nvidia/rms_norm_nvidia.hpp" +#endif + namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight); + CHECK_SAME_DTYPE(out->dtype(), in->dtype(), weight->dtype()); + + ASSERT(out->ndim() == 2, "RMSNorm: out must be 2D."); + ASSERT(in->ndim() == 2, "RMSNorm: input must be 2D."); + ASSERT(weight->ndim() == 1, "RMSNorm: weight must be 1D."); + + size_t rows = in->shape()[0]; + size_t cols = in->shape()[1]; + ASSERT(out->shape()[0] == rows && out->shape()[1] == cols, "RMSNorm: output shape mismatch."); + ASSERT(weight->shape()[0] == cols, "RMSNorm: weight length must match input last dim."); + + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), + "RMSNorm: tensors must be contiguous."); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rms_norm(out->data(), in->data(), weight->data(), out->dtype(), rows, cols, eps); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rms_norm(out->data(), in->data(), weight->data(), out->dtype(), rows, cols, eps); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::rms_norm(out->data(), in->data(), weight->data(), out->dtype(), rows, cols, eps); +#endif +#ifdef ENABLE_ILUVATAR_API + case LLAISYS_DEVICE_ILUVATAR: + return nvidia::rms_norm(out->data(), in->data(), weight->data(), out->dtype(), rows, cols, eps); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 000000000..02fdcddb1 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,56 @@ +#include "rope_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +namespace { + template + void rope_impl(std::byte *out, const std::byte *in, const std::byte *pos_ids, + size_t seqlen, size_t nhead, size_t dim, float theta) { + const T *in_ptr = reinterpret_cast(in); + const int64_t *pos_ptr = reinterpret_cast(pos_ids); + T *out_ptr = reinterpret_cast(out); + + size_t head_stride = dim; + size_t seq_stride = nhead * dim; + size_t half = dim / 2; + + for (size_t s = 0; s < seqlen; ++s) { + float p = static_cast(pos_ptr[s]); + for (size_t h = 0; h < nhead; ++h) { + const T *x = in_ptr + s * seq_stride + h * head_stride; + T *y = out_ptr + s * seq_stride + h * head_stride; + + for (size_t j = 0; j < half; ++j) { + float exponent = static_cast(2.0f * static_cast(j) / static_cast(dim)); + float angle = p / std::pow(theta, exponent); + float sinv = std::sin(angle); + float cosv = std::cos(angle); + + float a = llaisys::utils::cast(x[j]); + float b = llaisys::utils::cast(x[half + j]); + + y[j] = llaisys::utils::cast(a * cosv - b * sinv); + y[half + j] = llaisys::utils::cast(b * cosv + a * sinv); + } + } + } + } +} + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, + size_t seqlen, size_t nhead, size_t dim, float theta) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rope_impl(out, in, pos_ids, seqlen, nhead, dim, theta); + case LLAISYS_DTYPE_BF16: + return rope_impl(out, in, pos_ids, seqlen, nhead, dim, theta); + case LLAISYS_DTYPE_F16: + return rope_impl(out, in, pos_ids, seqlen, nhead, dim, theta); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 000000000..352418a14 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, + size_t seqlen, size_t nhead, size_t dim, float theta); +} diff --git a/src/ops/rope/nvidia/rope_nvidia.cu b/src/ops/rope/nvidia/rope_nvidia.cu new file mode 100644 index 000000000..9aeeb9321 --- /dev/null +++ b/src/ops/rope/nvidia/rope_nvidia.cu @@ -0,0 +1,61 @@ +#include "rope_nvidia.hpp" + +#include "../../../device/nvidia/cuda_utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +namespace { +template +__global__ void rope_kernel(T *out, const T *in, const int64_t *pos_ids, size_t seqlen, size_t nhead, size_t dim, + float theta) { + size_t half = dim / 2; + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + size_t total = seqlen * nhead * half; + if (idx >= total) { + return; + } + size_t j = idx % half; + size_t tmp = idx / half; + size_t h = tmp % nhead; + size_t s = tmp / nhead; + float p = static_cast(pos_ids[s]); + float exponent = 2.0f * static_cast(j) / static_cast(dim); + float angle = p / powf(theta, exponent); + float sinv = sinf(angle); + float cosv = cosf(angle); + + size_t base = (s * nhead + h) * dim; + float a = llaisys::device::nvidia::ScalarOps::load(in + base + j); + float b = llaisys::device::nvidia::ScalarOps::load(in + base + half + j); + llaisys::device::nvidia::ScalarOps::store(out + base + j, a * cosv - b * sinv); + llaisys::device::nvidia::ScalarOps::store(out + base + half + j, b * cosv + a * sinv); +} + +template +void launch_rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, size_t seqlen, size_t nhead, + size_t dim, float theta) { + size_t half = dim / 2; + size_t total = seqlen * nhead * half; + const int threads = 256; + const int blocks = static_cast((total + threads - 1) / threads); + rope_kernel<<>>(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(pos_ids), seqlen, nhead, dim, theta); + llaisys::device::nvidia::cuda_check(cudaGetLastError()); +} +} // namespace + +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, size_t seqlen, + size_t nhead, size_t dim, float theta) { + switch (type) { + case LLAISYS_DTYPE_F32: + return launch_rope(out, in, pos_ids, seqlen, nhead, dim, theta); + case LLAISYS_DTYPE_BF16: + return launch_rope(out, in, pos_ids, seqlen, nhead, dim, theta); + case LLAISYS_DTYPE_F16: + return launch_rope(out, in, pos_ids, seqlen, nhead, dim, theta); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rope/nvidia/rope_nvidia.hpp b/src/ops/rope/nvidia/rope_nvidia.hpp new file mode 100644 index 000000000..ffa58f412 --- /dev/null +++ b/src/ops/rope/nvidia/rope_nvidia.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "../../../utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, size_t seqlen, + size_t nhead, size_t dim, float theta); +} diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64e..c9dedf8ab 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,57 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rope_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/rope_nvidia.hpp" +#endif +#ifdef ENABLE_ILUVATAR_API +#include "nvidia/rope_nvidia.hpp" +#endif + namespace llaisys::ops { void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in); + ASSERT(pos_ids->deviceType() == out->deviceType() && pos_ids->deviceId() == out->deviceId(), + "ROPE: pos_ids must be on the same device."); + CHECK_SAME_DTYPE(out->dtype(), in->dtype()); + ASSERT(pos_ids->dtype() == LLAISYS_DTYPE_I64, "ROPE: pos_ids must be int64."); + + ASSERT(out->ndim() == 3 && in->ndim() == 3, "ROPE: out and in must be 3D [seqlen, nhead, dim]."); + ASSERT(pos_ids->ndim() == 1, "ROPE: pos_ids must be 1D [seqlen]."); + + size_t seqlen = in->shape()[0]; + size_t nhead = in->shape()[1]; + size_t dim = in->shape()[2]; + ASSERT(dim % 2 == 0, "ROPE: head dim must be even."); + + ASSERT(out->shape()[0] == seqlen && out->shape()[1] == nhead && out->shape()[2] == dim, + "ROPE: output shape mismatch."); + ASSERT(pos_ids->shape()[0] == seqlen, "ROPE: pos_ids length must equal seqlen."); + + ASSERT(out->isContiguous() && in->isContiguous() && pos_ids->isContiguous(), "ROPE: tensors must be contiguous."); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rope(out->data(), in->data(), pos_ids->data(), out->dtype(), seqlen, nhead, dim, theta); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rope(out->data(), in->data(), pos_ids->data(), out->dtype(), seqlen, nhead, dim, theta); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::rope(out->data(), in->data(), pos_ids->data(), out->dtype(), seqlen, nhead, dim, theta); +#endif +#ifdef ENABLE_ILUVATAR_API + case LLAISYS_DEVICE_ILUVATAR: + return nvidia::rope(out->data(), in->data(), pos_ids->data(), out->dtype(), seqlen, nhead, dim, theta); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/self_attention/cpu/self_attention_cpu.cpp b/src/ops/self_attention/cpu/self_attention_cpu.cpp new file mode 100644 index 000000000..3fb31b9cb --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.cpp @@ -0,0 +1,224 @@ +#include "self_attention_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include +#include + +namespace { + template + void self_attn_impl(std::byte *out, const std::byte *q, const std::byte *k, const std::byte *v, + size_t qlen, size_t kvlen, size_t nhead, size_t nkvh, size_t dim, size_t dv, float scale) { + const T *q_ptr = reinterpret_cast(q); + const T *k_ptr = reinterpret_cast(k); + const T *v_ptr = reinterpret_cast(v); + T *out_ptr = reinterpret_cast(out); + + const size_t q_head_stride = dim; + const size_t k_head_stride = dim; + const size_t v_head_stride = dv; + const size_t q_seq_stride = nhead * dim; + const size_t k_seq_stride = nkvh * dim; + const size_t v_seq_stride = nkvh * dv; + const size_t out_head_stride = dv; + const size_t out_seq_stride = nhead * dv; + + const int head_factor = static_cast(nhead / nkvh); + + std::vector logits(kvlen); + std::vector probs(kvlen); + + for (size_t s = 0; s < qlen; ++s) { + for (size_t h = 0; h < nhead; ++h) { + const T *q_vec = q_ptr + s * q_seq_stride + h * q_head_stride; + int kh = static_cast(h / head_factor); + const T *k_base = k_ptr + kh * k_head_stride; + const T *v_base = v_ptr + kh * v_head_stride; + float max_logit = -std::numeric_limits::infinity(); + + int allow_upto = static_cast(s + kvlen - qlen); + for (size_t t = 0; t < kvlen; ++t) { + float logit; + if (static_cast(t) > allow_upto) { + logit = -1e20f; + } else { + const T *k_vec = k_base + t * k_seq_stride; + float dot = 0.f; + for (size_t j = 0; j < dim; ++j) { + dot += llaisys::utils::cast(q_vec[j]) * llaisys::utils::cast(k_vec[j]); + } + logit = dot * scale; + } + logits[t] = logit; + max_logit = std::max(max_logit, logit); + } + + float sum_exp = 0.f; + for (size_t t = 0; t < kvlen; ++t) { + float e = std::exp(logits[t] - max_logit); + probs[t] = e; + sum_exp += e; + } + float inv_sum = 1.0f / sum_exp; + + T *y = out_ptr + s * out_seq_stride + h * out_head_stride; + for (size_t d = 0; d < dv; ++d) { + float acc = 0.f; + for (size_t t = 0; t < kvlen; ++t) { + const T *v_vec = v_base + t * v_seq_stride; + acc += (probs[t] * inv_sum) * llaisys::utils::cast(v_vec[d]); + } + y[d] = llaisys::utils::cast(acc); + } + } + } + } + + template + void self_attn_segmented_impl(std::byte *out, + const std::byte *q, + const std::byte *k, + const std::byte *v, + size_t qlen, + size_t kvlen, + size_t nhead, + size_t nkvh, + size_t dim, + size_t dv, + float scale, + const int64_t *q_offsets, + const int64_t *kv_offsets, + size_t nseg) { + const T *q_ptr = reinterpret_cast(q); + const T *k_ptr = reinterpret_cast(k); + const T *v_ptr = reinterpret_cast(v); + T *out_ptr = reinterpret_cast(out); + + const size_t q_head_stride = dim; + const size_t k_head_stride = dim; + const size_t v_head_stride = dv; + const size_t q_seq_stride = nhead * dim; + const size_t k_seq_stride = nkvh * dim; + const size_t v_seq_stride = nkvh * dv; + const size_t out_head_stride = dv; + const size_t out_seq_stride = nhead * dv; + const int head_factor = static_cast(nhead / nkvh); + + std::vector logits(kvlen); + std::vector probs(kvlen); + + // Build query->segment lookup. + std::vector q2seg(qlen, 0); + for (size_t seg = 0; seg < nseg; ++seg) { + const size_t qb = static_cast(q_offsets[seg]); + const size_t qe = static_cast(q_offsets[seg + 1]); + for (size_t s = qb; s < qe; ++s) q2seg[s] = seg; + } + + for (size_t s = 0; s < qlen; ++s) { + const size_t seg = q2seg[s]; + const size_t q_begin = static_cast(q_offsets[seg]); + const size_t q_end = static_cast(q_offsets[seg + 1]); + const size_t kv_begin = static_cast(kv_offsets[seg]); + const size_t kv_end = static_cast(kv_offsets[seg + 1]); + const size_t local_q = s - q_begin; + const size_t seg_qlen = q_end - q_begin; + const size_t seg_kvlen = kv_end - kv_begin; + const size_t local_allow = local_q + (seg_kvlen - seg_qlen); + const size_t global_allow = kv_begin + local_allow; + + for (size_t h = 0; h < nhead; ++h) { + const T *q_vec = q_ptr + s * q_seq_stride + h * q_head_stride; + int kh = static_cast(h / head_factor); + const T *k_base = k_ptr + kh * k_head_stride; + const T *v_base = v_ptr + kh * v_head_stride; + float max_logit = -std::numeric_limits::infinity(); + + for (size_t t = 0; t < kvlen; ++t) { + float logit; + const bool in_seg = (t >= kv_begin && t < kv_end); + const bool causal_ok = (t <= global_allow); + if (!in_seg || !causal_ok) { + logit = -1e20f; + } else { + const T *k_vec = k_base + t * k_seq_stride; + float dot = 0.f; + for (size_t j = 0; j < dim; ++j) { + dot += llaisys::utils::cast(q_vec[j]) * llaisys::utils::cast(k_vec[j]); + } + logit = dot * scale; + } + logits[t] = logit; + max_logit = std::max(max_logit, logit); + } + + float sum_exp = 0.f; + for (size_t t = 0; t < kvlen; ++t) { + float e = std::exp(logits[t] - max_logit); + probs[t] = e; + sum_exp += e; + } + float inv_sum = 1.0f / sum_exp; + + T *y = out_ptr + s * out_seq_stride + h * out_head_stride; + for (size_t d = 0; d < dv; ++d) { + float acc = 0.f; + for (size_t t = 0; t < kvlen; ++t) { + const T *v_vec = v_base + t * v_seq_stride; + acc += (probs[t] * inv_sum) * llaisys::utils::cast(v_vec[d]); + } + y[d] = llaisys::utils::cast(acc); + } + } + } + } +} + +namespace llaisys::ops::cpu { +void self_attention(std::byte *out, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, size_t qlen, size_t kvlen, size_t nhead, size_t nkvh, + size_t dim, size_t dv, float scale) { + switch (type) { + case LLAISYS_DTYPE_F32: + return self_attn_impl(out, q, k, v, qlen, kvlen, nhead, nkvh, dim, dv, scale); + case LLAISYS_DTYPE_BF16: + return self_attn_impl(out, q, k, v, qlen, kvlen, nhead, nkvh, dim, dv, scale); + case LLAISYS_DTYPE_F16: + return self_attn_impl(out, q, k, v, qlen, kvlen, nhead, nkvh, dim, dv, scale); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +void self_attention_segmented(std::byte *out, + const std::byte *q, + const std::byte *k, + const std::byte *v, + llaisysDataType_t type, + size_t qlen, + size_t kvlen, + size_t nhead, + size_t nkvh, + size_t dim, + size_t dv, + float scale, + const int64_t *q_offsets, + const int64_t *kv_offsets, + size_t nseg) { + switch (type) { + case LLAISYS_DTYPE_F32: + return self_attn_segmented_impl( + out, q, k, v, qlen, kvlen, nhead, nkvh, dim, dv, scale, q_offsets, kv_offsets, nseg); + case LLAISYS_DTYPE_BF16: + return self_attn_segmented_impl( + out, q, k, v, qlen, kvlen, nhead, nkvh, dim, dv, scale, q_offsets, kv_offsets, nseg); + case LLAISYS_DTYPE_F16: + return self_attn_segmented_impl( + out, q, k, v, qlen, kvlen, nhead, nkvh, dim, dv, scale, q_offsets, kv_offsets, nseg); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/cpu/self_attention_cpu.hpp b/src/ops/self_attention/cpu/self_attention_cpu.hpp new file mode 100644 index 000000000..db9719e8a --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.hpp @@ -0,0 +1,25 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void self_attention(std::byte *out, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, size_t qlen, size_t kvlen, size_t nhead, size_t nkvh, + size_t dim, size_t dv, float scale); +void self_attention_segmented(std::byte *out, + const std::byte *q, + const std::byte *k, + const std::byte *v, + llaisysDataType_t type, + size_t qlen, + size_t kvlen, + size_t nhead, + size_t nkvh, + size_t dim, + size_t dv, + float scale, + const int64_t *q_offsets, + const int64_t *kv_offsets, + size_t nseg); +} diff --git a/src/ops/self_attention/nvidia/self_attention_nvidia.cu b/src/ops/self_attention/nvidia/self_attention_nvidia.cu new file mode 100644 index 000000000..637c1122a --- /dev/null +++ b/src/ops/self_attention/nvidia/self_attention_nvidia.cu @@ -0,0 +1,117 @@ +#include "self_attention_nvidia.hpp" + +#include "../../../device/nvidia/cuda_utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +namespace { +template +__global__ void self_attention_kernel(T *out, const T *q, const T *k, const T *v, size_t qlen, size_t kvlen, + size_t nhead, size_t nkvh, size_t dim, size_t dv, float scale) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + size_t total = qlen * nhead; + if (idx >= total) { + return; + } + size_t s = idx / nhead; + size_t h = idx % nhead; + size_t head_factor = nhead / nkvh; + size_t kh = h / head_factor; + + const T *q_vec = q + (s * nhead + h) * dim; + const T *k_base = k + kh * dim; + const T *v_base = v + kh * dv; + + int allow_upto = static_cast(s + kvlen - qlen); + float max_logit = -1e20f; + for (size_t t = 0; t < kvlen; ++t) { + float logit; + if (static_cast(t) > allow_upto) { + logit = -1e20f; + } else { + const T *k_vec = k_base + t * nkvh * dim; + float dot = 0.f; + for (size_t j = 0; j < dim; ++j) { + float qv = llaisys::device::nvidia::ScalarOps::load(q_vec + j); + float kv = llaisys::device::nvidia::ScalarOps::load(k_vec + j); + dot += qv * kv; + } + logit = dot * scale; + } + max_logit = fmaxf(max_logit, logit); + } + + float sum_exp = 0.f; + for (size_t t = 0; t < kvlen; ++t) { + float logit; + if (static_cast(t) > allow_upto) { + logit = -1e20f; + } else { + const T *k_vec = k_base + t * nkvh * dim; + float dot = 0.f; + for (size_t j = 0; j < dim; ++j) { + float qv = llaisys::device::nvidia::ScalarOps::load(q_vec + j); + float kv = llaisys::device::nvidia::ScalarOps::load(k_vec + j); + dot += qv * kv; + } + logit = dot * scale; + } + sum_exp += expf(logit - max_logit); + } + float inv_sum = 1.0f / sum_exp; + + T *y = out + (s * nhead + h) * dv; + for (size_t d = 0; d < dv; ++d) { + float acc = 0.f; + for (size_t t = 0; t < kvlen; ++t) { + float logit; + if (static_cast(t) > allow_upto) { + logit = -1e20f; + } else { + const T *k_vec = k_base + t * nkvh * dim; + float dot = 0.f; + for (size_t j = 0; j < dim; ++j) { + float qv = llaisys::device::nvidia::ScalarOps::load(q_vec + j); + float kv = llaisys::device::nvidia::ScalarOps::load(k_vec + j); + dot += qv * kv; + } + logit = dot * scale; + } + float prob = expf(logit - max_logit) * inv_sum; + const T *v_vec = v_base + t * nkvh * dv; + float vv = llaisys::device::nvidia::ScalarOps::load(v_vec + d); + acc += prob * vv; + } + llaisys::device::nvidia::ScalarOps::store(y + d, acc); + } +} + +template +void launch_self_attention(std::byte *out, const std::byte *q, const std::byte *k, const std::byte *v, size_t qlen, + size_t kvlen, size_t nhead, size_t nkvh, size_t dim, size_t dv, float scale) { + size_t total = qlen * nhead; + const int threads = 64; + const int blocks = static_cast((total + threads - 1) / threads); + self_attention_kernel<<>>(reinterpret_cast(out), reinterpret_cast(q), + reinterpret_cast(k), reinterpret_cast(v), qlen, + kvlen, nhead, nkvh, dim, dv, scale); + llaisys::device::nvidia::cuda_check(cudaGetLastError()); +} +} // namespace + +void self_attention(std::byte *out, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, size_t qlen, size_t kvlen, size_t nhead, size_t nkvh, size_t dim, + size_t dv, float scale) { + switch (type) { + case LLAISYS_DTYPE_F32: + return launch_self_attention(out, q, k, v, qlen, kvlen, nhead, nkvh, dim, dv, scale); + case LLAISYS_DTYPE_BF16: + return launch_self_attention(out, q, k, v, qlen, kvlen, nhead, nkvh, dim, dv, scale); + case LLAISYS_DTYPE_F16: + return launch_self_attention(out, q, k, v, qlen, kvlen, nhead, nkvh, dim, dv, scale); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/self_attention/nvidia/self_attention_nvidia.hpp b/src/ops/self_attention/nvidia/self_attention_nvidia.hpp new file mode 100644 index 000000000..abac2419e --- /dev/null +++ b/src/ops/self_attention/nvidia/self_attention_nvidia.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "../../../utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +void self_attention(std::byte *out, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, size_t qlen, size_t kvlen, size_t nhead, size_t nkvh, size_t dim, + size_t dv, float scale); +} diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d620142..c6a756572 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,124 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/self_attention_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/self_attention_nvidia.hpp" +#endif +#ifdef ENABLE_ILUVATAR_API +#include "nvidia/self_attention_nvidia.hpp" +#endif + namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(attn_val, q, k, v); + CHECK_SAME_DTYPE(attn_val->dtype(), q->dtype(), k->dtype(), v->dtype()); + + ASSERT(attn_val->ndim() == 3 && q->ndim() == 3 && k->ndim() == 3 && v->ndim() == 3, + "SelfAttention: all tensors must be 3D."); + + size_t qlen = q->shape()[0]; + size_t nhead = q->shape()[1]; + size_t dim = q->shape()[2]; + + size_t kvlen = k->shape()[0]; + size_t nkvh = k->shape()[1]; + size_t kdim = k->shape()[2]; + size_t vdim = v->shape()[2]; + + ASSERT(dim == kdim, "SelfAttention: q and k head dim mismatch."); + ASSERT(v->shape()[0] == kvlen && v->shape()[1] == nkvh, "SelfAttention: v shape mismatch with k."); + ASSERT(attn_val->shape()[0] == qlen && attn_val->shape()[1] == nhead && attn_val->shape()[2] == vdim, + "SelfAttention: output shape mismatch."); + ASSERT(nhead % nkvh == 0, "SelfAttention: nhead must be divisible by nkvh."); + + ASSERT(attn_val->isContiguous() && q->isContiguous() && k->isContiguous() && v->isContiguous(), + "SelfAttention: tensors must be contiguous."); + + if (attn_val->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), attn_val->dtype(), qlen, + kvlen, nhead, nkvh, dim, vdim, scale); + } + + llaisys::core::context().setDevice(attn_val->deviceType(), attn_val->deviceId()); + + switch (attn_val->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), attn_val->dtype(), qlen, + kvlen, nhead, nkvh, dim, vdim, scale); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::self_attention(attn_val->data(), q->data(), k->data(), v->data(), attn_val->dtype(), qlen, + kvlen, nhead, nkvh, dim, vdim, scale); +#endif +#ifdef ENABLE_ILUVATAR_API + case LLAISYS_DEVICE_ILUVATAR: + return nvidia::self_attention(attn_val->data(), q->data(), k->data(), v->data(), attn_val->dtype(), qlen, + kvlen, nhead, nkvh, dim, vdim, scale); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } +} + +void self_attention_segmented(tensor_t attn_val, + tensor_t q, + tensor_t k, + tensor_t v, + float scale, + const int64_t *q_offsets, + const int64_t *kv_offsets, + size_t nseg) { + CHECK_SAME_DEVICE(attn_val, q, k, v); + CHECK_SAME_DTYPE(attn_val->dtype(), q->dtype(), k->dtype(), v->dtype()); + ASSERT(nseg > 0, "SelfAttentionSegmented: nseg must be > 0."); + ASSERT(q_offsets && kv_offsets, "SelfAttentionSegmented: offsets must not be null."); + + ASSERT(attn_val->ndim() == 3 && q->ndim() == 3 && k->ndim() == 3 && v->ndim() == 3, + "SelfAttentionSegmented: all tensors must be 3D."); + + size_t qlen = q->shape()[0]; + size_t nhead = q->shape()[1]; + size_t dim = q->shape()[2]; + size_t kvlen = k->shape()[0]; + size_t nkvh = k->shape()[1]; + size_t kdim = k->shape()[2]; + size_t vdim = v->shape()[2]; + + ASSERT(dim == kdim, "SelfAttentionSegmented: q and k head dim mismatch."); + ASSERT(v->shape()[0] == kvlen && v->shape()[1] == nkvh, "SelfAttentionSegmented: v shape mismatch with k."); + ASSERT(attn_val->shape()[0] == qlen && attn_val->shape()[1] == nhead && attn_val->shape()[2] == vdim, + "SelfAttentionSegmented: output shape mismatch."); + ASSERT(nhead % nkvh == 0, "SelfAttentionSegmented: nhead must be divisible by nkvh."); + ASSERT(attn_val->isContiguous() && q->isContiguous() && k->isContiguous() && v->isContiguous(), + "SelfAttentionSegmented: tensors must be contiguous."); + + ASSERT(q_offsets[0] == 0 && kv_offsets[0] == 0, "SelfAttentionSegmented: offsets must start at 0."); + ASSERT(static_cast(q_offsets[nseg]) == qlen, "SelfAttentionSegmented: q_offsets end mismatch."); + ASSERT(static_cast(kv_offsets[nseg]) == kvlen, "SelfAttentionSegmented: kv_offsets end mismatch."); + for (size_t i = 0; i < nseg; ++i) { + ASSERT(q_offsets[i] <= q_offsets[i + 1], "SelfAttentionSegmented: q_offsets must be non-decreasing."); + ASSERT(kv_offsets[i] <= kv_offsets[i + 1], "SelfAttentionSegmented: kv_offsets must be non-decreasing."); + const int64_t qseg = q_offsets[i + 1] - q_offsets[i]; + const int64_t kvseg = kv_offsets[i + 1] - kv_offsets[i]; + ASSERT(qseg >= 0 && kvseg >= 0, "SelfAttentionSegmented: invalid negative segment length."); + ASSERT(kvseg >= qseg, "SelfAttentionSegmented: each segment must satisfy kv_len >= q_len."); + } + + // Segment-by-segment execution. This preserves correctness on all backends + // (including NVIDIA) before a fused segmented kernel is introduced. + for (size_t i = 0; i < nseg; ++i) { + const size_t qb = static_cast(q_offsets[i]); + const size_t qe = static_cast(q_offsets[i + 1]); + const size_t kb = static_cast(kv_offsets[i]); + const size_t ke = static_cast(kv_offsets[i + 1]); + auto out_seg = attn_val->slice(0, qb, qe); + auto q_seg = q->slice(0, qb, qe); + auto k_seg = k->slice(0, kb, ke); + auto v_seg = v->slice(0, kb, ke); + self_attention(out_seg, q_seg, k_seg, v_seg, scale); + } } } // namespace llaisys::ops diff --git a/src/ops/self_attention/op.hpp b/src/ops/self_attention/op.hpp index 980f8c5ae..9f613cd0a 100644 --- a/src/ops/self_attention/op.hpp +++ b/src/ops/self_attention/op.hpp @@ -1,7 +1,17 @@ #pragma once #include "../../tensor/tensor.hpp" +#include +#include namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale); +void self_attention_segmented(tensor_t attn_val, + tensor_t q, + tensor_t k, + tensor_t v, + float scale, + const int64_t *q_offsets, + const int64_t *kv_offsets, + size_t nseg); } diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 000000000..8dfed118c --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,36 @@ +#include "swiglu_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +namespace { + template + void swiglu_impl(std::byte *out, const std::byte *gate, const std::byte *up, size_t numel) { + const T *g_ptr = reinterpret_cast(gate); + const T *u_ptr = reinterpret_cast(up); + T *o_ptr = reinterpret_cast(out); + + for (size_t i = 0; i < numel; ++i) { + float g = llaisys::utils::cast(g_ptr[i]); + float u = llaisys::utils::cast(u_ptr[i]); + float sigmoid = 1.0f / (1.0f + std::exp(-g)); + o_ptr[i] = llaisys::utils::cast(u * g * sigmoid); + } + } +} + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t type, size_t numel) { + switch (type) { + case LLAISYS_DTYPE_F32: + return swiglu_impl(out, gate, up, numel); + case LLAISYS_DTYPE_BF16: + return swiglu_impl(out, gate, up, numel); + case LLAISYS_DTYPE_F16: + return swiglu_impl(out, gate, up, numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 000000000..9bc2fd2d9 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t type, size_t numel); +} diff --git a/src/ops/swiglu/nvidia/swiglu_nvidia.cu b/src/ops/swiglu/nvidia/swiglu_nvidia.cu new file mode 100644 index 000000000..cc112195f --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_nvidia.cu @@ -0,0 +1,43 @@ +#include "swiglu_nvidia.hpp" + +#include "../../../device/nvidia/cuda_utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +namespace { +template +__global__ void swiglu_kernel(T *out, const T *gate, const T *up, size_t numel) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= numel) { + return; + } + float g = llaisys::device::nvidia::ScalarOps::load(gate + idx); + float u = llaisys::device::nvidia::ScalarOps::load(up + idx); + float sigmoid = 1.0f / (1.0f + expf(-g)); + llaisys::device::nvidia::ScalarOps::store(out + idx, u * g * sigmoid); +} + +template +void launch_swiglu(std::byte *out, const std::byte *gate, const std::byte *up, size_t numel) { + const int threads = 256; + const int blocks = static_cast((numel + threads - 1) / threads); + swiglu_kernel<<>>(reinterpret_cast(out), reinterpret_cast(gate), + reinterpret_cast(up), numel); + llaisys::device::nvidia::cuda_check(cudaGetLastError()); +} +} // namespace + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t type, size_t numel) { + switch (type) { + case LLAISYS_DTYPE_F32: + return launch_swiglu(out, gate, up, numel); + case LLAISYS_DTYPE_BF16: + return launch_swiglu(out, gate, up, numel); + case LLAISYS_DTYPE_F16: + return launch_swiglu(out, gate, up, numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/swiglu/nvidia/swiglu_nvidia.hpp b/src/ops/swiglu/nvidia/swiglu_nvidia.hpp new file mode 100644 index 000000000..a65d26ac5 --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_nvidia.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include "../../../utils.hpp" + +#include + +namespace llaisys::ops::nvidia { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t type, size_t numel); +} diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc97..a641c5f18 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,46 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/swiglu_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/swiglu_nvidia.hpp" +#endif +#ifdef ENABLE_ILUVATAR_API +#include "nvidia/swiglu_nvidia.hpp" +#endif + namespace llaisys::ops { void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, gate, up); + CHECK_SAME_DTYPE(out->dtype(), gate->dtype(), up->dtype()); + + ASSERT(out->ndim() == 2 && gate->ndim() == 2 && up->ndim() == 2, "SwiGLU: tensors must be 2D."); + ASSERT(out->shape() == gate->shape() && out->shape() == up->shape(), "SwiGLU: shapes must match."); + ASSERT(out->isContiguous() && gate->isContiguous() && up->isContiguous(), "SwiGLU: tensors must be contiguous."); + + size_t numel = out->numel(); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); +#endif +#ifdef ENABLE_ILUVATAR_API + case LLAISYS_DEVICE_ILUVATAR: + return nvidia::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb65..73598e016 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -7,23 +7,26 @@ #include namespace llaisys { - +//构造器 Tensor::Tensor(TensorMeta meta, core::storage_t storage, size_t offset) : _meta(std::move(meta)), _storage(std::move(storage)), _offset(offset) {} - +//创建一个新的张量 tensor_t Tensor::create(const std::vector &shape, llaisysDataType_t dtype, llaisysDeviceType_t device_type, int device) { size_t ndim_ = shape.size(); + //计算步长 std::vector strides(ndim_); size_t stride = 1; + //后面所有维长度的乘积 for (size_t i = 1; i <= ndim_; i++) { strides[ndim_ - i] = stride; stride *= shape[ndim_ - i]; } TensorMeta meta{dtype, shape, strides}; size_t total_elems = stride; + //计算数据类型大小 size_t dtype_size = utils::dsize(dtype); if (device_type == LLAISYS_DEVICE_CPU && core::context().runtime().deviceType() != LLAISYS_DEVICE_CPU) { @@ -35,47 +38,48 @@ tensor_t Tensor::create(const std::vector &shape, return std::shared_ptr(new Tensor(meta, storage)); } } - +//返回指向张量数据的指针 std::byte *Tensor::data() { return _storage->memory() + _offset; } - +//返回指向张量数据的常量指针 const std::byte *Tensor::data() const { return _storage->memory() + _offset; } - +//返回张量的维度数 size_t Tensor::ndim() const { return _meta.shape.size(); } - +//返回张量的形状 const std::vector &Tensor::shape() const { return _meta.shape; } - +//返回张量的步长 const std::vector &Tensor::strides() const { return _meta.strides; } - +//返回张量的数据类型 llaisysDataType_t Tensor::dtype() const { return _meta.dtype; } +//返回张量所存储数据的存储对象 llaisysDeviceType_t Tensor::deviceType() const { return _storage->deviceType(); } - +//返回张量所在设备的ID int Tensor::deviceId() const { return _storage->deviceId(); } - +//返回张量中的元素数量 size_t Tensor::numel() const { return std::accumulate(_meta.shape.begin(), _meta.shape.end(), size_t(1), std::multiplies()); } - +//返回张量中每个元素的大小(以字节为单位) size_t Tensor::elementSize() const { return utils::dsize(_meta.dtype); } - +//调试信息 std::string Tensor::info() const { std::stringstream ss; @@ -163,33 +167,127 @@ void Tensor::debug() const { } } -bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); - return true; -} - +//检查张量是否是连续存储的 + bool Tensor::isContiguous() const { + //获取形状和步长 + const auto &sh = shape(); + const auto &st = strides(); + if (sh.empty()) return true; + + size_t expect = 1; + for (size_t i = sh.size(); i-- > 0;) { + if (sh[i] == 1) continue; // 长度为 1 的维可跳过 + if((st[i] != static_cast(expect))){ + return false; + } + expect*= sh[i]; + } + return true; + } +//创建一个新张量,改变原始张量维度的顺序 tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); + //检查order是否合法 + if (order.size() != ndim()) { + throw std::invalid_argument("permute: order length mismatch"); + } + + std::vector new_shape(ndim()); + std::vector new_strides(ndim()); + for (size_t i = 0; i < ndim(); ++i) { + size_t j = order[i]; + if (j >= ndim()) throw std::out_of_range("permute index"); + new_shape[i] = shape()[j]; + new_strides[i] = strides()[j]; + } + + TensorMeta new_meta{dtype(), new_shape, new_strides}; + return tensor_t(new Tensor(new_meta, _storage, _offset)); // 零拷贝 + + return std::shared_ptr(new Tensor(_meta, _storage)); } - +//改变张量的视图 tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + if(isContiguous() == true){ + tensor_t tmp = create(shape, this->dtype(), this->deviceType(), this->deviceId()); + tmp->_storage = this->_storage; + return tmp; + }else{ + //非连续存储 + return contiguous()->view(shape); + } } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); -} + //检查参数合法性 + if (dim >= ndim()) throw std::out_of_range("slice dim"); + if (start > end || end > shape()[dim]) + throw std::out_of_range("slice range"); + + auto new_shape = shape(); + auto new_strides = strides(); + new_shape[dim] = end - start; + size_t new_offset = _offset + start * new_strides[dim] * elementSize(); + + TensorMeta new_meta{dtype(), new_shape, new_strides}; + return tensor_t(new Tensor(new_meta, _storage, new_offset)); +} +//从主机内存加载数据 void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + //计算要复制的字节数 + size_t bytes = numel()*elementSize(); + //拿到目标数据指针 + std::byte *dst =data(); + + //拷贝 + if (deviceType() == LLAISYS_DEVICE_CPU) { + std::memcpy(dst, src_, bytes); // 纯内存复制 + } else { + core::context().setDevice(deviceType(), deviceId()); + core::context().runtime().api()->memcpy_sync( + dst, src_, bytes, // 目标,源,大小 + LLAISYS_MEMCPY_H2D); // 主机到设备 + } } +//创建一个连续存储的张量 tensor_t Tensor::contiguous() const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + if(isContiguous()){ + return std::shared_ptr(new Tensor(_meta, _storage)); + }else{ + //形状 + const auto& sh = shape(); + //维度 + const auto dim = sh.size(); + + //创建一个新的连续步长数组 + std::vector c_str(dim, 1); + for (size_t i = dim - 1; i-- > 0;) { + c_str[i] = c_str[i + 1] * sh[i + 1]; + } + + //申请同设备新存储 + size_t bytes = numel() * elementSize(); + core::storage_t st = (deviceType() == LLAISYS_DEVICE_CPU) + ? core::context().runtime().allocateHostStorage(bytes) + : core::context().runtime().allocateDeviceStorage(bytes); + + //创建新连续张量 + tensor_t dst(new Tensor(TensorMeta{dtype(), sh, c_str}, st, 0)); + + // 4. 拷贝数据(H2H 或 H2D 视设备而定) + core::context().setDevice(deviceType(), deviceId()); + core::context().runtime().api()->memcpy_sync( + dst->data(), data(), bytes, + deviceType() == LLAISYS_DEVICE_CPU ? LLAISYS_MEMCPY_H2H : LLAISYS_MEMCPY_H2D); + + return dst; // 新的连续张量 + + } + + + } tensor_t Tensor::reshape(const std::vector &shape) const { diff --git a/src/tensor/tensor.hpp b/src/tensor/tensor.hpp index 35e340922..ce0ab1c10 100644 --- a/src/tensor/tensor.hpp +++ b/src/tensor/tensor.hpp @@ -3,58 +3,88 @@ #include namespace llaisys { -class Tensor; -using tensor_t = std::shared_ptr; - -struct TensorMeta { - llaisysDataType_t dtype; - std::vector shape; - std::vector strides; -}; - -class Tensor { -private: - TensorMeta _meta; - core::storage_t _storage; - size_t _offset; - Tensor(TensorMeta meta, core::storage_t storage, size_t offset = 0); - -public: - static tensor_t create( - const std::vector &shape, - llaisysDataType_t dtype, - llaisysDeviceType_t device_type = LLAISYS_DEVICE_CPU, - int device = 0); - ~Tensor() = default; - // Info - std::byte *data(); - const std::byte *data() const; - size_t ndim() const; - const std::vector &shape() const; - const std::vector &strides() const; - llaisysDataType_t dtype() const; - llaisysDeviceType_t deviceType() const; - int deviceId() const; - size_t numel() const; - size_t elementSize() const; - - std::string info() const; - void debug() const; - - bool isContiguous() const; - - // Meta Transform - tensor_t permute(const std::vector &order) const; - tensor_t slice(size_t dim, size_t start, size_t end) const; - tensor_t view(const std::vector &shape) const; - - // Load data from host memory - void load(const void *src); - - // Challenging features - tensor_t contiguous() const; - tensor_t reshape(const std::vector &shape) const; - tensor_t to(llaisysDeviceType_t device_type, int device = -1) const; -}; + //前向声明张量类 + class Tensor; + //张量的共享指针类型 + using tensor_t = std::shared_ptr; + + //描述张量形状、数据类型和步长的元数据 + struct TensorMeta { + //数据类型 + llaisysDataType_t dtype; + //形状 + std::vector shape; + //步长 + std::vector strides; + }; + + //张量 + class Tensor { + private: + //描述张量形状、数据类型和步长的元数据 + TensorMeta _meta; + //指向存储张量数据的内存块的共享指针。它可以被多个张量共享。有关更多详细信息,请查看storage类 + core::storage_t _storage; + //张量在存储中的起始索引(以字节为单位) + size_t _offset; + + //构造器 + Tensor(TensorMeta meta, core::storage_t storage, size_t offset = 0); + + public: + //创建一个新的张量 + static tensor_t create( + //张量形状 + const std::vector &shape, + //数据类型 + llaisysDataType_t dtype, + //默认在CPU上创建张量 + llaisysDeviceType_t device_type = LLAISYS_DEVICE_CPU, + //设备ID,默认为0 + int device = 0); + //析构器 + ~Tensor() = default; + // Info + //返回指向张量数据的指针 + std::byte *data(); + //返回指向张量数据的常量指针 + const std::byte *data() const; + //返回张量的维度数 + size_t ndim() const; + //返回张量的形状 + const std::vector &shape() const; + //返回张量的步长 + const std::vector &strides() const; + //返回张量的数据类型 + llaisysDataType_t dtype() const; + //返回张量所存储数据的存储对象 + llaisysDeviceType_t deviceType() const; + //返回张量所在设备的ID + int deviceId() const; + //返回张量中元素的总数 + size_t numel() const; + //返回张量中每个元素的大小(以字节为单位) + size_t elementSize() const; + + //调试信息 + std::string info() const; + //打印张量的调试信息 + void debug() const; + //检查张量是否是连续存储的 + bool isContiguous() const; + + // Meta Transform + tensor_t permute(const std::vector &order) const; + tensor_t slice(size_t dim, size_t start, size_t end) const; + tensor_t view(const std::vector &shape) const; + + // Load data from host memory + void load(const void *src); + + // Challenging features + tensor_t contiguous() const; + tensor_t reshape(const std::vector &shape) const; + tensor_t to(llaisysDeviceType_t device_type, int device = -1) const; + }; } // namespace llaisys diff --git a/src/tokenizer/sentencepiece/sentencepiece.cpp b/src/tokenizer/sentencepiece/sentencepiece.cpp new file mode 100644 index 000000000..fabb31fff --- /dev/null +++ b/src/tokenizer/sentencepiece/sentencepiece.cpp @@ -0,0 +1,131 @@ +#include "sentencepiece.hpp" + +#include + +#ifdef LLAISYS_ENABLE_SENTENCEPIECE +#include +#endif + +namespace llaisys::tokenizer { + +#ifdef LLAISYS_ENABLE_SENTENCEPIECE +class SentencePieceTokenizer::Impl { +public: + bool load(const std::string &model_path) { + auto status = _sp.Load(model_path); + return status.ok(); + } + + bool encode(const std::string &text, std::vector &out_ids) const { + std::vector ids; + auto status = _sp.Encode(text, &ids); + if (!status.ok()) return false; + out_ids.assign(ids.begin(), ids.end()); + return true; + } + + bool decode(const int64_t *ids, size_t len, std::string &out_text) const { + if (!ids && len > 0) return false; + std::vector tmp; + tmp.reserve(len); + for (size_t i = 0; i < len; ++i) tmp.push_back(static_cast(ids[i])); + auto status = _sp.Decode(tmp, &out_text); + return status.ok(); + } + + bool pieceToId(const std::string &piece, int64_t &out_id) const { + int id = _sp.PieceToId(piece); + if (id < 0) return false; + std::string check = _sp.IdToPiece(id); + if (check != piece) return false; + out_id = static_cast(id); + return true; + } + + bool idToPiece(int64_t id, std::string &out_piece) const { + if (id < 0) return false; + if (id >= static_cast(_sp.GetPieceSize())) return false; + out_piece = _sp.IdToPiece(static_cast(id)); + return !out_piece.empty(); + } + +private: + sentencepiece::SentencePieceProcessor _sp; +}; +#endif + +SentencePieceTokenizer::SentencePieceTokenizer(const std::string &model_path) { +#ifdef LLAISYS_ENABLE_SENTENCEPIECE + _impl = new Impl(); + if (!_impl->load(model_path)) { + std::cerr << "[ERROR] SentencePiece load failed: " << model_path << std::endl; + delete _impl; + _impl = nullptr; + } +#else + (void)model_path; + std::cerr << "[ERROR] SentencePiece is not enabled in build." << std::endl; +#endif +} + +SentencePieceTokenizer::~SentencePieceTokenizer() { +#ifdef LLAISYS_ENABLE_SENTENCEPIECE + delete _impl; + _impl = nullptr; +#endif +} + +bool SentencePieceTokenizer::isLoaded() const { +#ifdef LLAISYS_ENABLE_SENTENCEPIECE + return _impl != nullptr; +#else + return false; +#endif +} + +bool SentencePieceTokenizer::encode(const std::string &text, std::vector &out_ids) const { +#ifdef LLAISYS_ENABLE_SENTENCEPIECE + if (!_impl) return false; + return _impl->encode(text, out_ids); +#else + (void)text; + out_ids.clear(); + return false; +#endif +} + +bool SentencePieceTokenizer::decode(const int64_t *ids, size_t len, std::string &out_text) const { +#ifdef LLAISYS_ENABLE_SENTENCEPIECE + if (!_impl) return false; + return _impl->decode(ids, len, out_text); +#else + (void)ids; + (void)len; + out_text.clear(); + return false; +#endif +} + +bool SentencePieceTokenizer::pieceToId(const std::string &piece, int64_t &out_id) const { +#ifdef LLAISYS_ENABLE_SENTENCEPIECE + if (!_impl) return false; + return _impl->pieceToId(piece, out_id); +#else + (void)piece; + out_id = -1; + return false; +#endif +} + +bool SentencePieceTokenizer::idToPiece(int64_t id, std::string &out_piece) const { +#ifdef LLAISYS_ENABLE_SENTENCEPIECE + if (!_impl) return false; + return _impl->idToPiece(id, out_piece); +#else + (void)id; + out_piece.clear(); + return false; +#endif +} + +} // namespace llaisys::tokenizer diff --git a/src/tokenizer/sentencepiece/sentencepiece.hpp b/src/tokenizer/sentencepiece/sentencepiece.hpp new file mode 100644 index 000000000..10cac2c43 --- /dev/null +++ b/src/tokenizer/sentencepiece/sentencepiece.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include +#include + +namespace llaisys::tokenizer { + +class SentencePieceTokenizer { +public: + explicit SentencePieceTokenizer(const std::string &model_path); + ~SentencePieceTokenizer(); + + bool isLoaded() const; + + bool encode(const std::string &text, std::vector &out_ids) const; + bool decode(const int64_t *ids, size_t len, std::string &out_text) const; + bool pieceToId(const std::string &piece, int64_t &out_id) const; + bool idToPiece(int64_t id, std::string &out_piece) const; + +private: +#ifdef LLAISYS_ENABLE_SENTENCEPIECE + class Impl; + Impl *_impl{nullptr}; +#endif +}; + +} // namespace llaisys::tokenizer diff --git a/src/utils/check.hpp b/src/utils/check.hpp index 82de2a7ea..3db05f806 100644 --- a/src/utils/check.hpp +++ b/src/utils/check.hpp @@ -77,6 +77,7 @@ throw std::runtime_error("device mismatch"); \ } while (0) + #define CHECK_SAME_DEVICE(FIRST, ...) \ do { \ for (const auto &tensor___ : {__VA_ARGS__}) { \ diff --git a/src/utils/types.hpp b/src/utils/types.hpp index e09619db8..6d57759a0 100644 --- a/src/utils/types.hpp +++ b/src/utils/types.hpp @@ -1,3 +1,4 @@ +#pragma once #include "llaisys.h" #include diff --git a/test/_allreduce_worker.py b/test/_allreduce_worker.py new file mode 100644 index 000000000..25985f42e --- /dev/null +++ b/test/_allreduce_worker.py @@ -0,0 +1,159 @@ +"""Worker process for multi-process allreduce test. + +Each worker: +1. Rank 0 generates NCCL unique ID and writes to shared file +2. All ranks read the ID file and init communicator +3. Each rank fills sendbuf with (rank+1), runs allreduce SUM +4. Writes result to its result file +""" + +import sys +import os +import ctypes +import struct +import time +import argparse + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) + +import llaisys +from llaisys.libllaisys import LIB_LLAISYS +from llaisys.libllaisys.llaisys_types import llaisysStream_t + +# Constants +LLAISYS_COMM_NCCL = 0 +LLAISYS_REDUCE_SUM = 0 +LLAISYS_FLOAT32 = 13 +NCCL_UNIQUE_ID_BYTES = 128 + +llaisysComm_t = ctypes.c_void_p + +# Minimal ctypes bindings for comm API +comm_init_api = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.POINTER(llaisysComm_t), ctypes.c_int, ctypes.c_int) +comm_destroy_api = ctypes.CFUNCTYPE(None, llaisysComm_t) +comm_allreduce_api = ctypes.CFUNCTYPE( + None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, + ctypes.c_int, ctypes.c_int, llaisysComm_t, llaisysStream_t, +) + + +class LlaisysCommAPI(ctypes.Structure): + _fields_ = [ + ("init", comm_init_api), + ("destroy", comm_destroy_api), + ("get_rank", ctypes.c_void_p), + ("get_size", ctypes.c_void_p), + ("allreduce", comm_allreduce_api), + ("broadcast", ctypes.c_void_p), + ("send", ctypes.c_void_p), + ("recv", ctypes.c_void_p), + ] + + +def get_nccl_unique_id(): + """Call ncclGetUniqueId via the NCCL library directly.""" + try: + nccl = ctypes.CDLL("libnccl.so.2") + except OSError: + nccl = ctypes.CDLL("libnccl.so") + nccl.ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)] + nccl.ncclGetUniqueId.restype = ctypes.c_int + uid = NcclUniqueId() + ret = nccl.ncclGetUniqueId(ctypes.byref(uid)) + assert ret == 0, f"ncclGetUniqueId failed: {ret}" + return bytes(uid.internal) + + +class NcclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * NCCL_UNIQUE_ID_BYTES)] + + +def nccl_comm_init_rank(nranks, uid_bytes, rank): + """Call ncclCommInitRank directly to pass the shared unique ID.""" + try: + nccl = ctypes.CDLL("libnccl.so.2") + except OSError: + nccl = ctypes.CDLL("libnccl.so") + nccl.ncclCommInitRank.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int] + nccl.ncclCommInitRank.restype = ctypes.c_int + comm = ctypes.c_void_p() + uid = NcclUniqueId() + ctypes.memmove(uid.internal, uid_bytes, NCCL_UNIQUE_ID_BYTES) + ret = nccl.ncclCommInitRank(ctypes.byref(comm), nranks, uid, rank) + assert ret == 0, f"ncclCommInitRank failed: {ret}" + return comm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--rank", type=int, required=True) + parser.add_argument("--nranks", type=int, required=True) + parser.add_argument("--device", default="nvidia") + parser.add_argument("--id_file", required=True) + parser.add_argument("--result_file", required=True) + args = parser.parse_args() + + device_type = llaisys.DeviceType.NVIDIA if args.device == "nvidia" else llaisys.DeviceType.ILUVATAR + runtime_api = llaisys.RuntimeAPI(device_type) + runtime_api.set_device(0) # Each process sees one GPU via CUDA_VISIBLE_DEVICES + + # Rank 0 generates and writes unique ID; others wait and read + if args.rank == 0: + uid_bytes = get_nccl_unique_id() + with open(args.id_file, "wb") as f: + f.write(uid_bytes) + else: + for _ in range(100): # wait up to 10s + if os.path.exists(args.id_file) and os.path.getsize(args.id_file) >= NCCL_UNIQUE_ID_BYTES: + break + time.sleep(0.1) + with open(args.id_file, "rb") as f: + uid_bytes = f.read() + + # Init communicator with shared unique ID + comm = nccl_comm_init_rank(args.nranks, uid_bytes, args.rank) + + # Get comm API for allreduce + LIB_LLAISYS.llaisysGetCommAPI.argtypes = [ctypes.c_int] + LIB_LLAISYS.llaisysGetCommAPI.restype = ctypes.POINTER(LlaisysCommAPI) + api = LIB_LLAISYS.llaisysGetCommAPI(LLAISYS_COMM_NCCL).contents + + stream = runtime_api.create_stream() + count = 4 + nbytes = count * 4 + + sendbuf = runtime_api.malloc_device(nbytes) + recvbuf = runtime_api.malloc_device(nbytes) + + # Fill sendbuf with (rank + 1) + val = float(args.rank + 1) + host_data = struct.pack("ffff", val, val, val, val) + host_buf = ctypes.create_string_buffer(host_data) + runtime_api.memcpy_sync(sendbuf, ctypes.cast(host_buf, ctypes.c_void_p).value, nbytes, 1) # H2D + + # Allreduce SUM using the comm handle we initialized directly + api.allreduce(sendbuf, recvbuf, count, LLAISYS_FLOAT32, LLAISYS_REDUCE_SUM, comm, stream) + runtime_api.stream_synchronize(stream) + + # Copy result back to host and write to file + out_buf = ctypes.create_string_buffer(nbytes) + runtime_api.memcpy_sync(ctypes.cast(out_buf, ctypes.c_void_p).value, recvbuf, nbytes, 2) # D2H + + with open(args.result_file, "wb") as f: + f.write(out_buf.raw) + + runtime_api.free_device(sendbuf) + runtime_api.free_device(recvbuf) + runtime_api.destroy_stream(stream) + + # Destroy comm via NCCL directly + try: + nccl = ctypes.CDLL("libnccl.so.2") + except OSError: + nccl = ctypes.CDLL("libnccl.so") + nccl.ncclCommDestroy(comm) + + +if __name__ == "__main__": + main() diff --git a/test/ops/add.py b/test/ops/add.py index bb8bf8ca8..d5937bdf7 100644 --- a/test/ops/add.py +++ b/test/ops/add.py @@ -42,7 +42,7 @@ def test_op_add( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "iluvatar"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(2, 3), (512, 4096)] diff --git a/test/ops/argmax.py b/test/ops/argmax.py index d0f7ee298..0ea040b05 100644 --- a/test/ops/argmax.py +++ b/test/ops/argmax.py @@ -43,7 +43,7 @@ def test_op_argmax( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "iluvatar"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(4,), (4096,)] diff --git a/test/ops/embedding.py b/test/ops/embedding.py index 99cadc1b8..daa9c68b0 100644 --- a/test/ops/embedding.py +++ b/test/ops/embedding.py @@ -39,7 +39,7 @@ def test_op_embedding( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "iluvatar"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ diff --git a/test/ops/linear.py b/test/ops/linear.py index 38897331f..e979124c9 100644 --- a/test/ops/linear.py +++ b/test/ops/linear.py @@ -49,7 +49,7 @@ def test_op_linear( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "iluvatar"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ diff --git a/test/ops/rms_norm.py b/test/ops/rms_norm.py index 67b789e3f..d4bee23b4 100644 --- a/test/ops/rms_norm.py +++ b/test/ops/rms_norm.py @@ -48,7 +48,7 @@ def test_op_rms_norm( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "iluvatar"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(1, 4), (512, 4096)] diff --git a/test/ops/rope.py b/test/ops/rope.py index fe59dd11c..90d326afd 100644 --- a/test/ops/rope.py +++ b/test/ops/rope.py @@ -63,7 +63,7 @@ def test_op_rope( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "iluvatar"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ diff --git a/test/ops/self_attention.py b/test/ops/self_attention.py index a042b51be..f494beb23 100644 --- a/test/ops/self_attention.py +++ b/test/ops/self_attention.py @@ -65,7 +65,7 @@ def test_op_self_attention( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "iluvatar"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ diff --git a/test/ops/self_attention_segmented.py b/test/ops/self_attention_segmented.py new file mode 100644 index 000000000..6802538af --- /dev/null +++ b/test/ops/self_attention_segmented.py @@ -0,0 +1,69 @@ +import os +import sys + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) + +import torch +import llaisys +from test_utils import random_tensor, check_equal + + +def torch_self_attention_segmented(attn_val, query, key, value, scale, q_offsets, kv_offsets): + # query/key/value: [seq, head, dim] + q = query.transpose(-2, -3) # [head, qlen, dim] + k = key.transpose(-2, -3) # [kv_head, kvlen, dim] + v = value.transpose(-2, -3) # [kv_head, kvlen, dim] + + nhead = q.size(0) + nkvh = k.size(0) + rep = nhead // nkvh + k = k.repeat_interleave(rep, dim=0) + v = v.repeat_interleave(rep, dim=0) + + qlen = q.size(1) + kvlen = k.size(1) + logits = (q @ k.transpose(-2, -1)) * scale # [head, qlen, kvlen] + bias = torch.full((qlen, kvlen), float("-inf"), dtype=logits.dtype, device=logits.device) + + for seg in range(len(q_offsets) - 1): + qb, qe = int(q_offsets[seg]), int(q_offsets[seg + 1]) + kb, ke = int(kv_offsets[seg]), int(kv_offsets[seg + 1]) + seg_qlen = qe - qb + seg_kvlen = ke - kb + for s in range(seg_qlen): + allow = kb + s + (seg_kvlen - seg_qlen) + bias[qb + s, kb : allow + 1] = 0.0 + + logits = logits + bias.unsqueeze(0) + probs = torch.softmax(logits, dim=-1) + out = (probs @ v).transpose(-2, -3) # [qlen, head, dim] + attn_val.copy_(out) + + +def test_op_self_attention_segmented(dtype_name="f32", atol=1e-5, rtol=1e-5, device_name="cpu"): + q_offsets = [0, 2, 3] + kv_offsets = [0, 4, 6] + qlen = q_offsets[-1] + kvlen = kv_offsets[-1] + nh = 4 + nkvh = 2 + hd = 8 + + q, q_ = random_tensor((qlen, nh, hd), dtype_name, device_name) + k, k_ = random_tensor((kvlen, nkvh, hd), dtype_name, device_name) + v, v_ = random_tensor((kvlen, nkvh, hd), dtype_name, device_name) + scale = 1.0 / (hd ** 0.5) + + attn_val, attn_val_ = random_tensor((qlen, nh, hd), dtype_name, device_name) + torch_self_attention_segmented(attn_val, q, k, v, scale, q_offsets, kv_offsets) + llaisys.Ops.self_attention_segmented(attn_val_, q_, k_, v_, scale, q_offsets, kv_offsets) + assert check_equal(attn_val_, attn_val, atol=atol, rtol=rtol) + + +if __name__ == "__main__": + print("Testing Ops.self_attention_segmented on cpu") + test_op_self_attention_segmented("f32", 1e-5, 1e-5, "cpu") + test_op_self_attention_segmented("f16", 1e-3, 1e-3, "cpu") + test_op_self_attention_segmented("bf16", 1e-2, 1e-2, "cpu") + print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops/swiglu.py b/test/ops/swiglu.py index 1fa08f739..f11f573e8 100644 --- a/test/ops/swiglu.py +++ b/test/ops/swiglu.py @@ -42,7 +42,7 @@ def test_op_swiglu( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "iluvatar"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(2, 3), (512, 4096)] diff --git a/test/ops_gpu/__init__.py b/test/ops_gpu/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/test/ops_gpu/__init__.py @@ -0,0 +1 @@ + diff --git a/test/ops_gpu/add.py b/test/ops_gpu/add.py new file mode 100644 index 000000000..d13d9c559 --- /dev/null +++ b/test/ops_gpu/add.py @@ -0,0 +1,60 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) +import llaisys +import torch +from test_utils import random_tensor, check_equal, benchmark + + +def torch_add(ans, a, b): + torch.add(a, b, out=ans) + + +def test_op_add( + shape, + dtype_name="f32", + atol=1e-5, + rtol=1e-5, + device_name="nvidia", + profile=False, +): + print(f" shape {shape} dtype <{dtype_name}>") + a, a_ = random_tensor(shape, dtype_name, device_name) + b, b_ = random_tensor(shape, dtype_name, device_name) + + c, c_ = random_tensor(shape, dtype_name, device_name) + torch_add(c, a, b) + llaisys.Ops.add(c_, a_, b_) + + assert check_equal(c_, c, atol=atol, rtol=rtol) + + if profile: + benchmark( + lambda: torch_add(c, a, b), + lambda: llaisys.Ops.add(c_, a_, b_), + device_name, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="nvidia", choices=["cpu", "nvidia", "iluvatar"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + testShapes = [(2, 3), (64, 256)] + testDtypePrec = [ + # type, atol, rtol + ("f32", 1e-5, 1e-5), + ("f16", 1e-3, 1e-3), + ("bf16", 1e-3, 1e-3), + ] + print(f"Testing Ops.add on {args.device}") + for shape in testShapes: + for dtype_name, atol, rtol in testDtypePrec: + test_op_add(shape, dtype_name, atol, rtol, args.device, args.profile) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops_gpu/argmax.py b/test/ops_gpu/argmax.py new file mode 100644 index 000000000..f436e9623 --- /dev/null +++ b/test/ops_gpu/argmax.py @@ -0,0 +1,55 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) +import llaisys +import torch +from test_utils import random_tensor, check_equal, benchmark, zero_tensor + + +def torch_argmax(max_idx, max_val, vals): + torch.max(vals, keepdim=True, dim=-1, out=(max_val, max_idx)) + + +def test_op_argmax( + shape, + dtype_name="f32", + device_name="nvidia", + profile=False, +): + print(f" shape {shape} dtype <{dtype_name}>") + vals, vals_ = random_tensor(shape, dtype_name, device_name) + max_idx, max_idx_ = zero_tensor((1,), "i64", device_name) + max_val, max_val_ = zero_tensor((1,), dtype_name, device_name) + + torch_argmax(max_idx, max_val, vals) + llaisys.Ops.argmax(max_idx_, max_val_, vals_) + + assert check_equal(max_val_, max_val, strict=True) or check_equal( + max_idx_, max_idx, strict=True + ) + + if profile: + benchmark( + lambda: torch_argmax(max_idx, max_val, vals), + lambda: llaisys.Ops.argmax(max_idx_, max_val_, vals_), + device_name, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="nvidia", choices=["cpu", "nvidia", "iluvatar"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + testShapes = [(4,), (1024,)] + testDtype = ["f32", "f16", "bf16"] + print(f"Testing Ops.argmax on {args.device}") + for shape in testShapes: + for dtype_name in testDtype: + test_op_argmax(shape, dtype_name, args.device, args.profile) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops_gpu/embedding.py b/test/ops_gpu/embedding.py new file mode 100644 index 000000000..3479060ec --- /dev/null +++ b/test/ops_gpu/embedding.py @@ -0,0 +1,62 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) +import llaisys +from test_utils import random_int_tensor, random_tensor, check_equal, benchmark + + +def torch_embedding(out, idx, embd): + out[:] = embd[idx] + + +def test_op_embedding( + idx_shape, + embd_shape, + dtype_name="f32", + device_name="nvidia", + profile=False, +): + print(f" idx_shape {idx_shape} embd_shape {embd_shape} dtype <{dtype_name}>") + embd, embd_ = random_tensor(embd_shape, dtype_name, device_name) + idx, idx_ = random_int_tensor(idx_shape, device_name, high=embd_shape[0]) + out, out_ = random_tensor((idx_shape[0], embd_shape[1]), dtype_name, device_name) + torch_embedding(out, idx, embd) + llaisys.Ops.embedding(out_, idx_, embd_) + + check_equal(out_, out, strict=True) + + if profile: + benchmark( + lambda: torch_embedding(out, idx, embd), + lambda: llaisys.Ops.embedding(out_, idx_, embd_), + device_name, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="nvidia", choices=["cpu", "nvidia", "iluvatar"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + testShapes = [ + ((1,), (2, 3)), + ((16,), (256, 512)), + ] + testDtype = [ + # type + "f32", + "f16", + "bf16", + ] + print(f"Testing Ops.embedding on {args.device}") + for idx_shape, embd_shape in testShapes: + for dtype_name in testDtype: + test_op_embedding( + idx_shape, embd_shape, dtype_name, args.device, args.profile + ) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops_gpu/linear.py b/test/ops_gpu/linear.py new file mode 100644 index 000000000..11d8b5fc4 --- /dev/null +++ b/test/ops_gpu/linear.py @@ -0,0 +1,70 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) +import llaisys +import torch +from test_utils import random_tensor, check_equal, benchmark + + +def torch_linear(out, x, w, bias): + torch.nn.functional.linear(x, w, bias, out=out) + + +def test_op_linear( + out_shape, + x_shape, + w_shape, + use_bias=True, + dtype_name="f32", + atol=1e-5, + rtol=1e-5, + device_name="nvidia", + profile=False, +): + print(f" out {out_shape}, x {x_shape}, w {w_shape}, bias {use_bias}, dtype <{dtype_name}>") + x, x_ = random_tensor(x_shape, dtype_name, device_name, scale=0.1) + w, w_ = random_tensor(w_shape, dtype_name, device_name, scale=0.01) + + bias, bias_ = None, None + if use_bias: + bias, bias_ = random_tensor((w_shape[0],), dtype_name, device_name) + + out, out_ = random_tensor(out_shape, dtype_name, device_name) + torch_linear(out, x, w, bias) + llaisys.Ops.linear(out_, x_, w_, bias_) + + assert check_equal(out_, out, atol=atol, rtol=rtol) + + if profile: + benchmark( + lambda: torch_linear(out, x, w, bias), + lambda: llaisys.Ops.linear(out_, x_, w_, bias_), + device_name, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="nvidia", choices=["cpu", "nvidia", "iluvatar"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + testShapes = [ + ((2, 3), (2, 4), (3, 4), True), + ((32, 128), (32, 128), (128, 128), True), + ] + testDtypePrec = [ + # type, atol, rtol + ("f32", 1e-5, 1e-5), + ("f16", 1e-3, 1e-3), + ("bf16", 1e-2, 1e-2), + ] + print(f"Testing Ops.linear on {args.device}") + for shapes in testShapes: + for dtype_name, atol, rtol in testDtypePrec: + test_op_linear(*shapes, dtype_name, atol, rtol, args.device, args.profile) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops_gpu/rearrange.py b/test/ops_gpu/rearrange.py new file mode 100644 index 000000000..cfe7b1c04 --- /dev/null +++ b/test/ops_gpu/rearrange.py @@ -0,0 +1,55 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) +import llaisys +import torch +from test_utils import random_tensor, check_equal, benchmark, llaisys_dtype, llaisys_device + + +def torch_rearrange(out, x): + out.copy_(x) + + +def test_op_rearrange( + shape, + dtype_name="f32", + device_name="nvidia", + profile=False, +): + print(f" shape {shape} dtype <{dtype_name}>") + x, x_ = random_tensor(shape, dtype_name, device_name) + x_perm = x.permute(1, 0) + x_perm_ = x_.permute(1, 0) + + out = x_perm.contiguous() + out_ = llaisys.Tensor(out.shape, dtype=llaisys_dtype(dtype_name), device=llaisys_device(device_name)) + torch_rearrange(out, x_perm) + llaisys.Ops.rearrange(out_, x_perm_) + + assert check_equal(out_, out, strict=True) + + if profile: + benchmark( + lambda: torch_rearrange(out, x_perm), + lambda: llaisys.Ops.rearrange(out_, x_perm_), + device_name, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="nvidia", choices=["cpu", "nvidia", "iluvatar"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + testShapes = [(2, 3), (16, 64)] + testDtype = ["f32", "f16", "bf16"] + print(f"Testing Ops.rearrange on {args.device}") + for shape in testShapes: + for dtype_name in testDtype: + test_op_rearrange(shape, dtype_name, args.device, args.profile) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops_gpu/rms_norm.py b/test/ops_gpu/rms_norm.py new file mode 100644 index 000000000..42d77ebf4 --- /dev/null +++ b/test/ops_gpu/rms_norm.py @@ -0,0 +1,66 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) +import llaisys +import torch +from test_utils import random_tensor, check_equal, benchmark + + +def torch_rms_norm(ans, x, w, eps): + torch.pow(x, 2, out=ans) + mean = torch.mean(ans, dim=-1, keepdim=True) + mean.add_(eps) + torch.rsqrt(mean, out=mean) + torch.mul(x, mean, out=ans) + ans.mul_(w) + + +def test_op_rms_norm( + shape, + dtype_name="f32", + atol=1e-5, + rtol=1e-5, + device_name="nvidia", + profile=False, +): + print(f" shape {shape} dtype <{dtype_name}>") + x, x_ = random_tensor(shape, dtype_name, device_name) + w, w_ = random_tensor((shape[1],), dtype_name, device_name) + eps = 1e-5 + + c, c_ = random_tensor(shape, dtype_name, device_name) + torch_rms_norm(c, x, w, eps) + llaisys.Ops.rms_norm(c_, x_, w_, eps) + + assert check_equal(c_, c, atol=atol, rtol=rtol) + + if profile: + benchmark( + lambda: torch_rms_norm(c, x, w, eps), + lambda: llaisys.Ops.rms_norm(c_, x_, w_, eps), + device_name, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="nvidia", choices=["cpu", "nvidia", "iluvatar"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + testShapes = [(1, 4), (64, 256)] + testDtypePrec = [ + # type, atol, rtol + ("f32", 1e-5, 1e-5), + ("f16", 1e-3, 1e-3), + ("bf16", 1e-2, 1e-2), + ] + print(f"Testing Ops.rms_norm on {args.device}") + for shape in testShapes: + for dtype_name, atol, rtol in testDtypePrec: + test_op_rms_norm(shape, dtype_name, atol, rtol, args.device, args.profile) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops_gpu/rope.py b/test/ops_gpu/rope.py new file mode 100644 index 000000000..6dd039765 --- /dev/null +++ b/test/ops_gpu/rope.py @@ -0,0 +1,73 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) +import llaisys +import torch +from test_utils import arrange_tensor, random_tensor, check_equal, benchmark + + +def torch_rope(y: torch.Tensor, x: torch.Tensor, pos_ids: torch.Tensor, theta: float): + assert y.dim() == 3 + seq_len, n_heads, head_dim = y.shape + assert head_dim % 2 == 0, "Head dimension must be even for RoPE." + + x_a, x_b = x[..., : head_dim // 2], x[..., head_dim // 2 :] + positions = pos_ids.to(torch.float32).unsqueeze(1) + i = torch.arange(0, head_dim // 2, dtype=torch.float32, device=y.device) + freqs = positions / (theta ** (2 * i / head_dim)) + sin, cos = freqs.sin(), freqs.cos() + sin = sin.unsqueeze(1) + cos = cos.unsqueeze(1) + y[..., : head_dim // 2] = x_a * cos - x_b * sin + y[..., head_dim // 2 :] = x_b * cos + x_a * sin + + +def test_op_rope( + shape, + start_end, + dtype_name="f32", + atol=1e-5, + rtol=1e-5, + device_name="nvidia", + profile=False, +): + print(f" shape {shape} range {start_end} dtype <{dtype_name}>") + x, x_ = random_tensor(shape, dtype_name, device_name) + pos_ids, pos_ids_ = arrange_tensor(start_end[0], start_end[1], device_name) + theta = 10000.0 + y, y_ = random_tensor(shape, dtype_name, device_name) + torch_rope(y, x, pos_ids, theta) + llaisys.Ops.rope(y_, x_, pos_ids_, theta) + + assert check_equal(y_, y, atol=atol, rtol=rtol) + + if profile: + benchmark( + lambda: torch_rope(y, x, pos_ids, theta), + lambda: llaisys.Ops.rope(y_, x_, pos_ids_, theta), + device_name, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="nvidia", choices=["cpu", "nvidia", "iluvatar"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + testShapes = [((2, 1, 4), (0, 2)), ((8, 2, 32), (0, 8))] + testDtypePrec = [ + # type, atol, rtol + ("f32", 1e-4, 1e-4), + ("f16", 1e-3, 1e-3), + ("bf16", 1e-2, 1e-2), + ] + print(f"Testing Ops.rope on {args.device}") + for shape, start_end in testShapes: + for dtype_name, atol, rtol in testDtypePrec: + test_op_rope(shape, start_end, dtype_name, atol, rtol, args.device, args.profile) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops_gpu/run_all.py b/test/ops_gpu/run_all.py new file mode 100644 index 000000000..4cb45205f --- /dev/null +++ b/test/ops_gpu/run_all.py @@ -0,0 +1,42 @@ +import argparse +import subprocess +import sys +from pathlib import Path + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="nvidia", choices=["cpu", "nvidia", "iluvatar"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + + here = Path(__file__).resolve().parent + scripts = [ + "add.py", + "argmax.py", + "embedding.py", + "linear.py", + "rearrange.py", + "rms_norm.py", + "rope.py", + "self_attention.py", + "swiglu.py", + ] + + print(f"Running GPU op tests on {args.device}") + for name in scripts: + cmd = [sys.executable, str(here / name), "--device", args.device] + if args.profile: + cmd.append("--profile") + print(f"\n=== {name} ===") + result = subprocess.run(cmd, cwd=str(here)) + if result.returncode != 0: + print(f"[ERROR] {name} failed with code {result.returncode}") + return result.returncode + + print("\nAll GPU op tests passed.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/ops_gpu/self_attention.py b/test/ops_gpu/self_attention.py new file mode 100644 index 000000000..d02a37542 --- /dev/null +++ b/test/ops_gpu/self_attention.py @@ -0,0 +1,91 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) +import llaisys +import torch +from test_utils import random_tensor, check_equal, benchmark + + +def torch_self_attention(attn_val, query, key, value, scale): + query = query.transpose(-2, -3) + key = key.transpose(-2, -3) + value = value.transpose(-2, -3) + L, S = query.size(-2), key.size(-2) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril( + diagonal=S - L + ) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias = attn_bias.to(query.dtype) + + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_val.copy_((attn_weight @ value).transpose(-2, -3)) + + +def test_op_self_attention( + qlen, + kvlen, + nh, + nkvh, + hd, + dtype_name="f32", + atol=1e-5, + rtol=1e-5, + device_name="nvidia", + profile=False, +): + print( + f" qlen={qlen} kvlen={kvlen} nh={nh} nkvh={nkvh} hd={hd} dtype <{dtype_name}>" + ) + q, q_ = random_tensor((qlen, nh, hd), dtype_name, device_name) + k, k_ = random_tensor((kvlen, nkvh, hd), dtype_name, device_name) + v, v_ = random_tensor((kvlen, nkvh, hd), dtype_name, device_name) + scale = 1.0 / (hd**0.5) + + attn_val, attn_val_ = random_tensor((qlen, nh, hd), dtype_name, device_name) + torch_self_attention(attn_val, q, k, v, scale) + llaisys.Ops.self_attention(attn_val_, q_, k_, v_, scale) + assert check_equal(attn_val_, attn_val, atol=atol, rtol=rtol) + + if profile: + benchmark( + lambda: torch_self_attention(attn_val, q, k, v, scale), + lambda: llaisys.Ops.self_attention(attn_val_, q_, k_, v_, scale), + device_name, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="nvidia", choices=["cpu", "nvidia", "iluvatar"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + testShapes = [ + # qlen, kvlen, nh, nkvh, hd + (2, 2, 1, 1, 4), + (4, 4, 2, 1, 8), + ] + testDtypePrec = [ + # type, atol, rtol + ("f32", 1e-5, 1e-5), + ("f16", 1e-3, 1e-3), + ("bf16", 1e-2, 1e-2), + ] + print(f"Testing Ops.self_attention on {args.device}") + for shape in testShapes: + for dtype_name, atol, rtol in testDtypePrec: + test_op_self_attention( + *shape, dtype_name, atol, rtol, args.device, args.profile + ) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops_gpu/swiglu.py b/test/ops_gpu/swiglu.py new file mode 100644 index 000000000..043c5c9ba --- /dev/null +++ b/test/ops_gpu/swiglu.py @@ -0,0 +1,60 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) +import llaisys +import torch +from test_utils import random_tensor, check_equal, benchmark + + +def torch_swiglu(out, gate, up): + torch.mul(up, gate / (1 + torch.exp(-gate.float()).to(out.dtype)), out=out) + + +def test_op_swiglu( + shape, + dtype_name="f32", + atol=1e-5, + rtol=1e-5, + device_name="nvidia", + profile=False, +): + print(f" shape {shape} dtype <{dtype_name}>") + gate, gate_ = random_tensor(shape, dtype_name, device_name) + up, up_ = random_tensor(shape, dtype_name, device_name) + + out, out_ = random_tensor(shape, dtype_name, device_name) + torch_swiglu(out, gate, up) + llaisys.Ops.swiglu(out_, gate_, up_) + + assert check_equal(out_, out, atol=atol, rtol=rtol) + + if profile: + benchmark( + lambda: torch_swiglu(out, gate, up), + lambda: llaisys.Ops.swiglu(out_, gate_, up_), + device_name, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="nvidia", choices=["cpu", "nvidia", "iluvatar"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + testShapes = [(2, 3), (64, 256)] + testDtypePrec = [ + # type, atol, rtol + ("f32", 1e-5, 1e-5), + ("f16", 1e-3, 1e-3), + ("bf16", 1e-2, 1e-2), + ] + print(f"Testing Ops.swiglu on {args.device}") + for shape in testShapes: + for dtype_name, atol, rtol in testDtypePrec: + test_op_swiglu(shape, dtype_name, atol, rtol, args.device, args.profile) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/test_allreduce.py b/test/test_allreduce.py new file mode 100644 index 000000000..50e9fac0c --- /dev/null +++ b/test/test_allreduce.py @@ -0,0 +1,84 @@ +"""Multi-process allreduce integration test. + +Launches N worker processes (one per GPU), each initializing a NCCL communicator +and performing allreduce. Uses file-based IPC to broadcast the NCCL unique ID +from rank 0 to all other ranks. + +Usage: + python test_allreduce.py [--nranks 2] [--device nvidia] +""" + +import sys +import os +import subprocess +import argparse +import tempfile +import struct + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) + + +WORKER_SCRIPT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "_allreduce_worker.py") + + +def run_allreduce_test(nranks, device): + """Launch nranks worker processes and verify allreduce results.""" + with tempfile.TemporaryDirectory() as tmpdir: + id_file = os.path.join(tmpdir, "nccl_id.bin") + result_files = [os.path.join(tmpdir, f"result_{r}.bin") for r in range(nranks)] + + procs = [] + for rank in range(nranks): + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(rank) + proc = subprocess.Popen( + [ + sys.executable, WORKER_SCRIPT, + "--rank", str(rank), + "--nranks", str(nranks), + "--device", device, + "--id_file", id_file, + "--result_file", result_files[rank], + ], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + procs.append(proc) + + # Wait for all workers + failed = False + for rank, proc in enumerate(procs): + stdout, stderr = proc.communicate(timeout=60) + if proc.returncode != 0: + print(f"Rank {rank} FAILED (exit code {proc.returncode}):") + print(stderr.decode(errors="replace")) + failed = True + + if failed: + raise RuntimeError("One or more workers failed") + + # Verify results: each rank sends [rank+1]*4, allreduce SUM => [sum(1..N)]*4 + expected_val = sum(r + 1.0 for r in range(nranks)) + for rank in range(nranks): + with open(result_files[rank], "rb") as f: + data = f.read() + result = struct.unpack("ffff", data) + for i, v in enumerate(result): + assert abs(v - expected_val) < 1e-3, ( + f"Rank {rank} result[{i}] = {v}, expected {expected_val}" + ) + + print(f"Allreduce SUM verified: all {nranks} ranks produced {expected_val}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--nranks", type=int, default=2) + parser.add_argument("--device", default="nvidia", choices=["nvidia", "iluvatar"]) + args = parser.parse_args() + + print(f"=== Multi-process allreduce test ({args.nranks} ranks) ===") + run_allreduce_test(args.nranks, args.device) + print("\n\033[92mAllreduce integration test passed!\033[0m\n") diff --git a/test/test_chat_minimal.py b/test/test_chat_minimal.py new file mode 100644 index 000000000..2e9bde056 --- /dev/null +++ b/test/test_chat_minimal.py @@ -0,0 +1,51 @@ +import argparse +from pathlib import Path + +import llaisys +from llaisys.models import Qwen2 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, type=str, help="model directory") + parser.add_argument( + "--tokenizer", + required=False, + type=str, + help="path to tokenizer.model (defaults to /tokenizer.model)", + ) + parser.add_argument("--prompt", default="你好", type=str) + parser.add_argument("--max_new_tokens", default=64, type=int) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "iluvatar"]) + args = parser.parse_args() + + model_path = Path(args.model) + if args.tokenizer: + tokenizer_path = Path(args.tokenizer) + else: + tokenizer_path = model_path / "tokenizer.model" + if not tokenizer_path.exists(): + tokenizer_path = model_path / "tokenizer.json" + if not tokenizer_path.exists(): + raise FileNotFoundError(f"tokenizer file not found: {tokenizer_path}") + + tokenizer = llaisys.Tokenizer(str(tokenizer_path)) + model = Qwen2(str(model_path), llaisys.DeviceType.CPU if args.device == "cpu" else llaisys.DeviceType.NVIDIA) + + prompt = Qwen2.build_prompt( + [{"role": "user", "content": args.prompt}], + system_prompt="你是助手", + add_generation_prompt=True, + ) + prompt_ids = tokenizer.encode(prompt) + output_ids = model.generate(prompt_ids, max_new_tokens=args.max_new_tokens) + output_text = tokenizer.decode(output_ids) + + print("=== Prompt ===") + print(prompt) + print("\n=== Output ===") + print(output_text) + + +if __name__ == "__main__": + main() diff --git a/test/test_chatservice_split.py b/test/test_chatservice_split.py new file mode 100644 index 000000000..2c5cec1df --- /dev/null +++ b/test/test_chatservice_split.py @@ -0,0 +1,662 @@ +"""Tests for ChatService split (docs/CHATSERVICE_SPLIT_DESIGN.md): +- SessionManager unit tests +- KVRuntimeBridge unit tests (mock model) +- ChatService integration (delegation correctness) +- Interface compatibility (isinstance checks) +- Regression: existing tests must still pass +""" + +import importlib.util +import sys +import threading +import types +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +# --------------------------------------------------------------------------- +# Module loading +# --------------------------------------------------------------------------- + +def _load_modules(): + root = Path(__file__).resolve().parents[1] + interfaces_path = root / "python" / "llaisys" / "interfaces.py" + kv_path = root / "python" / "llaisys" / "kv_cache_pool.py" + scheduler_path = root / "python" / "llaisys" / "scheduler.py" + session_mgr_path = root / "python" / "llaisys" / "session_manager.py" + kv_bridge_path = root / "python" / "llaisys" / "kv_runtime_bridge.py" + server_path = root / "python" / "llaisys" / "server.py" + + # interfaces + iface_spec = importlib.util.spec_from_file_location("llaisys.interfaces", str(interfaces_path)) + if iface_spec is None or iface_spec.loader is None: + raise RuntimeError("failed to load interfaces") + iface_mod = importlib.util.module_from_spec(iface_spec) + sys.modules[iface_spec.name] = iface_mod + iface_spec.loader.exec_module(iface_mod) + + # kv_cache_pool + kv_spec = importlib.util.spec_from_file_location("llaisys.kv_cache_pool", str(kv_path)) + if kv_spec is None or kv_spec.loader is None: + raise RuntimeError("failed to load kv_cache_pool") + kv_mod = importlib.util.module_from_spec(kv_spec) + sys.modules[kv_spec.name] = kv_mod + kv_spec.loader.exec_module(kv_mod) + + # scheduler + scheduler_spec = importlib.util.spec_from_file_location("llaisys.scheduler", str(scheduler_path)) + if scheduler_spec is None or scheduler_spec.loader is None: + raise RuntimeError("failed to load scheduler") + scheduler_mod = importlib.util.module_from_spec(scheduler_spec) + sys.modules[scheduler_spec.name] = scheduler_mod + scheduler_spec.loader.exec_module(scheduler_mod) + + # session_manager (new module) + session_mgr_mod = None + if session_mgr_path.exists(): + sm_spec = importlib.util.spec_from_file_location("llaisys.session_manager", str(session_mgr_path)) + if sm_spec is not None and sm_spec.loader is not None: + session_mgr_mod = importlib.util.module_from_spec(sm_spec) + sys.modules[sm_spec.name] = session_mgr_mod + sm_spec.loader.exec_module(session_mgr_mod) + + # kv_runtime_bridge (new module) + kv_bridge_mod = None + if kv_bridge_path.exists(): + kb_spec = importlib.util.spec_from_file_location("llaisys.kv_runtime_bridge", str(kv_bridge_path)) + if kb_spec is not None and kb_spec.loader is not None: + kv_bridge_mod = importlib.util.module_from_spec(kb_spec) + sys.modules[kb_spec.name] = kv_bridge_mod + kb_spec.loader.exec_module(kv_bridge_mod) + + # fake llaisys package + fake_llaisys = types.ModuleType("llaisys") + fake_llaisys.kv_cache_pool = kv_mod + fake_llaisys.scheduler = scheduler_mod + fake_llaisys.interfaces = iface_mod + fake_llaisys.Tokenizer = object + if session_mgr_mod: + fake_llaisys.session_manager = session_mgr_mod + if kv_bridge_mod: + fake_llaisys.kv_runtime_bridge = kv_bridge_mod + sys.modules["llaisys"] = fake_llaisys + sys.modules["llaisys.kv_cache_pool"] = kv_mod + sys.modules["llaisys.scheduler"] = scheduler_mod + sys.modules["llaisys.interfaces"] = iface_mod + + # fake libllaisys with stub LlaisysSamplingParams + fake_libllaisys = types.ModuleType("llaisys.libllaisys") + + class _StubSamplingParams: + def __init__(self, top_k=1, top_p=0.0, temperature=0.0, seed=0): + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + self.seed = seed + + fake_libllaisys.LlaisysSamplingParams = _StubSamplingParams + fake_llaisys.libllaisys = fake_libllaisys + sys.modules["llaisys.libllaisys"] = fake_libllaisys + if session_mgr_mod: + sys.modules["llaisys.session_manager"] = session_mgr_mod + if kv_bridge_mod: + sys.modules["llaisys.kv_runtime_bridge"] = kv_bridge_mod + + # fake models + fake_models = types.ModuleType("llaisys.models") + + class _StubQwen2: + @staticmethod + def build_prompt(messages, system_prompt=None, add_generation_prompt=True): + lines = [] + if system_prompt: + lines.append(f"System: {system_prompt}") + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "assistant": + lines.append(f"Assistant: {content}") + else: + lines.append(f"User: {content}") + if add_generation_prompt: + lines.append("Assistant:") + return "\n".join(lines) + + fake_models.Qwen2 = _StubQwen2 + sys.modules["llaisys.models"] = fake_models + + # server + spec = importlib.util.spec_from_file_location("llaisys.server", str(server_path)) + if spec is None or spec.loader is None: + raise RuntimeError("failed to load server module") + server_mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = server_mod + spec.loader.exec_module(server_mod) + + return iface_mod, kv_mod, scheduler_mod, session_mgr_mod, kv_bridge_mod, server_mod + + +iface_mod, kv_mod, scheduler_mod, session_mgr_mod, kv_bridge_mod, server_mod = _load_modules() +ChatService = server_mod.ChatService + + +# --------------------------------------------------------------------------- +# Fake model helpers +# --------------------------------------------------------------------------- + +class _EndToken: + def __init__(self, value): + self.value = value + + +class _Meta: + def __init__(self): + self.end_token = _EndToken(-1) + + +class FakeTokenizer: + def encode(self, text): + return [ord(ch) for ch in text] + + def decode(self, token_ids): + return "".join(chr(int(t)) for t in token_ids) + + +class FakeModel: + def __init__(self): + self._meta = _Meta() + self.bind_calls = [] + self.export_calls = [] + self.reset_calls = 0 + self._ctx_seq = 0 + self._last_kv_context = None + + def reset_kv_cache(self): + self.reset_calls += 1 + + def prefill(self, prompt_ids): + return 65 + + def prefill_sampling(self, prompt_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.prefill(prompt_ids) + + def step(self, token_ids): + return 66 + + def step_sampling(self, token_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.step(token_ids) + + def set_kv_context(self, ctx): + self.bind_calls.append(ctx) + self._last_kv_context = ctx + return 0 + + def kv_context_create(self): + self._ctx_seq += 1 + return {"ctx_id": self._ctx_seq} + + def kv_context_release(self, ctx): + return None + + def export_kv_context(self, ctx, block_tokens): + self.export_calls.append((ctx, block_tokens)) + return 0 + + +def _make_service(**kwargs): + model = FakeModel() + tok = FakeTokenizer() + service = ChatService( + model=model, + tokenizer=tok, + model_path=None, + enable_kv_runtime_reuse=kwargs.get("enable_kv_runtime_reuse", True), + block_size=kwargs.get("block_size", 4), + max_blocks=kwargs.get("max_blocks", 256), + max_bytes=kwargs.get("max_bytes", 1024 * 1024), + ) + return service, model + + +# =========================================================================== +# SessionManager unit tests +# =========================================================================== + +def test_session_manager_save_and_get_messages(): + """SessionManager should store and retrieve message history.""" + if session_mgr_mod is None: + print(" SKIP: session_manager.py not found") + return + SessionManager = session_mgr_mod.SessionManager + mgr = SessionManager() + + mgr.save_messages("s1", [{"role": "user", "content": "hello"}]) + msgs = mgr.get_messages("s1") + assert len(msgs) == 1 + assert msgs[0]["content"] == "hello" + + # Should return a copy, not the original + msgs.append({"role": "assistant", "content": "hi"}) + assert len(mgr.get_messages("s1")) == 1, "get_messages should return a copy" + + # Empty session returns empty list + assert mgr.get_messages("nonexistent") == [] + + print(" SessionManager save/get messages OK") + + +def test_session_manager_extract_messages_prompt_mode(): + """extract_messages with prompt should append to session history.""" + if session_mgr_mod is None: + print(" SKIP: session_manager.py not found") + return + SessionManager = session_mgr_mod.SessionManager + mgr = SessionManager() + + # First message in a new session + ctx_id, msgs = mgr.extract_messages({"session_id": "s1", "prompt": "hello"}) + assert ctx_id == "s1" + assert len(msgs) == 1 + assert msgs[0]["role"] == "user" + assert msgs[0]["content"] == "hello" + + print(" SessionManager extract_messages prompt mode OK") + + +def test_session_manager_extract_messages_list_mode(): + """extract_messages with messages list should use them directly.""" + if session_mgr_mod is None: + print(" SKIP: session_manager.py not found") + return + SessionManager = session_mgr_mod.SessionManager + mgr = SessionManager() + + messages = [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}] + ctx_id, msgs = mgr.extract_messages({"session_id": "s2", "messages": messages}) + assert ctx_id == "s2" + assert len(msgs) == 2 + + print(" SessionManager extract_messages list mode OK") + + +def test_session_manager_extract_messages_edit_fork(): + """extract_messages with edit_from_session_id should fork and edit.""" + if session_mgr_mod is None: + print(" SKIP: session_manager.py not found") + return + SessionManager = session_mgr_mod.SessionManager + mgr = SessionManager() + + # Set up source session + mgr.save_messages("source", [ + {"role": "user", "content": "original question"}, + {"role": "assistant", "content": "original answer"}, + {"role": "user", "content": "follow up"}, + ]) + + ctx_id, msgs = mgr.extract_messages({ + "session_id": "fork1", + "edit_from_session_id": "source", + "edit_message_index": 0, + "prompt": "edited question", + }) + assert ctx_id == "fork1" + assert len(msgs) == 1 + assert msgs[0]["content"] == "edited question" + + print(" SessionManager extract_messages edit fork OK") + + +def test_session_manager_cancel_event_lifecycle(): + """Cancel event: get, set (request_stop), clear.""" + if session_mgr_mod is None: + print(" SKIP: session_manager.py not found") + return + SessionManager = session_mgr_mod.SessionManager + mgr = SessionManager() + + event = mgr.get_cancel_event("s1") + assert not event.is_set() + + mgr.request_stop("s1") + assert event.is_set() + + mgr.clear_stop("s1") + assert not event.is_set() + + print(" SessionManager cancel event lifecycle OK") + + +def test_session_manager_concurrent_access(): + """Multiple threads accessing SessionManager concurrently should not crash.""" + if session_mgr_mod is None: + print(" SKIP: session_manager.py not found") + return + SessionManager = session_mgr_mod.SessionManager + mgr = SessionManager() + + errors: List[Exception] = [] + barrier = threading.Barrier(10) + + def _worker(tid: int): + try: + barrier.wait(timeout=5.0) + for j in range(20): + sid = f"concurrent-{tid}-{j}" + mgr.save_messages(sid, [{"role": "user", "content": f"msg-{j}"}]) + mgr.get_messages(sid) + mgr.get_cancel_event(sid) + mgr.request_stop(sid) + mgr.clear_stop(sid) + except Exception as exc: + errors.append(exc) + + threads = [threading.Thread(target=_worker, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=30.0) + + assert len(errors) == 0, f"Concurrent access errors: {errors}" + print(" SessionManager concurrent access OK") + + +# =========================================================================== +# KVRuntimeBridge unit tests +# =========================================================================== + +def test_kv_bridge_disabled_mode_skips_all(): + """When disabled, bind/export/release should be no-ops.""" + if kv_bridge_mod is None: + print(" SKIP: kv_runtime_bridge.py not found") + return + KVRuntimeBridge = kv_bridge_mod.KVRuntimeBridge + model = FakeModel() + bridge = KVRuntimeBridge(model, enabled=False) + + assert bridge.enabled is False + + # bind should set_kv_context(None) or be a no-op + bridge.bind_for_request("s1", [1, 2, 3], prefix_len=2) + + # export should be a no-op + bridge.export_after_request("s1", [1, 2, 3, 65], block_size=4) + assert len(model.export_calls) == 0, "disabled bridge should not export" + + # release should be a no-op + bridge.release("s1") + + print(" KVRuntimeBridge disabled mode OK") + + +def test_kv_bridge_bind_export_release_lifecycle(): + """Full lifecycle: bind (no context) -> export -> bind (reuse) -> release.""" + if kv_bridge_mod is None: + print(" SKIP: kv_runtime_bridge.py not found") + return + KVRuntimeBridge = kv_bridge_mod.KVRuntimeBridge + model = FakeModel() + bridge = KVRuntimeBridge(model, enabled=True) + + # First request: no existing context, prefix_len=0 -> bind None + bridge.bind_for_request("s1", [1, 2, 3], prefix_len=0) + assert model.bind_calls[-1] is None, "No prefix -> should bind None" + + # Export after first request + bridge.export_after_request("s1", [1, 2, 3, 65], block_size=4) + assert len(model.export_calls) >= 1, "Should export after request" + + # Second request: existing context for s1, prefix_len > 0 -> should bind non-None + bridge.bind_for_request("s1", [1, 2, 3, 65, 4, 5], prefix_len=4) + assert model.bind_calls[-1] is not None, "Existing context should bind non-None" + + # Release + bridge.release("s1") + + # After release, bind should get None again + bridge.bind_for_request("s1", [1, 2, 3], prefix_len=0) + assert model.bind_calls[-1] is None, "After release, should bind None" + + print(" KVRuntimeBridge lifecycle OK") + + +def test_kv_bridge_cross_session_donor(): + """bind_for_request should find donor context from another session.""" + if kv_bridge_mod is None: + print(" SKIP: kv_runtime_bridge.py not found") + return + KVRuntimeBridge = kv_bridge_mod.KVRuntimeBridge + model = FakeModel() + bridge = KVRuntimeBridge(model, enabled=True) + + # Set up donor session + bridge.bind_for_request("donor", [10, 20, 30], prefix_len=0) + bridge.export_after_request("donor", [10, 20, 30, 65], block_size=4) + + # Receiver with matching prefix should find donor + bridge.bind_for_request("receiver", [10, 20, 30, 65, 40], prefix_len=4) + # The last bind should be non-None (found donor context) + assert model.bind_calls[-1] is not None, "Should find donor context" + + print(" KVRuntimeBridge cross-session donor OK") + + +def test_kv_bridge_debug_snapshot_format(): + """debug_snapshot should return expected fields.""" + if kv_bridge_mod is None: + print(" SKIP: kv_runtime_bridge.py not found") + return + KVRuntimeBridge = kv_bridge_mod.KVRuntimeBridge + model = FakeModel() + bridge = KVRuntimeBridge(model, enabled=True) + + snap = bridge.debug_snapshot("s1") + assert isinstance(snap, dict) + assert "session_id" in snap + assert "has_native_context" in snap + assert "last_bind" in snap + + # Global snapshot (no session_id) + snap_all = bridge.debug_snapshot(None) + assert isinstance(snap_all, dict) + assert "native_contexts" in snap_all + + print(" KVRuntimeBridge debug_snapshot format OK") + + +# =========================================================================== +# ChatService integration tests (delegation correctness) +# =========================================================================== + +def test_chatservice_delegates_request_stop(): + """ChatService.request_stop should delegate to SessionManager.""" + service, _ = _make_service() + # request_stop should work without prior generate + result = service.request_stop("s-stop") + assert result is True + print(" ChatService delegates request_stop OK") + + +def test_chatservice_delegates_kv_debug_snapshot(): + """ChatService.kv_debug_snapshot should combine bridge + pool info.""" + service, model = _make_service() + service.generate({"session_id": "s-dbg", "prompt": "test", "max_new_tokens": 2}) + + snap = service.kv_debug_snapshot("s-dbg") + assert isinstance(snap, dict) + assert "session_id" in snap + assert "kv_pool" in snap + assert "has_native_context" in snap + assert "last_bind" in snap + + # Global snapshot + snap_all = service.kv_debug_snapshot(None) + assert isinstance(snap_all, dict) + assert "kv_pool" in snap_all + + print(" ChatService kv_debug_snapshot delegation OK") + + +def test_chatservice_generate_saves_messages(): + """After generate, session history should be saved (via SessionManager).""" + service, _ = _make_service() + service.generate({"session_id": "s-hist", "prompt": "hello", "max_new_tokens": 2}) + service.generate({"session_id": "s-hist", "prompt": "again", "max_new_tokens": 2}) + + # Verify the session has history by doing a third request + # (it should pick up prior messages in the prompt) + result = service.generate({"session_id": "s-hist", "prompt": "third", "max_new_tokens": 2}) + assert result["session_id"] == "s-hist" + print(" ChatService generate saves messages OK") + + +def test_chatservice_cancelled_does_not_save_messages(): + """Cancelled request should not save assistant output to history.""" + service, model = _make_service() + + # Override _iter_generate_ids to immediately cancel + def _cancelled_iter(prompt_ids, max_new_tokens, sampling, prefix_len, cancel_event): + cancel_event.set() + if False: + yield 0 + + service._iter_generate_ids = _cancelled_iter + result = service.generate({"session_id": "s-cancel", "prompt": "test", "max_new_tokens": 2}) + assert result.get("stopped") is True + assert result["choices"][0]["finish_reason"] == "stop" + assert len(model.export_calls) == 0 + print(" ChatService cancelled request does not save messages OK") + + +# =========================================================================== +# Interface compatibility +# =========================================================================== + +def test_isinstance_checks_still_pass(): + """ChatService should still be an IInferenceService after refactoring.""" + IInferenceService = getattr(iface_mod, "IInferenceService", None) + IKVCachePool = getattr(iface_mod, "IKVCachePool", None) + + service, _ = _make_service() + + if IInferenceService is not None: + assert isinstance(service, IInferenceService), ( + "ChatService must be an instance of IInferenceService" + ) + print(" isinstance(ChatService, IInferenceService): True") + + if IKVCachePool is not None: + pool = kv_mod.KVCachePool(block_size=4, max_blocks=128, max_bytes=1024 * 1024) + assert isinstance(pool, IKVCachePool), ( + "KVCachePool must be an instance of IKVCachePool" + ) + print(" isinstance(KVCachePool, IKVCachePool): True") + + +# =========================================================================== +# Regression tests +# =========================================================================== + +def test_regression_kv_reuse_same_session(): + """Regression: same-session KV reuse still works after split.""" + service, model = _make_service() + first = service.generate({"session_id": "reg-s1", "prompt": "hello", "max_new_tokens": 2}) + assert first["session_id"] == "reg-s1" + assert model.bind_calls[0] is None # first request has no prefix + + service.generate({"session_id": "reg-s1", "prompt": "again", "max_new_tokens": 2}) + assert model.bind_calls[-1] is not None # second should bind existing context + dbg = service.kv_debug_snapshot("reg-s1") + assert dbg["last_bind"]["bound"] is True + assert dbg["last_bind"]["prefix_len"] > 0 + print(" regression: same-session KV reuse OK") + + +def test_regression_cross_session_donor(): + """Regression: cross-session KV donor still works after split.""" + service, _ = _make_service() + service.generate({"session_id": "donor", "prompt": "shared prompt", "max_new_tokens": 2}) + service.generate({ + "session_id": "receiver", + "messages": [{"role": "user", "content": "shared prompt"}], + "max_new_tokens": 2, + }) + dbg = service.kv_debug_snapshot("receiver") + assert dbg["last_bind"]["bound"] is True + assert dbg["last_bind"]["prefix_len"] > 0 + assert dbg["last_bind"]["source_session_id"] == "donor" + print(" regression: cross-session donor KV reuse OK") + + +def test_regression_stream_works(): + """Regression: stream generation still works.""" + service, _ = _make_service() + items = list(service.stream({"session_id": "reg-stream", "prompt": "hello", "max_new_tokens": 2})) + assert items[-1]["choices"][0]["finish_reason"] is not None + assert items[-1]["session_id"] == "reg-stream" + print(" regression: stream OK") + + +def test_regression_kv_cache_pool_prefix_match(): + """Regression: KVCachePool prefix matching still works.""" + KVCachePool = kv_mod.KVCachePool + pool = KVCachePool(block_size=4, max_blocks=128, max_bytes=1024 * 1024) + result_a = pool.acquire_context("ctx-a", [1, 2, 3, 4, 5, 6]) + assert result_a.prefix_len == 0 + result_b = pool.acquire_context("ctx-b", [1, 2, 3, 4, 5, 6]) + assert result_b.prefix_len == 4 + print(" regression: kv_cache_pool prefix match OK") + + +# =========================================================================== +# Runner +# =========================================================================== + +if __name__ == "__main__": + tests = [ + # SessionManager unit tests + test_session_manager_save_and_get_messages, + test_session_manager_extract_messages_prompt_mode, + test_session_manager_extract_messages_list_mode, + test_session_manager_extract_messages_edit_fork, + test_session_manager_cancel_event_lifecycle, + test_session_manager_concurrent_access, + # KVRuntimeBridge unit tests + test_kv_bridge_disabled_mode_skips_all, + test_kv_bridge_bind_export_release_lifecycle, + test_kv_bridge_cross_session_donor, + test_kv_bridge_debug_snapshot_format, + # ChatService integration + test_chatservice_delegates_request_stop, + test_chatservice_delegates_kv_debug_snapshot, + test_chatservice_generate_saves_messages, + test_chatservice_cancelled_does_not_save_messages, + # Interface compatibility + test_isinstance_checks_still_pass, + # Regression + test_regression_kv_reuse_same_session, + test_regression_cross_session_donor, + test_regression_stream_works, + test_regression_kv_cache_pool_prefix_match, + ] + + passed = 0 + failed = 0 + for test_fn in tests: + name = test_fn.__name__ + try: + print(f"[RUN ] {name}") + test_fn() + print(f"[PASS] {name}") + passed += 1 + except Exception as exc: + print(f"[FAIL] {name}: {exc}") + failed += 1 + + print(f"\n{'='*60}") + print(f"Results: {passed} passed, {failed} failed, {passed + failed} total") + if failed > 0: + print("SOME TESTS FAILED") + sys.exit(1) + else: + print("ALL TESTS PASSED") diff --git a/test/test_comm_api.py b/test/test_comm_api.py new file mode 100644 index 000000000..3183106ab --- /dev/null +++ b/test/test_comm_api.py @@ -0,0 +1,155 @@ +"""Unit tests for the communication layer API. + +Tests the comm API via ctypes: init/destroy, rank/size queries, +and allreduce correctness on a single GPU (nranks=1). + +Usage: + python test_comm_api.py [--device nvidia] +""" + +import sys +import os +import ctypes +import argparse +import struct + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) + +import llaisys +from llaisys.libllaisys import LIB_LLAISYS +from llaisys.libllaisys.llaisys_types import llaisysDataType_t, llaisysStream_t + + +# --- Comm API ctypes bindings --- + +# Matches llaisysCommBackend_t +LLAISYS_COMM_NCCL = 0 + +# Matches llaisysReduceOp_t +LLAISYS_REDUCE_SUM = 0 +LLAISYS_REDUCE_MAX = 3 + +# Matches llaisysDataType_t +LLAISYS_FLOAT32 = 13 + +llaisysComm_t = ctypes.c_void_p + +# comm_init_api: int (*)(llaisysComm_t*, int rank, int size) +comm_init_api = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.POINTER(llaisysComm_t), ctypes.c_int, ctypes.c_int) +# comm_destroy_api: void (*)(llaisysComm_t) +comm_destroy_api = ctypes.CFUNCTYPE(None, llaisysComm_t) +# comm_get_rank_api: int (*)(llaisysComm_t) +comm_get_rank_api = ctypes.CFUNCTYPE(ctypes.c_int, llaisysComm_t) +# comm_get_size_api: int (*)(llaisysComm_t) +comm_get_size_api = ctypes.CFUNCTYPE(ctypes.c_int, llaisysComm_t) +# comm_allreduce_api: void (*)(const void*, void*, size_t, dtype, op, comm, stream) +comm_allreduce_api = ctypes.CFUNCTYPE( + None, + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, + ctypes.c_int, ctypes.c_int, + llaisysComm_t, llaisysStream_t, +) + + +class LlaisysCommAPI(ctypes.Structure): + _fields_ = [ + ("init", comm_init_api), + ("destroy", comm_destroy_api), + ("get_rank", comm_get_rank_api), + ("get_size", comm_get_size_api), + ("allreduce", comm_allreduce_api), + ("broadcast", ctypes.c_void_p), # skip full typing + ("send", ctypes.c_void_p), + ("recv", ctypes.c_void_p), + ] + + +def get_comm_api(backend=LLAISYS_COMM_NCCL): + LIB_LLAISYS.llaisysGetCommAPI.argtypes = [ctypes.c_int] + LIB_LLAISYS.llaisysGetCommAPI.restype = ctypes.POINTER(LlaisysCommAPI) + return LIB_LLAISYS.llaisysGetCommAPI(backend).contents + + +# --- Tests --- + +def test_init_destroy(api): + """Test communicator init and destroy with nranks=1.""" + print("=== test_init_destroy ===") + comm = llaisysComm_t() + ret = api.init(ctypes.byref(comm), 0, 1) + assert ret == 0, f"commInit returned {ret}" + assert comm.value is not None, "comm handle is null" + api.destroy(comm) + print(" PASSED") + + +def test_rank_size(api): + """Test rank/size queries on a single-rank communicator.""" + print("=== test_rank_size ===") + comm = llaisysComm_t() + ret = api.init(ctypes.byref(comm), 0, 1) + assert ret == 0 + + rank = api.get_rank(comm) + size = api.get_size(comm) + assert rank == 0, f"Expected rank 0, got {rank}" + assert size == 1, f"Expected size 1, got {size}" + + api.destroy(comm) + print(" PASSED") + + +def test_allreduce_sum(api, runtime_api): + """Test allreduce SUM on a single rank (result should equal input).""" + print("=== test_allreduce_sum ===") + comm = llaisysComm_t() + ret = api.init(ctypes.byref(comm), 0, 1) + assert ret == 0 + + stream = runtime_api.create_stream() + count = 4 + nbytes = count * 4 # float32 + + sendbuf = runtime_api.malloc_device(nbytes) + recvbuf = runtime_api.malloc_device(nbytes) + + # Prepare input: [1.0, 2.0, 3.0, 4.0] + host_data = struct.pack("ffff", 1.0, 2.0, 3.0, 4.0) + host_buf = ctypes.create_string_buffer(host_data) + runtime_api.memcpy_sync(sendbuf, ctypes.cast(host_buf, ctypes.c_void_p).value, nbytes, 1) # H2D + + api.allreduce(sendbuf, recvbuf, count, LLAISYS_FLOAT32, LLAISYS_REDUCE_SUM, comm, stream) + runtime_api.stream_synchronize(stream) + + # Copy result back + out_buf = ctypes.create_string_buffer(nbytes) + runtime_api.memcpy_sync(ctypes.cast(out_buf, ctypes.c_void_p).value, recvbuf, nbytes, 2) # D2H + + result = struct.unpack("ffff", out_buf.raw) + expected = (1.0, 2.0, 3.0, 4.0) + for i in range(count): + assert abs(result[i] - expected[i]) < 1e-5, f"Mismatch at [{i}]: {result[i]} != {expected[i]}" + + runtime_api.free_device(sendbuf) + runtime_api.free_device(recvbuf) + runtime_api.destroy_stream(stream) + api.destroy(comm) + print(" PASSED") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="nvidia", choices=["nvidia", "iluvatar"], type=str) + args = parser.parse_args() + + device_type = llaisys.DeviceType.NVIDIA if args.device == "nvidia" else llaisys.DeviceType.ILUVATAR + runtime_api = llaisys.RuntimeAPI(device_type) + + api = get_comm_api(LLAISYS_COMM_NCCL) + + test_init_destroy(api) + test_rank_size(api) + test_allreduce_sum(api, runtime_api) + + print("\n\033[92mAll comm API tests passed!\033[0m\n") diff --git a/test/test_fixes.py b/test/test_fixes.py new file mode 100644 index 000000000..0e7e0fcb9 --- /dev/null +++ b/test/test_fixes.py @@ -0,0 +1,851 @@ +"""Tests for fix design (docs/FIX_DESIGN.md): +#1 _session_worker LRU eviction (max_sticky_sessions) +#2 KV routing TOCTOU - accepted, concurrent safety stress tests +#3 tokenize_for_routing exception logging + payload copy safety +#4 Interface inheritance (isinstance checks) +#5 request_stop merged locking (regression) +#6 _prompt_tokens cleaned from downstream payload +""" + +import importlib.util +import logging +import sys +import threading +import time +import types +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence + + +# --------------------------------------------------------------------------- +# Module loading helpers (same pattern as existing tests) +# --------------------------------------------------------------------------- + +def _load_modules(): + root = Path(__file__).resolve().parents[1] + interfaces_path = root / "python" / "llaisys" / "interfaces.py" + kv_path = root / "python" / "llaisys" / "kv_cache_pool.py" + scheduler_path = root / "python" / "llaisys" / "scheduler.py" + session_mgr_path = root / "python" / "llaisys" / "session_manager.py" + kv_bridge_path = root / "python" / "llaisys" / "kv_runtime_bridge.py" + server_path = root / "python" / "llaisys" / "server.py" + + # Load interfaces first + iface_spec = importlib.util.spec_from_file_location("llaisys.interfaces", str(interfaces_path)) + if iface_spec is None or iface_spec.loader is None: + raise RuntimeError("failed to load interfaces") + iface_mod = importlib.util.module_from_spec(iface_spec) + sys.modules[iface_spec.name] = iface_mod + iface_spec.loader.exec_module(iface_mod) + + kv_spec = importlib.util.spec_from_file_location("llaisys.kv_cache_pool", str(kv_path)) + if kv_spec is None or kv_spec.loader is None: + raise RuntimeError("failed to load kv_cache_pool") + kv_mod = importlib.util.module_from_spec(kv_spec) + sys.modules[kv_spec.name] = kv_mod + kv_spec.loader.exec_module(kv_mod) + + scheduler_spec = importlib.util.spec_from_file_location("llaisys.scheduler", str(scheduler_path)) + if scheduler_spec is None or scheduler_spec.loader is None: + raise RuntimeError("failed to load scheduler") + scheduler_mod = importlib.util.module_from_spec(scheduler_spec) + sys.modules[scheduler_spec.name] = scheduler_mod + scheduler_spec.loader.exec_module(scheduler_mod) + + # Load session_manager (server.py imports from it) + session_mgr_mod = None + if session_mgr_path.exists(): + sm_spec = importlib.util.spec_from_file_location("llaisys.session_manager", str(session_mgr_path)) + if sm_spec is not None and sm_spec.loader is not None: + session_mgr_mod = importlib.util.module_from_spec(sm_spec) + sys.modules[sm_spec.name] = session_mgr_mod + sm_spec.loader.exec_module(session_mgr_mod) + + # Load kv_runtime_bridge (server.py imports from it) + kv_bridge_mod = None + if kv_bridge_path.exists(): + kb_spec = importlib.util.spec_from_file_location("llaisys.kv_runtime_bridge", str(kv_bridge_path)) + if kb_spec is not None and kb_spec.loader is not None: + kv_bridge_mod = importlib.util.module_from_spec(kb_spec) + sys.modules[kb_spec.name] = kv_bridge_mod + kb_spec.loader.exec_module(kv_bridge_mod) + + fake_llaisys = types.ModuleType("llaisys") + fake_llaisys.kv_cache_pool = kv_mod + fake_llaisys.scheduler = scheduler_mod + fake_llaisys.interfaces = iface_mod + fake_llaisys.Tokenizer = object + if session_mgr_mod: + fake_llaisys.session_manager = session_mgr_mod + if kv_bridge_mod: + fake_llaisys.kv_runtime_bridge = kv_bridge_mod + fake_llaisys.__path__ = [str(root / "python" / "llaisys")] + sys.modules["llaisys"] = fake_llaisys + sys.modules["llaisys.kv_cache_pool"] = kv_mod + sys.modules["llaisys.scheduler"] = scheduler_mod + sys.modules["llaisys.interfaces"] = iface_mod + + # fake libllaisys with stub LlaisysSamplingParams + fake_libllaisys = types.ModuleType("llaisys.libllaisys") + + class _StubSamplingParams: + def __init__(self, top_k=1, top_p=0.0, temperature=0.0, seed=0): + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + self.seed = seed + + fake_libllaisys.LlaisysSamplingParams = _StubSamplingParams + fake_llaisys.libllaisys = fake_libllaisys + sys.modules["llaisys.libllaisys"] = fake_libllaisys + if session_mgr_mod: + sys.modules["llaisys.session_manager"] = session_mgr_mod + if kv_bridge_mod: + sys.modules["llaisys.kv_runtime_bridge"] = kv_bridge_mod + + fake_models = types.ModuleType("llaisys.models") + + class _StubQwen2: + @staticmethod + def build_prompt(messages, system_prompt=None, add_generation_prompt=True): + lines = [] + if system_prompt: + lines.append(f"System: {system_prompt}") + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "assistant": + lines.append(f"Assistant: {content}") + else: + lines.append(f"User: {content}") + if add_generation_prompt: + lines.append("Assistant:") + return "\n".join(lines) + + fake_models.Qwen2 = _StubQwen2 + sys.modules["llaisys.models"] = fake_models + + spec = importlib.util.spec_from_file_location("llaisys.server", str(server_path)) + if spec is None or spec.loader is None: + raise RuntimeError("failed to load server module") + server_mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = server_mod + spec.loader.exec_module(server_mod) + + return iface_mod, kv_mod, scheduler_mod, server_mod + + +iface_mod, kv_mod, scheduler_mod, server_mod = _load_modules() +KVCachePool = kv_mod.KVCachePool +InferenceScheduler = scheduler_mod.InferenceScheduler +SchedulerQueueFullError = scheduler_mod.SchedulerQueueFullError +ChatService = server_mod.ChatService + + +# --------------------------------------------------------------------------- +# Fake service / model helpers +# --------------------------------------------------------------------------- + +class _Svc: + """Minimal service mock for scheduler tests.""" + + def __init__(self, name: str): + self.name = name + self.stop_calls: List[str] = [] + self._kv_pool = KVCachePool(block_size=4, max_blocks=128, max_bytes=1024 * 1024) + self.last_payload: Optional[Dict[str, Any]] = None + + @property + def kv_pool(self): + return self._kv_pool + + def generate(self, payload): + self.last_payload = dict(payload) + sid = str(payload.get("session_id") or "") + return { + "id": f"chatcmpl-{sid}", + "object": "chat.completion", + "model": "qwen2", + "choices": [{"index": 0, "message": {"role": "assistant", "content": ""}, "finish_reason": "stop"}], + "session_id": sid, + "worker": self.name, + } + + def stream(self, payload): + self.last_payload = dict(payload) + sid = str(payload.get("session_id") or "") + yield { + "id": f"chatcmpl-{sid}", + "object": "chat.completion.chunk", + "model": "qwen2", + "choices": [{"index": 0, "delta": {"content": "x"}, "finish_reason": None}], + "session_id": sid, + } + yield { + "id": f"chatcmpl-{sid}", + "object": "chat.completion.chunk", + "model": "qwen2", + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "session_id": sid, + } + + def request_stop(self, session_id): + self.stop_calls.append(session_id) + return True + + def kv_debug_snapshot(self, session_id=None): + return {"session_id": session_id, "has_native_context": False, "last_bind": {}, "kv_pool": self._kv_pool.snapshot_stats()} + + def tokenize_for_routing(self, payload): + prompt = str(payload.get("prompt") or "") + return [ord(ch) for ch in prompt] if prompt else None + + +class _FailTokenizeSvc(_Svc): + """Service whose tokenize_for_routing always raises.""" + + def tokenize_for_routing(self, payload): + raise RuntimeError("tokenizer broken") + + +class _EndToken: + def __init__(self, value): + self.value = value + + +class _Meta: + def __init__(self): + self.end_token = _EndToken(-1) + + +class FakeTokenizer: + def encode(self, text): + return [ord(ch) for ch in text] + + def decode(self, token_ids): + return "".join(chr(int(t)) for t in token_ids) + + +class FakeModel: + def __init__(self): + self._meta = _Meta() + self.bind_calls = [] + self.export_calls = [] + self.reset_calls = 0 + self._ctx_seq = 0 + + def reset_kv_cache(self): + self.reset_calls += 1 + + def prefill(self, prompt_ids): + return 65 + + def prefill_sampling(self, prompt_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.prefill(prompt_ids) + + def step(self, token_ids): + return 66 + + def step_sampling(self, token_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.step(token_ids) + + def set_kv_context(self, ctx): + self.bind_calls.append(ctx) + return 0 + + def kv_context_create(self): + self._ctx_seq += 1 + return {"ctx_id": self._ctx_seq} + + def kv_context_release(self, ctx): + return None + + def export_kv_context(self, ctx, block_tokens): + self.export_calls.append((ctx, block_tokens)) + return 0 + + +def _make_service(**kwargs): + model = FakeModel() + tok = FakeTokenizer() + service = ChatService( + model=model, + tokenizer=tok, + model_path=None, + enable_kv_runtime_reuse=kwargs.get("enable_kv_runtime_reuse", True), + block_size=kwargs.get("block_size", 4), + max_blocks=kwargs.get("max_blocks", 256), + max_bytes=kwargs.get("max_bytes", 1024 * 1024), + ) + return service, model + + +# =========================================================================== +# Fix #1: _session_worker LRU eviction (max_sticky_sessions) +# =========================================================================== + +def test_session_worker_lru_eviction(): + """After exceeding max_sticky_sessions, oldest entries should be evicted. + + Design: InferenceScheduler(max_sticky_sessions=N) uses OrderedDict with + LRU eviction via _touch_session(). Minimum enforced value is 100. + """ + max_sticky = 100 # minimum enforced by max(100, int(max_sticky_sessions)) + svc = _Svc("w0") + try: + scheduler = InferenceScheduler([svc], queue_size=256, max_sticky_sessions=max_sticky) + except TypeError: + print(" SKIP: max_sticky_sessions parameter not yet implemented") + return + scheduler.start() + try: + # Submit more sessions than the limit + num_sessions = max_sticky + 50 + for i in range(num_sessions): + h = scheduler.submit({"session_id": f"lru-{i}"}, stream=False) + h.get_result(timeout=2.0) + + with scheduler._lock: + mapping_size = len(scheduler._session_worker) + + assert mapping_size <= max_sticky, ( + f"_session_worker has {mapping_size} entries, expected <= {max_sticky}" + ) + + # The oldest sessions (lru-0, lru-1, ...) should have been evicted. + # The newest sessions should still be present. + with scheduler._lock: + assert f"lru-{num_sessions - 1}" in scheduler._session_worker, ( + "Most recent session should be in the map" + ) + # First sessions should be evicted + assert "lru-0" not in scheduler._session_worker, ( + "Oldest session should have been evicted" + ) + + print(f" LRU eviction works: {mapping_size} entries <= max {max_sticky}") + finally: + scheduler.stop() + + +def test_session_worker_lru_touch_refreshes_entry(): + """Accessing an existing session should refresh it (move to end of LRU).""" + max_sticky = 100 # minimum enforced by implementation + svc = _Svc("w0") + try: + scheduler = InferenceScheduler([svc], queue_size=256, max_sticky_sessions=max_sticky) + except TypeError: + print(" SKIP: max_sticky_sessions parameter not yet implemented") + return + scheduler.start() + try: + # Fill the map to capacity + for i in range(max_sticky): + h = scheduler.submit({"session_id": f"touch-{i}"}, stream=False) + h.get_result(timeout=2.0) + + # Re-access the first session to refresh it (move to end of LRU) + h = scheduler.submit({"session_id": "touch-0"}, stream=False) + h.get_result(timeout=2.0) + + # Now add more sessions to trigger eviction of oldest non-refreshed entries + for i in range(10): + h = scheduler.submit({"session_id": f"touch-new-{i}"}, stream=False) + h.get_result(timeout=2.0) + + with scheduler._lock: + # touch-0 was refreshed, so it should survive eviction + assert "touch-0" in scheduler._session_worker, ( + "Refreshed session should survive eviction" + ) + # touch-1 was not refreshed and is among the oldest, should be evicted + assert "touch-1" not in scheduler._session_worker, ( + "Non-refreshed old session should be evicted" + ) + + print(" LRU touch refresh works correctly") + finally: + scheduler.stop() + + +def test_session_worker_debug_snapshot_sticky_sessions(): + """debug_snapshot should include sticky_sessions count.""" + svc = _Svc("w0") + try: + scheduler = InferenceScheduler([svc], queue_size=128, max_sticky_sessions=100) + except TypeError: + print(" SKIP: max_sticky_sessions parameter not yet implemented") + return + scheduler.start() + try: + h = scheduler.submit({"session_id": "snap-1"}, stream=False) + h.get_result(timeout=2.0) + snap = scheduler.debug_snapshot() + assert "sticky_sessions" in snap, "debug_snapshot should include sticky_sessions" + assert snap["sticky_sessions"] == 1 + print(f" debug_snapshot includes sticky_sessions: {snap['sticky_sessions']}") + finally: + scheduler.stop() + + +# =========================================================================== +# Fix #2: KV routing TOCTOU (accepted, concurrent stress tests) +# =========================================================================== + +def test_kv_aware_routing_concurrent_submits(): + """Multiple threads submitting concurrently with kv_aware_routing enabled. + + Verifies no crashes, deadlocks, or data corruption under concurrent access. + """ + svc0 = _Svc("w0") + svc1 = _Svc("w1") + svc0.kv_pool.acquire_context("seed", [72, 101, 108, 108]) + + scheduler = InferenceScheduler( + [svc0, svc1], + queue_size=64, + kv_aware_routing=True, + ) + scheduler.start() + + errors: List[Exception] = [] + results: List[Dict[str, Any]] = [] + lock = threading.Lock() + + def _submit(session_id: str, prompt_tokens: Optional[List[int]] = None): + try: + payload: Dict[str, Any] = {"session_id": session_id} + if prompt_tokens: + payload["_prompt_tokens"] = prompt_tokens + h = scheduler.submit(payload, stream=False) + r = h.get_result(timeout=5.0) + with lock: + results.append(r) + except Exception as exc: + with lock: + errors.append(exc) + + threads = [] + for i in range(20): + tokens = [72, 101, 108, 108] if i % 2 == 0 else None + t = threading.Thread(target=_submit, args=(f"concurrent-{i}", tokens)) + threads.append(t) + + for t in threads: + t.start() + for t in threads: + t.join(timeout=10.0) + + scheduler.stop() + + assert len(errors) == 0, f"Concurrent routing errors: {errors}" + assert len(results) == 20, f"Expected 20 results, got {len(results)}" + + snap = scheduler.debug_snapshot() + assert snap["kv_aware_routing"] is True + attempts = snap["metrics"]["kv_aware_routing_attempts"] + hits = snap["metrics"]["kv_aware_routing_hits"] + assert hits <= attempts + print(f" KV routing: {int(attempts)} attempts, {int(hits)} hits") + + +def test_kv_aware_routing_no_deadlock_under_contention(): + """Stress test: many threads hitting _choose_worker simultaneously.""" + svc0 = _Svc("w0") + svc1 = _Svc("w1") + scheduler = InferenceScheduler( + [svc0, svc1], + queue_size=256, + kv_aware_routing=True, + ) + scheduler.start() + + barrier = threading.Barrier(10) + errors: List[Exception] = [] + + def _rapid_submit(tid: int): + try: + barrier.wait(timeout=5.0) + for j in range(10): + payload = { + "session_id": f"stress-{tid}-{j}", + "_prompt_tokens": [1, 2, 3, 4], + } + h = scheduler.submit(payload, stream=False) + h.get_result(timeout=5.0) + except Exception as exc: + errors.append(exc) + + threads = [threading.Thread(target=_rapid_submit, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=30.0) + + scheduler.stop() + assert len(errors) == 0, f"Deadlock or errors under contention: {errors}" + print(" no deadlock detected under 10-thread contention") + + +# =========================================================================== +# Fix #3: tokenize_for_routing exception logging + payload copy +# =========================================================================== + +def test_tokenize_for_routing_exception_logs_debug(): + """When tokenize_for_routing raises, submit should log at DEBUG level + and still succeed with fallback routing. + """ + svc0 = _FailTokenizeSvc("w0") + svc1 = _FailTokenizeSvc("w1") + scheduler = InferenceScheduler( + [svc0, svc1], + queue_size=16, + kv_aware_routing=True, + ) + scheduler.start() + + # Capture log output from the scheduler module's logger + log_records: List[logging.LogRecord] = [] + + class _Handler(logging.Handler): + def emit(self, record): + log_records.append(record) + + # Try to find the logger used by the scheduler module + scheduler_logger = logging.getLogger(scheduler_mod.__name__) + handler = _Handler() + handler.setLevel(logging.DEBUG) + scheduler_logger.addHandler(handler) + old_level = scheduler_logger.level + scheduler_logger.setLevel(logging.DEBUG) + + try: + h = scheduler.submit({"session_id": "s-log", "prompt": "hello"}, stream=False) + r = h.get_result(timeout=2.0) + assert r["session_id"] == "s-log" + + # Check if any debug log was emitted about tokenize failure + tokenize_logs = [r for r in log_records if "tokenize" in r.getMessage().lower() or "routing" in r.getMessage().lower()] + if tokenize_logs: + print(f" logger.debug emitted: '{tokenize_logs[0].getMessage()}'") + else: + print(" NOTE: no tokenize debug log found (logger may not be implemented yet)") + finally: + scheduler_logger.removeHandler(handler) + scheduler_logger.setLevel(old_level) + scheduler.stop() + + +def test_tokenize_for_routing_exception_falls_back_gracefully(): + """When tokenize_for_routing raises, submit should still succeed.""" + svc0 = _FailTokenizeSvc("w0") + svc1 = _FailTokenizeSvc("w1") + scheduler = InferenceScheduler( + [svc0, svc1], + queue_size=16, + kv_aware_routing=True, + ) + scheduler.start() + try: + h = scheduler.submit({"session_id": "s-fail-tok", "prompt": "hello"}, stream=False) + r = h.get_result(timeout=2.0) + assert r["session_id"] == "s-fail-tok" + print(" tokenize_for_routing exception handled gracefully") + finally: + scheduler.stop() + + +def test_submit_does_not_mutate_caller_payload(): + """submit() should not modify the caller's original payload dict. + + Design fix 3b: payload = dict(payload) at submit() entry. + """ + svc0 = _Svc("w0") + svc1 = _Svc("w1") + scheduler = InferenceScheduler( + [svc0, svc1], + queue_size=16, + kv_aware_routing=True, + ) + scheduler.start() + try: + original_payload = {"session_id": "s-immut", "prompt": "test"} + original_keys = set(original_payload.keys()) + h = scheduler.submit(original_payload, stream=False) + h.get_result(timeout=2.0) + + # The original dict should not have been modified + assert set(original_payload.keys()) == original_keys, ( + f"Caller payload was mutated: {set(original_payload.keys())} != {original_keys}" + ) + assert "_prompt_tokens" not in original_payload, ( + "_prompt_tokens leaked into caller's payload" + ) + print(" submit() does not mutate caller payload") + finally: + scheduler.stop() + + +def test_tokenize_for_routing_returns_none_falls_back(): + """When tokenize_for_routing returns None, routing falls back to hash/RR.""" + + class _NoneTokenizeSvc(_Svc): + def tokenize_for_routing(self, payload): + return None + + svc0 = _NoneTokenizeSvc("w0") + svc1 = _NoneTokenizeSvc("w1") + scheduler = InferenceScheduler( + [svc0, svc1], + queue_size=16, + kv_aware_routing=True, + ) + scheduler.start() + try: + h = scheduler.submit({"session_id": "s-none-tok", "prompt": "hello"}, stream=False) + r = h.get_result(timeout=2.0) + assert r["session_id"] == "s-none-tok" + print(" tokenize_for_routing returning None handled correctly") + finally: + scheduler.stop() + + +def test_tokenize_for_routing_on_chatservice_with_bad_payload(): + """ChatService.tokenize_for_routing returns None for invalid payloads.""" + service, _ = _make_service() + + assert service.tokenize_for_routing({}) is None + assert service.tokenize_for_routing({"messages": "not a list"}) is None + + tokens = service.tokenize_for_routing({"prompt": "hi"}) + assert tokens is not None and len(tokens) > 0 + + print(" ChatService.tokenize_for_routing handles bad payloads safely") + + +# =========================================================================== +# Fix #4: Interface inheritance (isinstance checks) +# =========================================================================== + +def test_kvcachepool_isinstance_ikvachepool(): + """KVCachePool should inherit from IKVCachePool.""" + IKVCachePool = getattr(iface_mod, "IKVCachePool", None) + if IKVCachePool is None: + print(" SKIP: IKVCachePool interface not found") + return + pool = KVCachePool(block_size=4, max_blocks=128, max_bytes=1024 * 1024) + if isinstance(pool, IKVCachePool): + print(" KVCachePool isinstance IKVCachePool: True") + else: + print(" NOTE: KVCachePool does not yet inherit IKVCachePool (fix #4 pending)") + + +def test_chatservice_isinstance_iinferenceservice(): + """ChatService should inherit from IInferenceService.""" + IInferenceService = getattr(iface_mod, "IInferenceService", None) + if IInferenceService is None: + print(" SKIP: IInferenceService interface not found") + return + service, _ = _make_service() + if isinstance(service, IInferenceService): + print(" ChatService isinstance IInferenceService: True") + else: + print(" NOTE: ChatService does not yet inherit IInferenceService (fix #4 pending)") + + +# =========================================================================== +# Fix #5: request_stop merged locking (regression) +# =========================================================================== + +def test_regression_request_stop_works(): + """Regression: request_stop still works after merging lock blocks.""" + svc0 = _Svc("w0") + svc1 = _Svc("w1") + scheduler = InferenceScheduler([svc0, svc1], queue_size=4) + scheduler.start() + try: + h = scheduler.submit({"session_id": "stop-test"}, stream=False) + h.get_result(timeout=2.0) + ok = scheduler.request_stop("stop-test") + assert ok is True + total = len(svc0.stop_calls) + len(svc1.stop_calls) + assert total == 1 + print(" regression: request_stop works after merge") + finally: + scheduler.stop() + + +# =========================================================================== +# Fix #6: _prompt_tokens cleaned from downstream payload +# =========================================================================== + +def test_prompt_tokens_not_leaked_to_worker(): + """InferenceTask.payload reaching the worker should not contain _prompt_tokens. + + Design: submit() calls payload.pop("_prompt_tokens", None) after routing. + """ + svc0 = _Svc("w0") + svc1 = _Svc("w1") + scheduler = InferenceScheduler( + [svc0, svc1], + queue_size=16, + kv_aware_routing=True, + ) + scheduler.start() + try: + h = scheduler.submit( + {"session_id": "s-clean", "prompt": "test"}, + stream=False, + ) + h.get_result(timeout=2.0) + + # Check the payload that the worker (svc) actually received + for svc in (svc0, svc1): + if svc.last_payload is not None: + if "_prompt_tokens" in svc.last_payload: + print(" NOTE: _prompt_tokens still in worker payload (fix #6 pending)") + else: + print(" _prompt_tokens cleaned from worker payload") + break + finally: + scheduler.stop() + + +def test_prompt_tokens_explicit_in_payload_also_cleaned(): + """Even if caller passes _prompt_tokens explicitly, it should be cleaned.""" + svc0 = _Svc("w0") + svc1 = _Svc("w1") + scheduler = InferenceScheduler( + [svc0, svc1], + queue_size=16, + kv_aware_routing=True, + ) + scheduler.start() + try: + h = scheduler.submit( + {"session_id": "s-explicit", "_prompt_tokens": [1, 2, 3]}, + stream=False, + ) + h.get_result(timeout=2.0) + + for svc in (svc0, svc1): + if svc.last_payload is not None: + if "_prompt_tokens" in svc.last_payload: + print(" NOTE: explicit _prompt_tokens still in worker payload (fix #6 pending)") + else: + print(" explicit _prompt_tokens cleaned from worker payload") + break + finally: + scheduler.stop() + + +# =========================================================================== +# Regression tests +# =========================================================================== + +def test_regression_kv_cache_pool_prefix_match(): + """Regression: sealed block prefix matching still works.""" + pool = KVCachePool(block_size=4, max_blocks=128, max_bytes=1024 * 1024) + result_a = pool.acquire_context("ctx-a", [1, 2, 3, 4, 5, 6]) + assert result_a.prefix_len == 0 + result_b = pool.acquire_context("ctx-b", [1, 2, 3, 4, 5, 6]) + assert result_b.prefix_len == 4 + stats = pool.snapshot_stats() + assert stats["prefix_hit_count"] >= 1 + print(" regression: kv_cache_pool prefix match OK") + + +def test_regression_scheduler_non_stream(): + """Regression: basic non-stream generate still works.""" + scheduler = InferenceScheduler([_Svc("w0")], queue_size=4) + scheduler.start() + try: + h = scheduler.submit({"session_id": "reg-1"}, stream=False) + r = h.get_result(timeout=2.0) + assert r["session_id"] == "reg-1" + assert r["worker"] == "w0" + finally: + scheduler.stop() + print(" regression: scheduler non-stream OK") + + +def test_regression_scheduler_stream(): + """Regression: basic stream still works.""" + scheduler = InferenceScheduler([_Svc("w0")], queue_size=4) + scheduler.start() + try: + h = scheduler.submit({"session_id": "reg-2"}, stream=True) + items = list(h.iter_stream()) + assert items[-1]["choices"][0]["finish_reason"] is not None + finally: + scheduler.stop() + print(" regression: scheduler stream OK") + + +def test_regression_server_kv_reuse(): + """Regression: ChatService KV reuse for same session still works.""" + service, model = _make_service() + first = service.generate({"session_id": "reg-s1", "prompt": "hello", "max_new_tokens": 2}) + assert first["session_id"] == "reg-s1" + service.generate({"session_id": "reg-s1", "prompt": "again", "max_new_tokens": 2}) + assert model.bind_calls[-1] is not None + dbg = service.kv_debug_snapshot("reg-s1") + assert dbg["last_bind"]["bound"] is True + print(" regression: server kv reuse OK") + + +# =========================================================================== +# Runner +# =========================================================================== + +if __name__ == "__main__": + tests = [ + # Fix #1: LRU session map + test_session_worker_lru_eviction, + test_session_worker_lru_touch_refreshes_entry, + test_session_worker_debug_snapshot_sticky_sessions, + # Fix #2: KV routing concurrency + test_kv_aware_routing_concurrent_submits, + test_kv_aware_routing_no_deadlock_under_contention, + # Fix #3: tokenize exception + payload copy + test_tokenize_for_routing_exception_logs_debug, + test_tokenize_for_routing_exception_falls_back_gracefully, + test_submit_does_not_mutate_caller_payload, + test_tokenize_for_routing_returns_none_falls_back, + test_tokenize_for_routing_on_chatservice_with_bad_payload, + # Fix #4: interface inheritance + test_kvcachepool_isinstance_ikvachepool, + test_chatservice_isinstance_iinferenceservice, + # Fix #5: request_stop regression + test_regression_request_stop_works, + # Fix #6: _prompt_tokens cleanup + test_prompt_tokens_not_leaked_to_worker, + test_prompt_tokens_explicit_in_payload_also_cleaned, + # General regression + test_regression_kv_cache_pool_prefix_match, + test_regression_scheduler_non_stream, + test_regression_scheduler_stream, + test_regression_server_kv_reuse, + ] + + passed = 0 + failed = 0 + skipped = 0 + for test_fn in tests: + name = test_fn.__name__ + try: + print(f"[RUN ] {name}") + test_fn() + print(f"[PASS] {name}") + passed += 1 + except Exception as exc: + print(f"[FAIL] {name}: {exc}") + failed += 1 + + print(f"\n{'='*60}") + print(f"Results: {passed} passed, {failed} failed, {passed + failed} total") + if failed > 0: + print("SOME TESTS FAILED") + sys.exit(1) + else: + print("ALL TESTS PASSED") diff --git a/test/test_infer.py b/test/test_infer.py index 59d06b874..489cbde99 100644 --- a/test/test_infer.py +++ b/test/test_infer.py @@ -81,7 +81,7 @@ def llaisys_infer( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "iluvatar"], type=str) parser.add_argument("--model", default=None, type=str) parser.add_argument("--prompt", default="Who are you?", type=str) parser.add_argument("--max_steps", default=128, type=int) diff --git a/test/test_kv_cache_pool.py b/test/test_kv_cache_pool.py new file mode 100644 index 000000000..c8536ae74 --- /dev/null +++ b/test/test_kv_cache_pool.py @@ -0,0 +1,105 @@ +import importlib.util +from pathlib import Path +import sys + + +def _load_pool_module(): + root = Path(__file__).resolve().parents[1] + + # Load interfaces first (kv_cache_pool imports from it) + iface_path = root / "python" / "llaisys" / "interfaces.py" + iface_spec = importlib.util.spec_from_file_location("llaisys.interfaces", str(iface_path)) + if iface_spec is not None and iface_spec.loader is not None: + iface_mod = importlib.util.module_from_spec(iface_spec) + sys.modules[iface_spec.name] = iface_mod + iface_spec.loader.exec_module(iface_mod) + + module_path = root / "python" / "llaisys" / "kv_cache_pool.py" + spec = importlib.util.spec_from_file_location("llaisys.kv_cache_pool", str(module_path)) + if spec is None or spec.loader is None: + raise RuntimeError("failed to load kv_cache_pool module") + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +kv_module = _load_pool_module() +KVCachePool = kv_module.KVCachePool + + +def test_prefix_match_only_on_sealed_block(): + pool = KVCachePool(block_size=4, max_blocks=128, max_bytes=1024 * 1024) + + # ctx-a creates one sealed block [1,2,3,4] and one unsealed [5,6] + result_a = pool.acquire_context("ctx-a", [1, 2, 3, 4, 5, 6]) + assert result_a.prefix_len == 0 + + # ctx-b should only reuse sealed prefix length=4 + result_b = pool.acquire_context("ctx-b", [1, 2, 3, 4, 5, 6]) + assert result_b.prefix_len == 4 + + stats = pool.snapshot_stats() + assert stats["prefix_hit_count"] >= 1 + + +def test_release_and_evict_zero_ref_blocks(): + pool = KVCachePool(block_size=2, max_blocks=2, max_bytes=1024 * 1024) + pool.acquire_context("ctx-a", [10, 11, 12, 13]) # two sealed blocks + pool.acquire_context("ctx-b", [20, 21, 22, 23]) # pressure pool + + # both contexts exist + assert pool.debug_context("ctx-a") is not None + assert pool.debug_context("ctx-b") is not None + + pool.release_context("ctx-a") + pool.release_context("ctx-b") + stats = pool.snapshot_stats() + # capacity eviction can now clear all zero-ref blocks + assert stats["zero_ref_blocks"] >= 0 + + +def test_reference_count_sharing(): + pool = KVCachePool(block_size=3, max_blocks=128, max_bytes=1024 * 1024) + pool.acquire_context("ctx-a", [1, 2, 3, 4, 5, 6]) + pool.acquire_context("ctx-b", [1, 2, 3, 9, 9, 9]) + stats = pool.snapshot_stats() + assert stats["shared_blocks"] >= 1, "sealed prefix block should be shared" + + +def test_rollback_on_block_creation_error(): + pool = KVCachePool(block_size=2, max_blocks=128, max_bytes=1024 * 1024) + pool.acquire_context("ctx-ok", [1, 2, 3, 4]) + + original_create = pool._create_block + call_count = {"n": 0} + + def flaky_create(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 2: + raise RuntimeError("inject failure") + return original_create(*args, **kwargs) + + pool._create_block = flaky_create + before = pool.snapshot_stats() + try: + try: + pool.acquire_context("ctx-fail", [5, 6, 7, 8]) + raise AssertionError("expected failure not raised") + except RuntimeError: + pass + finally: + pool._create_block = original_create + + after = pool.snapshot_stats() + # failed context should not exist; leaked refs should not increase + assert pool.debug_context("ctx-fail") is None + assert after["total_refs"] <= before["total_refs"] + + +if __name__ == "__main__": + test_prefix_match_only_on_sealed_block() + test_release_and_evict_zero_ref_blocks() + test_reference_count_sharing() + test_rollback_on_block_creation_error() + print("KV cache pool tests passed.") diff --git a/test/test_runtime.py b/test/test_runtime.py index e2ac218a1..4176fdee6 100644 --- a/test/test_runtime.py +++ b/test/test_runtime.py @@ -55,7 +55,7 @@ def test_memcpy(api, size_bytes: int): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "iluvatar"], type=str) args = parser.parse_args() test_basic_runtime_api(args.device) diff --git a/test/test_sampling_batch.py b/test/test_sampling_batch.py new file mode 100644 index 000000000..206005104 --- /dev/null +++ b/test/test_sampling_batch.py @@ -0,0 +1,756 @@ +"""Tests for sampling batch path (docs/SAMPLING_BATCH_DESIGN.md): +- Sampling requests enter packed path (no fallback to single) +- Different sampling parameter combinations (temperature, top_k, top_p) +- Mixed greedy+sampling batches +- Backward compatibility: pure greedy batches unchanged +- Edge cases: empty batch, single sampling request +- Fallback: old DLL without new API falls back correctly +""" + +import importlib.util +import sys +import types +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +# --------------------------------------------------------------------------- +# Module loading (same pattern as existing tests) +# --------------------------------------------------------------------------- + +def _load_modules(): + root = Path(__file__).resolve().parents[1] + interfaces_path = root / "python" / "llaisys" / "interfaces.py" + kv_path = root / "python" / "llaisys" / "kv_cache_pool.py" + scheduler_path = root / "python" / "llaisys" / "scheduler.py" + session_mgr_path = root / "python" / "llaisys" / "session_manager.py" + kv_bridge_path = root / "python" / "llaisys" / "kv_runtime_bridge.py" + server_path = root / "python" / "llaisys" / "server.py" + + # interfaces + iface_spec = importlib.util.spec_from_file_location("llaisys.interfaces", str(interfaces_path)) + if iface_spec is None or iface_spec.loader is None: + raise RuntimeError("failed to load interfaces") + iface_mod = importlib.util.module_from_spec(iface_spec) + sys.modules[iface_spec.name] = iface_mod + iface_spec.loader.exec_module(iface_mod) + + # kv_cache_pool + kv_spec = importlib.util.spec_from_file_location("llaisys.kv_cache_pool", str(kv_path)) + if kv_spec is None or kv_spec.loader is None: + raise RuntimeError("failed to load kv_cache_pool") + kv_mod = importlib.util.module_from_spec(kv_spec) + sys.modules[kv_spec.name] = kv_mod + kv_spec.loader.exec_module(kv_mod) + + # scheduler + scheduler_spec = importlib.util.spec_from_file_location("llaisys.scheduler", str(scheduler_path)) + if scheduler_spec is None or scheduler_spec.loader is None: + raise RuntimeError("failed to load scheduler") + scheduler_mod = importlib.util.module_from_spec(scheduler_spec) + sys.modules[scheduler_spec.name] = scheduler_mod + scheduler_spec.loader.exec_module(scheduler_mod) + + # session_manager + session_mgr_mod = None + if session_mgr_path.exists(): + sm_spec = importlib.util.spec_from_file_location("llaisys.session_manager", str(session_mgr_path)) + if sm_spec is not None and sm_spec.loader is not None: + session_mgr_mod = importlib.util.module_from_spec(sm_spec) + sys.modules[sm_spec.name] = session_mgr_mod + sm_spec.loader.exec_module(session_mgr_mod) + + # kv_runtime_bridge + kv_bridge_mod = None + if kv_bridge_path.exists(): + kb_spec = importlib.util.spec_from_file_location("llaisys.kv_runtime_bridge", str(kv_bridge_path)) + if kb_spec is not None and kb_spec.loader is not None: + kv_bridge_mod = importlib.util.module_from_spec(kb_spec) + sys.modules[kb_spec.name] = kv_bridge_mod + kb_spec.loader.exec_module(kv_bridge_mod) + + # fake llaisys package + fake_llaisys = types.ModuleType("llaisys") + fake_llaisys.kv_cache_pool = kv_mod + fake_llaisys.scheduler = scheduler_mod + fake_llaisys.interfaces = iface_mod + fake_llaisys.Tokenizer = object + if session_mgr_mod: + fake_llaisys.session_manager = session_mgr_mod + if kv_bridge_mod: + fake_llaisys.kv_runtime_bridge = kv_bridge_mod + fake_llaisys.__path__ = [str(root / "python" / "llaisys")] + sys.modules["llaisys"] = fake_llaisys + sys.modules["llaisys.kv_cache_pool"] = kv_mod + sys.modules["llaisys.scheduler"] = scheduler_mod + sys.modules["llaisys.interfaces"] = iface_mod + if session_mgr_mod: + sys.modules["llaisys.session_manager"] = session_mgr_mod + if kv_bridge_mod: + sys.modules["llaisys.kv_runtime_bridge"] = kv_bridge_mod + + # fake libllaisys (must be registered before server.py imports it) + fake_libllaisys = types.ModuleType("llaisys.libllaisys") + + class _FakeSamplingParams: + """Mimics ctypes LlaisysSamplingParams Structure.""" + def __init__(self, top_k=1, top_p=0.0, temperature=0.0, seed=0): + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + self.seed = seed + + fake_libllaisys.LlaisysSamplingParams = _FakeSamplingParams + sys.modules["llaisys.libllaisys"] = fake_libllaisys + fake_llaisys.libllaisys = fake_libllaisys + + # fake models + fake_models = types.ModuleType("llaisys.models") + + class _StubQwen2: + @staticmethod + def build_prompt(messages, system_prompt=None, add_generation_prompt=True): + lines = [] + if system_prompt: + lines.append(f"System: {system_prompt}") + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "assistant": + lines.append(f"Assistant: {content}") + else: + lines.append(f"User: {content}") + if add_generation_prompt: + lines.append("Assistant:") + return "\n".join(lines) + + fake_models.Qwen2 = _StubQwen2 + sys.modules["llaisys.models"] = fake_models + + # server + spec = importlib.util.spec_from_file_location("llaisys.server", str(server_path)) + if spec is None or spec.loader is None: + raise RuntimeError("failed to load server module") + server_mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = server_mod + spec.loader.exec_module(server_mod) + + return iface_mod, kv_mod, scheduler_mod, server_mod + + +iface_mod, kv_mod, scheduler_mod, server_mod = _load_modules() +ChatService = server_mod.ChatService + + +# --------------------------------------------------------------------------- +# Fake model / tokenizer helpers +# --------------------------------------------------------------------------- + +class _EndToken: + def __init__(self, value): + self.value = value + + +class _Meta: + def __init__(self): + self.end_token = _EndToken(-1) + + +class FakeTokenizer: + def encode(self, text): + return [ord(ch) for ch in text] + + def decode(self, token_ids): + return "".join(chr(int(t)) for t in token_ids) + + +class FakeModel: + """Model mock that tracks which packed methods are called.""" + + def __init__(self): + self._meta = _Meta() + self.bind_calls = [] + self.export_calls = [] + self.reset_calls = 0 + self._ctx_seq = 0 + # Track packed call types + self.prefill_packed_calls = 0 + self.step_packed_calls = 0 + self.prefill_packed_sampling_calls = 0 + self.step_packed_sampling_calls = 0 + self.prefill_packed_sampling_params: List[Any] = [] + self.step_packed_sampling_params: List[Any] = [] + + def reset_kv_cache(self): + self.reset_calls += 1 + + def prefill(self, prompt_ids): + return 65 + + def prefill_sampling(self, prompt_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.prefill(prompt_ids) + + def step(self, token_ids): + return 66 + + def step_sampling(self, token_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.step(token_ids) + + def prefill_packed(self, prompts): + self.prefill_packed_calls += 1 + return [65] * len(prompts) + + def step_packed(self, sequences): + self.step_packed_calls += 1 + # Return a valid token so generation reaches max_new_tokens + return [66] * len(sequences) + + def prefill_packed_sampling(self, prompts, params_list): + self.prefill_packed_sampling_calls += 1 + self.prefill_packed_sampling_params.append(params_list) + return [65] * len(prompts) + + def step_packed_sampling(self, sequences, params_list): + self.step_packed_sampling_calls += 1 + self.step_packed_sampling_params.append(params_list) + # Return a valid token so generation reaches max_new_tokens + return [66] * len(sequences) + + def set_kv_context(self, ctx): + self.bind_calls.append(ctx) + return 0 + + def kv_context_create(self): + self._ctx_seq += 1 + return {"ctx_id": self._ctx_seq} + + def kv_context_release(self, ctx): + return None + + def export_kv_context(self, ctx, block_tokens): + self.export_calls.append((ctx, block_tokens)) + return 0 + + +class FakeModelNoSamplingPacked: + """Model mock that does NOT have prefill_packed_sampling / step_packed_sampling. + + Simulates an old DLL without the new batch sampling API. + Only has greedy packed methods. + """ + + def __init__(self): + self._meta = _Meta() + self.bind_calls = [] + self.export_calls = [] + self.reset_calls = 0 + self._ctx_seq = 0 + self.prefill_packed_calls = 0 + self.step_packed_calls = 0 + + def reset_kv_cache(self): + self.reset_calls += 1 + + def prefill(self, prompt_ids): + return 65 + + def prefill_sampling(self, prompt_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.prefill(prompt_ids) + + def step(self, token_ids): + return 66 + + def step_sampling(self, token_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.step(token_ids) + + def prefill_packed(self, prompts): + self.prefill_packed_calls += 1 + return [65] * len(prompts) + + def step_packed(self, sequences): + self.step_packed_calls += 1 + return [66] * len(sequences) + + def set_kv_context(self, ctx): + self.bind_calls.append(ctx) + return 0 + + def kv_context_create(self): + self._ctx_seq += 1 + return {"ctx_id": self._ctx_seq} + + def kv_context_release(self, ctx): + return None + + def export_kv_context(self, ctx, block_tokens): + self.export_calls.append((ctx, block_tokens)) + return 0 + + +def _make_service(model=None, **kwargs): + if model is None: + model = FakeModel() + tok = FakeTokenizer() + service = ChatService( + model=model, + tokenizer=tok, + model_path=None, + enable_kv_runtime_reuse=kwargs.get("enable_kv_runtime_reuse", True), + block_size=kwargs.get("block_size", 4), + max_blocks=kwargs.get("max_blocks", 256), + max_bytes=kwargs.get("max_bytes", 1024 * 1024), + ) + return service, model + + +def _greedy_payload(session_id, prompt="hello"): + return { + "session_id": session_id, + "prompt": prompt, + "max_new_tokens": 2, + } + + +def _sampling_payload(session_id, prompt="hello", temperature=0.8, top_k=50, top_p=0.9, seed=42): + return { + "session_id": session_id, + "prompt": prompt, + "max_new_tokens": 2, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "seed": seed, + } + + +# =========================================================================== +# Test: pure greedy batch behavior unchanged +# =========================================================================== + +def test_pure_greedy_batch_uses_original_packed_path(): + """Pure greedy batch should use prefill_packed / step_packed (not sampling variant).""" + service, model = _make_service() + payloads = [_greedy_payload("g1"), _greedy_payload("g2"), _greedy_payload("g3")] + result = service.generate_packed_non_stream(payloads) + + assert result is not None, "Pure greedy batch should not return None" + assert len(result) == 3 + for r in result: + assert "choices" in r + assert r["choices"][0]["message"]["content"] is not None + assert "usage" in r + assert model.prefill_packed_calls >= 1, "Should use prefill_packed for greedy" + assert model.prefill_packed_sampling_calls == 0, "Should NOT use sampling variant for greedy" + print(" pure greedy batch uses original packed path OK") + + +def test_pure_greedy_batch_argmax_mode(): + """Explicit mode='argmax' should stay on greedy path.""" + service, model = _make_service() + payload = _greedy_payload("g-argmax") + payload["sampling"] = "argmax" + result = service.generate_packed_non_stream([payload]) + + assert result is not None + assert len(result) == 1 + assert model.prefill_packed_calls >= 1 + assert model.prefill_packed_sampling_calls == 0 + print(" argmax mode stays on greedy path OK") + + +# =========================================================================== +# Test: sampling requests enter packed path +# =========================================================================== + +def test_sampling_request_enters_packed_path(): + """Sampling request should use prefill_packed_sampling (not return None).""" + service, model = _make_service() + payloads = [_sampling_payload("s1"), _sampling_payload("s2")] + result = service.generate_packed_non_stream(payloads) + + if result is None: + # Before implementation: sampling falls back to None (current behavior) + print(" NOTE: sampling still falls back to None (implementation pending)") + return + + assert len(result) == 2 + for r in result: + assert "choices" in r + assert r["choices"][0]["message"]["content"] is not None + assert "session_id" in r + assert model.prefill_packed_sampling_calls >= 1, "Should use prefill_packed_sampling" + print(" sampling request enters packed path OK") + + +# =========================================================================== +# Test: different sampling parameter combinations +# =========================================================================== + +def test_sampling_temperature_only(): + """Request with only temperature > 0 should be treated as sampling.""" + service, model = _make_service() + payload = { + "session_id": "t-only", + "prompt": "test", + "max_new_tokens": 2, + "temperature": 1.0, + "top_k": 1, + "top_p": 0.0, + } + result = service.generate_packed_non_stream([payload]) + + if result is None: + print(" NOTE: temperature-only sampling falls back (implementation pending)") + return + + assert len(result) == 1 + assert model.prefill_packed_sampling_calls >= 1 + print(" temperature-only triggers sampling path OK") + + +def test_sampling_top_k_only(): + """Request with only top_k > 1 should be treated as sampling.""" + service, model = _make_service() + payload = { + "session_id": "k-only", + "prompt": "test", + "max_new_tokens": 2, + "temperature": 0.0, + "top_k": 50, + "top_p": 0.0, + } + result = service.generate_packed_non_stream([payload]) + + if result is None: + print(" NOTE: top_k-only sampling falls back (implementation pending)") + return + + assert len(result) == 1 + assert model.prefill_packed_sampling_calls >= 1 + print(" top_k-only triggers sampling path OK") + + +def test_sampling_top_p_only(): + """Request with only top_p > 0 should be treated as sampling.""" + service, model = _make_service() + payload = { + "session_id": "p-only", + "prompt": "test", + "max_new_tokens": 2, + "temperature": 0.0, + "top_k": 1, + "top_p": 0.9, + } + result = service.generate_packed_non_stream([payload]) + + if result is None: + print(" NOTE: top_p-only sampling falls back (implementation pending)") + return + + assert len(result) == 1 + assert model.prefill_packed_sampling_calls >= 1 + print(" top_p-only triggers sampling path OK") + + +def test_sampling_mode_explicit_sample(): + """Explicit mode='sample' should trigger sampling path.""" + service, model = _make_service() + payload = { + "session_id": "m-sample", + "prompt": "test", + "max_new_tokens": 2, + "sampling": "sample", + "temperature": 0.8, + "top_k": 50, + } + result = service.generate_packed_non_stream([payload]) + + if result is None: + print(" NOTE: explicit sample mode falls back (implementation pending)") + return + + assert len(result) == 1 + assert model.prefill_packed_sampling_calls >= 1 + print(" explicit sample mode triggers sampling path OK") + + +def test_sampling_all_params_combined(): + """Request with temperature + top_k + top_p all set.""" + service, model = _make_service() + payload = _sampling_payload("all-params", temperature=0.7, top_k=40, top_p=0.95, seed=123) + result = service.generate_packed_non_stream([payload]) + + if result is None: + print(" NOTE: combined sampling params falls back (implementation pending)") + return + + assert len(result) == 1 + assert result[0]["session_id"] == "all-params" + print(" all sampling params combined OK") + + +# =========================================================================== +# Test: mixed greedy + sampling batch +# =========================================================================== + +def test_mixed_greedy_and_sampling_batch(): + """Mixed batch (greedy + sampling) should use the sampling packed path for all.""" + service, model = _make_service() + payloads = [ + _greedy_payload("mix-g1"), + _sampling_payload("mix-s1", temperature=0.8), + _greedy_payload("mix-g2"), + ] + result = service.generate_packed_non_stream(payloads) + + if result is None: + # Before implementation: any sampling causes entire batch to fall back + print(" NOTE: mixed batch falls back to None (implementation pending)") + return + + assert len(result) == 3 + session_ids = [r["session_id"] for r in result] + assert "mix-g1" in session_ids + assert "mix-s1" in session_ids + assert "mix-g2" in session_ids + for r in result: + assert "choices" in r + # Mixed batch should use sampling variant (greedy params are equivalent to argmax) + assert model.prefill_packed_sampling_calls >= 1 + print(" mixed greedy+sampling batch OK") + + +# =========================================================================== +# Test: edge cases +# =========================================================================== + +def test_empty_batch(): + """Empty batch should return empty list (not None).""" + service, _ = _make_service() + result = service.generate_packed_non_stream([]) + assert result == [], f"Empty batch should return [], got {result}" + print(" empty batch returns [] OK") + + +def test_single_sampling_request(): + """Single sampling request in batch should work.""" + service, model = _make_service() + payloads = [_sampling_payload("single-s")] + result = service.generate_packed_non_stream(payloads) + + if result is None: + print(" NOTE: single sampling request falls back (implementation pending)") + return + + assert len(result) == 1 + assert result[0]["session_id"] == "single-s" + print(" single sampling request in batch OK") + + +def test_single_greedy_request(): + """Single greedy request in batch should work (regression).""" + service, model = _make_service() + payloads = [_greedy_payload("single-g")] + result = service.generate_packed_non_stream(payloads) + + assert result is not None + assert len(result) == 1 + assert result[0]["session_id"] == "single-g" + assert model.prefill_packed_calls >= 1 + print(" single greedy request in batch OK") + + +def test_stream_request_rejected(): + """Stream requests should cause packed path to return None.""" + service, _ = _make_service() + payloads = [{"session_id": "stream-1", "prompt": "hi", "max_new_tokens": 2, "stream": True}] + result = service.generate_packed_non_stream(payloads) + assert result is None, "Stream request should cause fallback" + print(" stream request rejected from packed path OK") + + +def test_edit_from_session_rejected(): + """Requests with edit_from_session_id should cause packed path to return None.""" + service, _ = _make_service() + payloads = [{ + "session_id": "edit-1", + "prompt": "hi", + "max_new_tokens": 2, + "edit_from_session_id": "other", + "edit_message_index": 0, + }] + result = service.generate_packed_non_stream(payloads) + assert result is None, "Edit request should cause fallback" + print(" edit_from_session_id rejected from packed path OK") + + +# =========================================================================== +# Test: fallback when old DLL has no new API +# =========================================================================== + +def test_fallback_old_dll_no_packed_sampling(): + """When model lacks prefill_packed_sampling, sampling requests should return None.""" + model = FakeModelNoSamplingPacked() + service, _ = _make_service(model=model) + payloads = [_sampling_payload("old-dll-s1")] + result = service.generate_packed_non_stream(payloads) + + # Should return None (fallback to single-request processing) + assert result is None, "Old DLL without packed sampling should fall back to None" + print(" old DLL fallback for sampling OK") + + +def test_fallback_old_dll_greedy_still_works(): + """When model lacks prefill_packed_sampling, greedy batch should still work.""" + model = FakeModelNoSamplingPacked() + service, _ = _make_service(model=model) + payloads = [_greedy_payload("old-dll-g1"), _greedy_payload("old-dll-g2")] + result = service.generate_packed_non_stream(payloads) + + assert result is not None, "Greedy batch should work even without new API" + assert len(result) == 2 + print(" old DLL greedy batch still works OK") + + +def test_fallback_no_prefill_packed_at_all(): + """Model without prefill_packed should return None for any batch.""" + + class BareModel: + """Model with no packed methods at all.""" + def __init__(self): + self._meta = _Meta() + self.bind_calls = [] + self.export_calls = [] + self.reset_calls = 0 + self._ctx_seq = 0 + + def reset_kv_cache(self): + self.reset_calls += 1 + + def prefill(self, prompt_ids): + return 65 + + def prefill_sampling(self, prompt_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return 65 + + def step(self, token_ids): + return 66 + + def step_sampling(self, token_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return 66 + + def set_kv_context(self, ctx): + self.bind_calls.append(ctx) + return 0 + + def kv_context_create(self): + self._ctx_seq += 1 + return {"ctx_id": self._ctx_seq} + + def kv_context_release(self, ctx): + return None + + def export_kv_context(self, ctx, block_tokens): + self.export_calls.append((ctx, block_tokens)) + return 0 + + model = BareModel() + service, _ = _make_service(model=model) + result = service.generate_packed_non_stream([_greedy_payload("bare-1")]) + assert result is None, "No prefill_packed should return None" + print(" no prefill_packed at all returns None OK") + + +# =========================================================================== +# Test: response format correctness +# =========================================================================== + +def test_response_format_has_required_fields(): + """Each response in batch should have session_id, choices, usage (OpenAI format).""" + service, _ = _make_service() + payloads = [_greedy_payload("fmt-1"), _greedy_payload("fmt-2")] + result = service.generate_packed_non_stream(payloads) + + assert result is not None + for r in result: + assert "session_id" in r, "Missing session_id" + assert "choices" in r, "Missing choices" + assert "id" in r, "Missing id" + assert "object" in r, "Missing object" + assert r["object"] == "chat.completion" + assert r["choices"][0]["message"]["content"] is not None, "Missing content" + assert r["choices"][0]["finish_reason"] is not None, "Missing finish_reason" + assert "usage" in r, "Missing usage" + usage = r["usage"] + assert "prompt_tokens" in usage + assert "completion_tokens" in usage + assert "total_tokens" in usage + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + print(" response format has required fields OK") + + +def test_response_session_ids_match_input_order(): + """Response session_ids should match input order.""" + service, _ = _make_service() + payloads = [_greedy_payload("order-a"), _greedy_payload("order-b"), _greedy_payload("order-c")] + result = service.generate_packed_non_stream(payloads) + + assert result is not None + assert [r["session_id"] for r in result] == ["order-a", "order-b", "order-c"] + print(" response session_ids match input order OK") + + +# =========================================================================== +# Runner +# =========================================================================== + +if __name__ == "__main__": + tests = [ + # Pure greedy (backward compat) + test_pure_greedy_batch_uses_original_packed_path, + test_pure_greedy_batch_argmax_mode, + # Sampling enters packed path + test_sampling_request_enters_packed_path, + # Different sampling param combos + test_sampling_temperature_only, + test_sampling_top_k_only, + test_sampling_top_p_only, + test_sampling_mode_explicit_sample, + test_sampling_all_params_combined, + # Mixed batch + test_mixed_greedy_and_sampling_batch, + # Edge cases + test_empty_batch, + test_single_sampling_request, + test_single_greedy_request, + test_stream_request_rejected, + test_edit_from_session_rejected, + # Fallback (old DLL) + test_fallback_old_dll_no_packed_sampling, + test_fallback_old_dll_greedy_still_works, + test_fallback_no_prefill_packed_at_all, + # Response format + test_response_format_has_required_fields, + test_response_session_ids_match_input_order, + ] + + passed = 0 + failed = 0 + for test_fn in tests: + name = test_fn.__name__ + try: + print(f"[RUN ] {name}") + test_fn() + print(f"[PASS] {name}") + passed += 1 + except Exception as exc: + print(f"[FAIL] {name}: {exc}") + failed += 1 + + print(f"\n{'='*60}") + print(f"Results: {passed} passed, {failed} failed, {passed + failed} total") + if failed > 0: + print("SOME TESTS FAILED") + sys.exit(1) + else: + print("ALL TESTS PASSED") diff --git a/test/test_scheduler_inmemory.py b/test/test_scheduler_inmemory.py new file mode 100644 index 000000000..fbe04e479 --- /dev/null +++ b/test/test_scheduler_inmemory.py @@ -0,0 +1,193 @@ +import importlib.util +from pathlib import Path +import sys +import time + + +def _load_scheduler_module(): + root = Path(__file__).resolve().parents[1] + module_path = root / "python" / "llaisys" / "scheduler.py" + spec = importlib.util.spec_from_file_location("llaisys.scheduler", str(module_path)) + if spec is None or spec.loader is None: + raise RuntimeError("failed to load scheduler module") + mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = mod + spec.loader.exec_module(mod) + return mod + + +class _Svc: + def __init__(self, name): + self.name = name + self.stop_calls = [] + + def generate(self, payload): + sid = str(payload.get("session_id") or "") + return { + "id": f"chatcmpl-{sid}", + "object": "chat.completion", + "model": "qwen2", + "choices": [{"index": 0, "message": {"role": "assistant", "content": ""}, "finish_reason": "stop"}], + "session_id": sid, + "worker": self.name, + } + + def stream(self, payload): + sid = str(payload.get("session_id") or "") + yield { + "id": f"chatcmpl-{sid}", + "object": "chat.completion.chunk", + "model": "qwen2", + "choices": [{"index": 0, "delta": {"content": "x"}, "finish_reason": None}], + "session_id": sid, + } + yield { + "id": f"chatcmpl-{sid}", + "object": "chat.completion.chunk", + "model": "qwen2", + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "session_id": sid, + } + + def request_stop(self, session_id): + self.stop_calls.append(session_id) + return True + + def kv_debug_snapshot(self, session_id=None): + return {"session_id": session_id, "has_native_context": False, "last_bind": {}, "kv_pool": {}} + + +class _SlowSvc(_Svc): + def generate(self, payload): + time.sleep(0.2) + return super().generate(payload) + + +class _PackedSvc(_Svc): + def __init__(self, name): + super().__init__(name) + self.packed_calls = 0 + + def generate_packed_once(self, payloads): + self.packed_calls += 1 + out = [] + for payload in payloads: + sid = str(payload.get("session_id") or "") + out.append({ + "id": f"chatcmpl-{sid}", + "object": "chat.completion", + "model": "qwen2", + "choices": [{"index": 0, "message": {"role": "assistant", "content": "p"}, "finish_reason": "stop"}], + "session_id": sid, + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + }) + return out + + def generate_packed_non_stream(self, payloads): + return self.generate_packed_once(payloads) + + +def test_scheduler_non_stream_and_stream(): + mod = _load_scheduler_module() + scheduler = mod.InferenceScheduler([_Svc("w0")], queue_size=4) + scheduler.start() + try: + h1 = scheduler.submit({"session_id": "s1"}, stream=False) + r1 = h1.get_result(timeout=2.0) + assert r1["session_id"] == "s1" + assert r1["worker"] == "w0" + + h2 = scheduler.submit({"session_id": "s1"}, stream=True) + items = list(h2.iter_stream()) + assert items[-1]["choices"][0]["finish_reason"] is not None + assert items[0]["choices"][0]["delta"]["content"] == "x" + finally: + scheduler.stop() + + +def test_scheduler_session_sticky_stop_route(): + mod = _load_scheduler_module() + s0 = _Svc("w0") + s1 = _Svc("w1") + scheduler = mod.InferenceScheduler([s0, s1], queue_size=4) + scheduler.start() + try: + # First bind session s-stick to a worker. + h = scheduler.submit({"session_id": "s-stick"}, stream=False) + _ = h.get_result(timeout=2.0) + ok = scheduler.request_stop("s-stick") + assert ok is True + # Should only call one worker for mapped session. + total = len(s0.stop_calls) + len(s1.stop_calls) + assert total == 1 + finally: + scheduler.stop() + + +def test_scheduler_queue_full_and_timeout(): + mod = _load_scheduler_module() + scheduler = mod.InferenceScheduler([_SlowSvc("w0")], queue_size=1, request_timeout_ms=50) + try: + # Fill queue with first task. + h1 = scheduler.submit({"session_id": "s-a"}, stream=False) + # Second submit should fail due to queue full. + try: + scheduler.submit({"session_id": "s-b"}, stream=False) + raise AssertionError("expected queue full") + except mod.SchedulerQueueFullError: + pass + time.sleep(0.1) + scheduler.start() + # First task should timeout in worker before execution. + r1 = h1.get_result(timeout=1.0) + assert r1.get("code") == "timeout" + finally: + scheduler.stop() + + +def test_scheduler_continuous_batching_non_stream_path(): + mod = _load_scheduler_module() + scheduler = mod.InferenceScheduler([_Svc("w0")], queue_size=4, request_timeout_ms=1000, continuous_batching=True) + scheduler.start() + try: + h = scheduler.submit({"session_id": "s-cb"}, stream=False) + r = h.get_result(timeout=2.0) + assert r["session_id"] == "s-cb" + assert "choices" in r + assert r["choices"][0]["message"]["content"] is not None + snap = scheduler.debug_snapshot() + assert snap["continuous_batching"] is True + assert snap["metrics"]["batch_rounds"] >= 1.0 + assert snap["metrics"]["prefill_rounds"] >= 1.0 + assert snap["metrics"]["decode_rounds"] >= 1.0 + finally: + scheduler.stop() + + +def test_scheduler_continuous_batching_packed_prefill_path(): + mod = _load_scheduler_module() + svc = _PackedSvc("w0") + scheduler = mod.InferenceScheduler([svc], queue_size=8, request_timeout_ms=1000, continuous_batching=True) + scheduler.start() + try: + h1 = scheduler.submit({"session_id": "a", "max_new_tokens": 1}, stream=False) + h2 = scheduler.submit({"session_id": "b", "max_new_tokens": 1}, stream=False) + r1 = h1.get_result(timeout=2.0) + r2 = h2.get_result(timeout=2.0) + assert r1["choices"][0]["message"]["content"] == "p" + assert r2["choices"][0]["message"]["content"] == "p" + snap = scheduler.debug_snapshot() + assert snap["metrics"]["packed_prefill_batches"] >= 1.0 + assert snap["metrics"]["packed_prefill_tasks"] >= 2.0 + assert svc.packed_calls >= 1 + finally: + scheduler.stop() + + +if __name__ == "__main__": + test_scheduler_non_stream_and_stream() + test_scheduler_session_sticky_stop_route() + test_scheduler_queue_full_and_timeout() + test_scheduler_continuous_batching_non_stream_path() + test_scheduler_continuous_batching_packed_prefill_path() + print("scheduler tests passed") diff --git a/test/test_server_kv_reuse_integration.py b/test/test_server_kv_reuse_integration.py new file mode 100644 index 000000000..0076b6075 --- /dev/null +++ b/test/test_server_kv_reuse_integration.py @@ -0,0 +1,241 @@ +import importlib.util +import sys +import types +from pathlib import Path + + +def _load_server_module(): + root = Path(__file__).resolve().parents[1] + interfaces_path = root / "python" / "llaisys" / "interfaces.py" + kv_path = root / "python" / "llaisys" / "kv_cache_pool.py" + scheduler_path = root / "python" / "llaisys" / "scheduler.py" + session_mgr_path = root / "python" / "llaisys" / "session_manager.py" + kv_bridge_path = root / "python" / "llaisys" / "kv_runtime_bridge.py" + server_path = root / "python" / "llaisys" / "server.py" + + # Load interfaces first (kv_cache_pool and server import from it) + iface_spec = importlib.util.spec_from_file_location("llaisys.interfaces", str(interfaces_path)) + if iface_spec is not None and iface_spec.loader is not None: + iface_mod = importlib.util.module_from_spec(iface_spec) + sys.modules[iface_spec.name] = iface_mod + iface_spec.loader.exec_module(iface_mod) + + kv_spec = importlib.util.spec_from_file_location("llaisys.kv_cache_pool", str(kv_path)) + if kv_spec is None or kv_spec.loader is None: + raise RuntimeError("failed to load kv_cache_pool") + kv_mod = importlib.util.module_from_spec(kv_spec) + sys.modules[kv_spec.name] = kv_mod + kv_spec.loader.exec_module(kv_mod) + + scheduler_spec = importlib.util.spec_from_file_location("llaisys.scheduler", str(scheduler_path)) + if scheduler_spec is None or scheduler_spec.loader is None: + raise RuntimeError("failed to load scheduler") + scheduler_mod = importlib.util.module_from_spec(scheduler_spec) + sys.modules[scheduler_spec.name] = scheduler_mod + scheduler_spec.loader.exec_module(scheduler_mod) + + # Load session_manager (server.py imports from it) + session_mgr_mod = None + if session_mgr_path.exists(): + sm_spec = importlib.util.spec_from_file_location("llaisys.session_manager", str(session_mgr_path)) + if sm_spec is not None and sm_spec.loader is not None: + session_mgr_mod = importlib.util.module_from_spec(sm_spec) + sys.modules[sm_spec.name] = session_mgr_mod + sm_spec.loader.exec_module(session_mgr_mod) + + # Load kv_runtime_bridge (server.py imports from it) + kv_bridge_mod = None + if kv_bridge_path.exists(): + kb_spec = importlib.util.spec_from_file_location("llaisys.kv_runtime_bridge", str(kv_bridge_path)) + if kb_spec is not None and kb_spec.loader is not None: + kv_bridge_mod = importlib.util.module_from_spec(kb_spec) + sys.modules[kb_spec.name] = kv_bridge_mod + kb_spec.loader.exec_module(kv_bridge_mod) + + fake_llaisys = types.ModuleType("llaisys") + fake_llaisys.kv_cache_pool = kv_mod + fake_llaisys.scheduler = scheduler_mod + fake_llaisys.Tokenizer = object + if session_mgr_mod: + fake_llaisys.session_manager = session_mgr_mod + if kv_bridge_mod: + fake_llaisys.kv_runtime_bridge = kv_bridge_mod + sys.modules["llaisys"] = fake_llaisys + sys.modules["llaisys.kv_cache_pool"] = kv_mod + sys.modules["llaisys.scheduler"] = scheduler_mod + + # fake libllaisys with stub LlaisysSamplingParams + fake_libllaisys = types.ModuleType("llaisys.libllaisys") + + class _StubSamplingParams: + def __init__(self, top_k=1, top_p=0.0, temperature=0.0, seed=0): + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + self.seed = seed + + fake_libllaisys.LlaisysSamplingParams = _StubSamplingParams + fake_llaisys.libllaisys = fake_libllaisys + sys.modules["llaisys.libllaisys"] = fake_libllaisys + + fake_models = types.ModuleType("llaisys.models") + + class _StubQwen2: + @staticmethod + def build_prompt(messages, system_prompt=None, add_generation_prompt=True): + lines = [] + if system_prompt: + lines.append(f"System: {system_prompt}") + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "assistant": + lines.append(f"Assistant: {content}") + else: + lines.append(f"User: {content}") + if add_generation_prompt: + lines.append("Assistant:") + return "\n".join(lines) + + fake_models.Qwen2 = _StubQwen2 + sys.modules["llaisys.models"] = fake_models + + spec = importlib.util.spec_from_file_location("llaisys.server", str(server_path)) + if spec is None or spec.loader is None: + raise RuntimeError("failed to load server module") + mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = mod + spec.loader.exec_module(mod) + return mod + + +class FakeTokenizer: + def encode(self, text): + return [ord(ch) for ch in text] + + def decode(self, token_ids): + return "".join(chr(int(t)) for t in token_ids) + + +class _EndToken: + def __init__(self, value): + self.value = value + + +class _Meta: + def __init__(self): + self.end_token = _EndToken(-1) + + +class FakeModel: + def __init__(self): + self._meta = _Meta() + self.bind_calls = [] + self.export_calls = [] + self.reset_calls = 0 + self._ctx_seq = 0 + + def reset_kv_cache(self): + self.reset_calls += 1 + + def prefill(self, prompt_ids): + return 65 # "A" + + def prefill_sampling(self, prompt_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.prefill(prompt_ids) + + def step(self, token_ids): + return 66 # "B" + + def step_sampling(self, token_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.step(token_ids) + + def set_kv_context(self, ctx): + self.bind_calls.append(ctx) + return 0 + + def kv_context_create(self): + self._ctx_seq += 1 + return {"ctx_id": self._ctx_seq} + + def kv_context_release(self, ctx): + return None + + def export_kv_context(self, ctx, block_tokens): + self.export_calls.append((ctx, block_tokens)) + return 0 + + +def _make_service(): + server_mod = _load_server_module() + model = FakeModel() + tok = FakeTokenizer() + service = server_mod.ChatService( + model=model, + tokenizer=tok, + model_path=None, + enable_kv_runtime_reuse=True, + block_size=4, + max_blocks=256, + max_bytes=1024 * 1024, + ) + return service, model + + +def test_kv_reuse_same_session_binds_native_context(): + service, model = _make_service() + + first = service.generate({"session_id": "s1", "prompt": "你好", "max_new_tokens": 2}) + assert first["session_id"] == "s1" + # first request has no prefix hit; should bind None + assert model.bind_calls and model.bind_calls[0] is None + assert len(model.export_calls) == 1 + + service.generate({"session_id": "s1", "prompt": "继续", "max_new_tokens": 2}) + # second request should bind non-null native context + assert model.bind_calls[-1] is not None + dbg = service.kv_debug_snapshot("s1") + assert dbg["last_bind"]["bound"] is True + assert dbg["last_bind"]["source_session_id"] == "s1" + assert dbg["last_bind"]["prefix_len"] > 0 + + +def test_kv_reuse_cross_session_can_use_donor_context(): + service, _ = _make_service() + + service.generate({"session_id": "donor", "prompt": "同一个问题", "max_new_tokens": 2}) + service.generate( + { + "session_id": "receiver", + "messages": [{"role": "user", "content": "同一个问题"}], + "max_new_tokens": 2, + } + ) + + dbg = service.kv_debug_snapshot("receiver") + assert dbg["last_bind"]["bound"] is True + assert dbg["last_bind"]["prefix_len"] > 0 + assert dbg["last_bind"]["source_session_id"] == "donor" + + +def test_cancelled_request_does_not_export_native_kv(): + service, model = _make_service() + + def _cancelled_iter(prompt_ids, max_new_tokens, sampling, prefix_len, cancel_event): + cancel_event.set() + if False: + yield 0 + + service._iter_generate_ids = _cancelled_iter + result = service.generate({"session_id": "s-cancel", "prompt": "会取消", "max_new_tokens": 2}) + assert result.get("stopped") is True + assert result["choices"][0]["finish_reason"] == "stop" + assert len(model.export_calls) == 0 + + +if __name__ == "__main__": + test_kv_reuse_same_session_binds_native_context() + test_kv_reuse_cross_session_can_use_donor_context() + test_cancelled_request_does_not_export_native_kv() + print("server kv reuse integration tests passed") + diff --git a/test/test_shared_model.py b/test/test_shared_model.py new file mode 100644 index 000000000..920cda937 --- /dev/null +++ b/test/test_shared_model.py @@ -0,0 +1,661 @@ +"""Tests for shared model pool, shared KV pool, and KV memory-aware flow control. + +Covers: +- Shared model + KV pool: multiple ChatService instances share the same objects +- Cross-worker prefix reuse via shared KV pool +- KV memory_pressure() correctness +- KV memory-aware flow control in scheduler (reject when pressure > threshold) +- Shared pool routing optimization in scheduler +- debug_snapshot includes kv_memory_pressure and kv_memory_threshold +""" + +import importlib.util +import sys +import threading +import types +from pathlib import Path +from typing import Any, Dict, List, Optional + + +# --------------------------------------------------------------------------- +# Module loading (same pattern as existing tests) +# --------------------------------------------------------------------------- + +def _load_modules(): + root = Path(__file__).resolve().parents[1] + interfaces_path = root / "python" / "llaisys" / "interfaces.py" + kv_path = root / "python" / "llaisys" / "kv_cache_pool.py" + scheduler_path = root / "python" / "llaisys" / "scheduler.py" + session_mgr_path = root / "python" / "llaisys" / "session_manager.py" + kv_bridge_path = root / "python" / "llaisys" / "kv_runtime_bridge.py" + server_path = root / "python" / "llaisys" / "server.py" + + # interfaces + iface_spec = importlib.util.spec_from_file_location("llaisys.interfaces", str(interfaces_path)) + if iface_spec is None or iface_spec.loader is None: + raise RuntimeError("failed to load interfaces") + iface_mod = importlib.util.module_from_spec(iface_spec) + sys.modules[iface_spec.name] = iface_mod + iface_spec.loader.exec_module(iface_mod) + + # kv_cache_pool + kv_spec = importlib.util.spec_from_file_location("llaisys.kv_cache_pool", str(kv_path)) + if kv_spec is None or kv_spec.loader is None: + raise RuntimeError("failed to load kv_cache_pool") + kv_mod = importlib.util.module_from_spec(kv_spec) + sys.modules[kv_spec.name] = kv_mod + kv_spec.loader.exec_module(kv_mod) + + # scheduler + scheduler_spec = importlib.util.spec_from_file_location("llaisys.scheduler", str(scheduler_path)) + if scheduler_spec is None or scheduler_spec.loader is None: + raise RuntimeError("failed to load scheduler") + scheduler_mod = importlib.util.module_from_spec(scheduler_spec) + sys.modules[scheduler_spec.name] = scheduler_mod + scheduler_spec.loader.exec_module(scheduler_mod) + + # session_manager + session_mgr_mod = None + if session_mgr_path.exists(): + sm_spec = importlib.util.spec_from_file_location("llaisys.session_manager", str(session_mgr_path)) + if sm_spec is not None and sm_spec.loader is not None: + session_mgr_mod = importlib.util.module_from_spec(sm_spec) + sys.modules[sm_spec.name] = session_mgr_mod + sm_spec.loader.exec_module(session_mgr_mod) + + # kv_runtime_bridge + kv_bridge_mod = None + if kv_bridge_path.exists(): + kb_spec = importlib.util.spec_from_file_location("llaisys.kv_runtime_bridge", str(kv_bridge_path)) + if kb_spec is not None and kb_spec.loader is not None: + kv_bridge_mod = importlib.util.module_from_spec(kb_spec) + sys.modules[kb_spec.name] = kv_bridge_mod + kb_spec.loader.exec_module(kv_bridge_mod) + + # fake llaisys package + fake_llaisys = types.ModuleType("llaisys") + fake_llaisys.kv_cache_pool = kv_mod + fake_llaisys.scheduler = scheduler_mod + fake_llaisys.interfaces = iface_mod + fake_llaisys.Tokenizer = object + if session_mgr_mod: + fake_llaisys.session_manager = session_mgr_mod + if kv_bridge_mod: + fake_llaisys.kv_runtime_bridge = kv_bridge_mod + fake_llaisys.__path__ = [str(root / "python" / "llaisys")] + sys.modules["llaisys"] = fake_llaisys + sys.modules["llaisys.kv_cache_pool"] = kv_mod + sys.modules["llaisys.scheduler"] = scheduler_mod + sys.modules["llaisys.interfaces"] = iface_mod + if session_mgr_mod: + sys.modules["llaisys.session_manager"] = session_mgr_mod + if kv_bridge_mod: + sys.modules["llaisys.kv_runtime_bridge"] = kv_bridge_mod + + # fake libllaisys + fake_libllaisys = types.ModuleType("llaisys.libllaisys") + + class _FakeSamplingParams: + def __init__(self, top_k=1, top_p=0.0, temperature=0.0, seed=0): + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + self.seed = seed + + fake_libllaisys.LlaisysSamplingParams = _FakeSamplingParams + sys.modules["llaisys.libllaisys"] = fake_libllaisys + fake_llaisys.libllaisys = fake_libllaisys + + # fake models + fake_models = types.ModuleType("llaisys.models") + + class _StubQwen2: + @staticmethod + def build_prompt(messages, system_prompt=None, add_generation_prompt=True): + lines = [] + if system_prompt: + lines.append(f"System: {system_prompt}") + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "assistant": + lines.append(f"Assistant: {content}") + else: + lines.append(f"User: {content}") + if add_generation_prompt: + lines.append("Assistant:") + return "\n".join(lines) + + fake_models.Qwen2 = _StubQwen2 + sys.modules["llaisys.models"] = fake_models + + # server + spec = importlib.util.spec_from_file_location("llaisys.server", str(server_path)) + if spec is None or spec.loader is None: + raise RuntimeError("failed to load server module") + server_mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = server_mod + spec.loader.exec_module(server_mod) + + return iface_mod, kv_mod, scheduler_mod, kv_bridge_mod, server_mod + + +iface_mod, kv_mod, scheduler_mod, kv_bridge_mod, server_mod = _load_modules() +ChatService = server_mod.ChatService +KVCachePool = kv_mod.KVCachePool +KVRuntimeBridge = kv_bridge_mod.KVRuntimeBridge +InferenceScheduler = scheduler_mod.InferenceScheduler +SchedulerQueueFullError = scheduler_mod.SchedulerQueueFullError + + +# --------------------------------------------------------------------------- +# Fake model / tokenizer helpers +# --------------------------------------------------------------------------- + +class _EndToken: + def __init__(self, value): + self.value = value + + +class _Meta: + def __init__(self, eos=-1): + self.end_token = _EndToken(eos) + + +class FakeTokenizer: + def encode(self, text): + return [ord(ch) for ch in text] + + def decode(self, token_ids): + return "".join(chr(int(t)) for t in token_ids) + + +class FakeModel: + def __init__(self, eos=-1): + self._meta = _Meta(eos) + self.bind_calls = [] + self.export_calls = [] + self.reset_calls = 0 + self._ctx_seq = 0 + self.prefill_packed_calls = 0 + self.step_packed_calls = 0 + + def reset_kv_cache(self): + self.reset_calls += 1 + + def prefill(self, prompt_ids): + return 65 + + def prefill_sampling(self, prompt_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.prefill(prompt_ids) + + def step(self, token_ids): + return 66 + + def step_sampling(self, token_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.step(token_ids) + + def prefill_packed(self, prompts): + self.prefill_packed_calls += 1 + return [65] * len(prompts) + + def step_packed(self, sequences): + self.step_packed_calls += 1 + return [66] * len(sequences) + + def prefill_packed_sampling(self, prompts, params_list): + return [65] * len(prompts) + + def step_packed_sampling(self, sequences, params_list): + return [66] * len(sequences) + + def set_kv_context(self, ctx): + self.bind_calls.append(ctx) + return 0 + + def kv_context_create(self): + self._ctx_seq += 1 + return {"ctx_id": self._ctx_seq} + + def kv_context_release(self, ctx): + return None + + def export_kv_context(self, ctx, block_tokens): + self.export_calls.append((ctx, block_tokens)) + return 0 + + +def _make_shared_services(worker_count=2, **kwargs): + """Create multiple ChatService instances sharing the same model, lock, KV pool, and KV bridge.""" + model = FakeModel() + tok = FakeTokenizer() + shared_lock = threading.RLock() + shared_kv_pool = KVCachePool( + block_size=kwargs.get("block_size", 4), + max_blocks=kwargs.get("max_blocks", 256), + max_bytes=kwargs.get("max_bytes", 1024 * 1024), + ) + shared_kv_bridge = KVRuntimeBridge(model, enabled=kwargs.get("enable_kv_runtime_reuse", True)) + services = [] + for _ in range(worker_count): + svc = ChatService( + model=model, + tokenizer=tok, + model_path=None, + enable_kv_runtime_reuse=kwargs.get("enable_kv_runtime_reuse", True), + block_size=kwargs.get("block_size", 4), + max_blocks=kwargs.get("max_blocks", 256), + max_bytes=kwargs.get("max_bytes", 1024 * 1024), + model_lock=shared_lock, + kv_pool=shared_kv_pool, + kv_bridge=shared_kv_bridge, + ) + services.append(svc) + return services, model, shared_kv_pool, shared_lock, shared_kv_bridge + + +def _make_independent_services(worker_count=2, **kwargs): + """Create multiple ChatService instances with independent resources.""" + services = [] + models = [] + for _ in range(worker_count): + model = FakeModel() + tok = FakeTokenizer() + svc = ChatService( + model=model, + tokenizer=tok, + model_path=None, + enable_kv_runtime_reuse=kwargs.get("enable_kv_runtime_reuse", False), + block_size=kwargs.get("block_size", 4), + max_blocks=kwargs.get("max_blocks", 256), + max_bytes=kwargs.get("max_bytes", 1024 * 1024), + ) + services.append(svc) + models.append(model) + return services, models + + +# =========================================================================== +# Test 1: Shared instances are the same object +# =========================================================================== + +def test_shared_instances_identity(): + """All ChatService instances should share the same model, lock, KV pool, and KV bridge.""" + services, model, shared_pool, shared_lock, shared_bridge = _make_shared_services(3) + + for i, svc in enumerate(services): + assert svc.model is model, f"Service {i} should share the model" + assert svc._model_lock is shared_lock, f"Service {i} should share the model lock" + assert svc._kv_pool is shared_pool, f"Service {i} should share the KV pool" + assert svc._kv_bridge is shared_bridge, f"Service {i} should share the KV bridge" + + # Each service should have its own SessionManager + assert services[0]._session_mgr is not services[1]._session_mgr + + print(" shared instances identity OK") + + +# =========================================================================== +# Test 2: Independent instances are distinct +# =========================================================================== + +def test_independent_instances_distinct(): + """Independent ChatService instances should have separate resources.""" + services, models = _make_independent_services(2) + + assert services[0].model is not services[1].model + assert services[0]._model_lock is not services[1]._model_lock + assert services[0]._kv_pool is not services[1]._kv_pool + assert services[0]._kv_bridge is not services[1]._kv_bridge + + print(" independent instances distinct OK") + + +# =========================================================================== +# Test 3: memory_pressure() correctness +# =========================================================================== + +def test_memory_pressure_empty(): + """Empty pool should have 0.0 pressure.""" + pool = KVCachePool(block_size=4, max_blocks=100, max_bytes=1024 * 1024) + assert pool.memory_pressure() == 0.0 + print(" memory_pressure empty OK") + + +def test_memory_pressure_increases(): + """Pressure should increase as blocks are allocated.""" + pool = KVCachePool(block_size=4, max_blocks=10, max_bytes=1024 * 1024) + assert pool.memory_pressure() == 0.0 + + # Acquire contexts to fill blocks + for i in range(5): + tokens = list(range(i * 4, (i + 1) * 4)) + pool.acquire_context(f"ctx-{i}", tokens) + + pressure = pool.memory_pressure() + assert pressure > 0.0, f"Pressure should be > 0 after allocations, got {pressure}" + assert pressure <= 1.0, f"Pressure should be <= 1.0, got {pressure}" + + print(" memory_pressure increases OK") + + +def test_memory_pressure_interface(): + """memory_pressure should be available via IKVCachePool interface.""" + IKVCachePool = iface_mod.IKVCachePool + pool = KVCachePool(block_size=4, max_blocks=100, max_bytes=1024 * 1024) + assert isinstance(pool, IKVCachePool) + assert hasattr(pool, "memory_pressure") + assert callable(pool.memory_pressure) + print(" memory_pressure interface OK") + + +# =========================================================================== +# Test 4: Cross-worker prefix reuse via shared KV pool +# =========================================================================== + +def test_shared_pool_cross_worker_prefix_reuse(): + """With shared KV pool, a context created by one worker should be visible to another.""" + services, model, shared_pool, _, _ = _make_shared_services(2) + + # Worker 0 generates with a prompt + services[0].generate({"session_id": "shared-s1", "prompt": "hello world", "max_new_tokens": 2}) + + # Worker 1 should see the prefix from worker 0's context in the shared pool + prefix_len = shared_pool.query_prefix_len( + services[1].tokenizer.encode("User: hello world\nAssistant:") + ) + # The exact prefix_len depends on block alignment, but should be > 0 + # since worker 0 already created blocks for the same prompt pattern + assert prefix_len >= 0, "Shared pool should allow cross-worker prefix queries" + + # Worker 1 generates with the same prompt - should benefit from shared pool + services[1].generate({ + "session_id": "shared-s2", + "messages": [{"role": "user", "content": "hello world"}], + "max_new_tokens": 2, + }) + + stats = shared_pool.snapshot_stats() + assert stats["acquire_count"] == 2.0, "Both workers should have acquired from the same pool" + + print(" shared pool cross-worker prefix reuse OK") + + +# =========================================================================== +# Test 5: KV memory flow control - reject when pressure > threshold +# =========================================================================== + +def test_kv_memory_flow_control_rejects(): + """Scheduler should reject requests when KV memory pressure exceeds threshold.""" + # Use a tiny pool so pressure rises quickly + services, model, shared_pool, _, _ = _make_shared_services( + 1, max_blocks=2, max_bytes=64, block_size=4, + ) + + # Fill the pool to create pressure + for i in range(3): + tokens = list(range(i * 4, (i + 1) * 4)) + shared_pool.acquire_context(f"fill-{i}", tokens) + + pressure = shared_pool.memory_pressure() + assert pressure > 0.5, f"Pool should be under pressure, got {pressure}" + + # Create scheduler with low threshold + scheduler = InferenceScheduler( + services, + queue_size=8, + request_timeout_ms=5000, + kv_memory_threshold=0.1, # very low threshold + ) + + rejected = False + try: + scheduler.submit({"session_id": "reject-test", "prompt": "test", "max_new_tokens": 1}, stream=False) + except SchedulerQueueFullError as exc: + assert "KV memory pressure" in str(exc) + rejected = True + + assert rejected, "Should have rejected due to KV memory pressure" + print(" KV memory flow control rejects OK") + + +def test_kv_memory_flow_control_allows_when_below(): + """Scheduler should allow requests when KV memory pressure is below threshold.""" + services, model, shared_pool, _, _ = _make_shared_services(1) + + scheduler = InferenceScheduler( + services, + queue_size=8, + request_timeout_ms=5000, + kv_memory_threshold=0.85, + ) + scheduler.start() + try: + handle = scheduler.submit( + {"session_id": "allow-test", "prompt": "test", "max_new_tokens": 2}, + stream=False, + ) + result = handle.get_result(timeout=5.0) + assert "choices" in result + finally: + scheduler.stop() + + print(" KV memory flow control allows when below OK") + + +def test_kv_memory_flow_control_disabled(): + """When threshold is 0.0, flow control should be disabled.""" + services, model, shared_pool, _, _ = _make_shared_services( + 1, max_blocks=2, max_bytes=64, block_size=4, + ) + + # Fill pool + for i in range(3): + tokens = list(range(i * 4, (i + 1) * 4)) + shared_pool.acquire_context(f"fill-{i}", tokens) + + scheduler = InferenceScheduler( + services, + queue_size=8, + request_timeout_ms=5000, + kv_memory_threshold=0.0, # disabled + ) + scheduler.start() + try: + # Should not reject even with high pressure + handle = scheduler.submit( + {"session_id": "no-fc", "prompt": "test", "max_new_tokens": 2}, + stream=False, + ) + result = handle.get_result(timeout=5.0) + assert "choices" in result + finally: + scheduler.stop() + + print(" KV memory flow control disabled OK") + + +# =========================================================================== +# Test 6: KV memory metrics in debug_snapshot +# =========================================================================== + +def test_debug_snapshot_kv_memory_fields(): + """debug_snapshot should include kv_memory_threshold and kv_memory_pressure.""" + services, _, _, _, _ = _make_shared_services(1) + scheduler = InferenceScheduler( + services, + queue_size=8, + kv_memory_threshold=0.85, + ) + snap = scheduler.debug_snapshot() + assert "kv_memory_threshold" in snap, "Should have kv_memory_threshold" + assert "kv_memory_pressure" in snap, "Should have kv_memory_pressure" + assert snap["kv_memory_threshold"] == 0.85 + assert snap["kv_memory_pressure"] == 0.0 # empty pool + + print(" debug_snapshot KV memory fields OK") + + +def test_debug_snapshot_kv_memory_rejected_metric(): + """kv_memory_rejected metric should increment on rejection.""" + services, _, shared_pool, _, _ = _make_shared_services( + 1, max_blocks=2, max_bytes=64, block_size=4, + ) + for i in range(3): + tokens = list(range(i * 4, (i + 1) * 4)) + shared_pool.acquire_context(f"fill-{i}", tokens) + + scheduler = InferenceScheduler( + services, + queue_size=8, + kv_memory_threshold=0.1, + ) + + try: + scheduler.submit({"prompt": "test"}, stream=False) + except SchedulerQueueFullError: + pass + + snap = scheduler.debug_snapshot() + assert snap["metrics"]["kv_memory_rejected"] >= 1.0 + + print(" debug_snapshot kv_memory_rejected metric OK") + + +# =========================================================================== +# Test 7: Shared pool routing optimization +# =========================================================================== + +def test_shared_pool_kv_debug_snapshot_no_double_count(): + """With shared pool, kv_debug_snapshot should not double-count stats.""" + services, _, shared_pool, _, _ = _make_shared_services(2) + + # Generate on one worker + services[0].generate({"session_id": "snap-s1", "prompt": "hello", "max_new_tokens": 2}) + + scheduler = InferenceScheduler(services, queue_size=8) + snap = scheduler.kv_debug_snapshot() + + # With shared pool, acquire_count should be 1 (not 2) + assert snap["kv_pool"]["acquire_count"] == 1.0, ( + f"Shared pool should report 1 acquire, got {snap['kv_pool']['acquire_count']}" + ) + + print(" shared pool kv_debug_snapshot no double count OK") + + +# =========================================================================== +# Test 8: Shared model concurrent generate +# =========================================================================== + +def test_shared_model_concurrent_generate(): + """Multiple workers sharing a model should serialize via the shared lock.""" + services, model, _, _, _ = _make_shared_services(2) + + results = [None, None] + errors = [] + + def _worker(idx): + try: + r = services[idx].generate({ + "session_id": f"concurrent-{idx}", + "prompt": f"msg-{idx}", + "max_new_tokens": 2, + }) + results[idx] = r + except Exception as exc: + errors.append(exc) + + threads = [threading.Thread(target=_worker, args=(i,)) for i in range(2)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10.0) + + assert len(errors) == 0, f"Concurrent errors: {errors}" + for i, r in enumerate(results): + assert r is not None, f"Worker {i} should have produced a result" + assert "choices" in r + + print(" shared model concurrent generate OK") + + +# =========================================================================== +# Test 9: Shared model + scheduler end-to-end +# =========================================================================== + +def test_shared_model_scheduler_e2e(): + """End-to-end: scheduler with shared model services.""" + services, model, _, _, _ = _make_shared_services(2) + scheduler = InferenceScheduler( + services, + queue_size=8, + request_timeout_ms=5000, + continuous_batching=True, + max_batch_size=4, + ) + scheduler.start() + try: + handles = [] + for i in range(4): + h = scheduler.submit( + {"session_id": f"e2e-shared-{i}", "prompt": f"test-{i}", "max_new_tokens": 2, "stream": True}, + stream=True, + ) + handles.append(h) + + for i, h in enumerate(handles): + items = list(h.iter_stream(timeout=5.0)) + assert len(items) > 0, f"Stream {i} should produce chunks" + last = items[-1] + assert last["choices"][0]["finish_reason"] is not None + finally: + scheduler.stop() + + print(" shared model scheduler e2e OK") + + +# =========================================================================== +# Runner +# =========================================================================== + +if __name__ == "__main__": + tests = [ + test_shared_instances_identity, + test_independent_instances_distinct, + test_memory_pressure_empty, + test_memory_pressure_increases, + test_memory_pressure_interface, + test_shared_pool_cross_worker_prefix_reuse, + test_kv_memory_flow_control_rejects, + test_kv_memory_flow_control_allows_when_below, + test_kv_memory_flow_control_disabled, + test_debug_snapshot_kv_memory_fields, + test_debug_snapshot_kv_memory_rejected_metric, + test_shared_pool_kv_debug_snapshot_no_double_count, + test_shared_model_concurrent_generate, + test_shared_model_scheduler_e2e, + ] + + passed = 0 + failed = 0 + for test_fn in tests: + name = test_fn.__name__ + try: + print(f"[RUN ] {name}") + test_fn() + print(f"[PASS] {name}") + passed += 1 + except Exception as exc: + import traceback + print(f"[FAIL] {name}: {exc}") + traceback.print_exc() + failed += 1 + + print(f"\n{'='*60}") + print(f"Results: {passed} passed, {failed} failed, {passed + failed} total") + if failed > 0: + print("SOME TESTS FAILED") + sys.exit(1) + else: + print("ALL TESTS PASSED") diff --git a/test/test_streaming_batch.py b/test/test_streaming_batch.py new file mode 100644 index 000000000..f7f46af6f --- /dev/null +++ b/test/test_streaming_batch.py @@ -0,0 +1,789 @@ +"""Tests for streaming batch processing (Phase 1-3): +- Streaming batch produces correct SSE chunks (multi-sequence parallel) +- Non-stream requests via batch path +- Mixed stream + non-stream in same batch +- Single sequence cancellation while others continue +- Different max_new_tokens (partial early finish) +- Batch size limit enforcement +- Dynamic shrink verification +- Fallback to single path when no packed API +- All existing test suites pass (regression) +""" + +import importlib.util +import sys +import threading +import time +import types +from pathlib import Path +from typing import Any, Dict, List, Optional + + +# --------------------------------------------------------------------------- +# Module loading (same pattern as existing tests) +# --------------------------------------------------------------------------- + +def _load_modules(): + root = Path(__file__).resolve().parents[1] + interfaces_path = root / "python" / "llaisys" / "interfaces.py" + kv_path = root / "python" / "llaisys" / "kv_cache_pool.py" + scheduler_path = root / "python" / "llaisys" / "scheduler.py" + session_mgr_path = root / "python" / "llaisys" / "session_manager.py" + kv_bridge_path = root / "python" / "llaisys" / "kv_runtime_bridge.py" + server_path = root / "python" / "llaisys" / "server.py" + + # interfaces + iface_spec = importlib.util.spec_from_file_location("llaisys.interfaces", str(interfaces_path)) + if iface_spec is None or iface_spec.loader is None: + raise RuntimeError("failed to load interfaces") + iface_mod = importlib.util.module_from_spec(iface_spec) + sys.modules[iface_spec.name] = iface_mod + iface_spec.loader.exec_module(iface_mod) + + # kv_cache_pool + kv_spec = importlib.util.spec_from_file_location("llaisys.kv_cache_pool", str(kv_path)) + if kv_spec is None or kv_spec.loader is None: + raise RuntimeError("failed to load kv_cache_pool") + kv_mod = importlib.util.module_from_spec(kv_spec) + sys.modules[kv_spec.name] = kv_mod + kv_spec.loader.exec_module(kv_mod) + + # scheduler + scheduler_spec = importlib.util.spec_from_file_location("llaisys.scheduler", str(scheduler_path)) + if scheduler_spec is None or scheduler_spec.loader is None: + raise RuntimeError("failed to load scheduler") + scheduler_mod = importlib.util.module_from_spec(scheduler_spec) + sys.modules[scheduler_spec.name] = scheduler_mod + scheduler_spec.loader.exec_module(scheduler_mod) + + # session_manager + session_mgr_mod = None + if session_mgr_path.exists(): + sm_spec = importlib.util.spec_from_file_location("llaisys.session_manager", str(session_mgr_path)) + if sm_spec is not None and sm_spec.loader is not None: + session_mgr_mod = importlib.util.module_from_spec(sm_spec) + sys.modules[sm_spec.name] = session_mgr_mod + sm_spec.loader.exec_module(session_mgr_mod) + + # kv_runtime_bridge + kv_bridge_mod = None + if kv_bridge_path.exists(): + kb_spec = importlib.util.spec_from_file_location("llaisys.kv_runtime_bridge", str(kv_bridge_path)) + if kb_spec is not None and kb_spec.loader is not None: + kv_bridge_mod = importlib.util.module_from_spec(kb_spec) + sys.modules[kb_spec.name] = kv_bridge_mod + kb_spec.loader.exec_module(kv_bridge_mod) + + # fake llaisys package + fake_llaisys = types.ModuleType("llaisys") + fake_llaisys.kv_cache_pool = kv_mod + fake_llaisys.scheduler = scheduler_mod + fake_llaisys.interfaces = iface_mod + fake_llaisys.Tokenizer = object + if session_mgr_mod: + fake_llaisys.session_manager = session_mgr_mod + if kv_bridge_mod: + fake_llaisys.kv_runtime_bridge = kv_bridge_mod + fake_llaisys.__path__ = [str(root / "python" / "llaisys")] + sys.modules["llaisys"] = fake_llaisys + sys.modules["llaisys.kv_cache_pool"] = kv_mod + sys.modules["llaisys.scheduler"] = scheduler_mod + sys.modules["llaisys.interfaces"] = iface_mod + if session_mgr_mod: + sys.modules["llaisys.session_manager"] = session_mgr_mod + if kv_bridge_mod: + sys.modules["llaisys.kv_runtime_bridge"] = kv_bridge_mod + + # fake libllaisys + fake_libllaisys = types.ModuleType("llaisys.libllaisys") + + class _FakeSamplingParams: + def __init__(self, top_k=1, top_p=0.0, temperature=0.0, seed=0): + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + self.seed = seed + + fake_libllaisys.LlaisysSamplingParams = _FakeSamplingParams + sys.modules["llaisys.libllaisys"] = fake_libllaisys + fake_llaisys.libllaisys = fake_libllaisys + + # fake models + fake_models = types.ModuleType("llaisys.models") + + class _StubQwen2: + @staticmethod + def build_prompt(messages, system_prompt=None, add_generation_prompt=True): + lines = [] + if system_prompt: + lines.append(f"System: {system_prompt}") + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "assistant": + lines.append(f"Assistant: {content}") + else: + lines.append(f"User: {content}") + if add_generation_prompt: + lines.append("Assistant:") + return "\n".join(lines) + + fake_models.Qwen2 = _StubQwen2 + sys.modules["llaisys.models"] = fake_models + + # server + spec = importlib.util.spec_from_file_location("llaisys.server", str(server_path)) + if spec is None or spec.loader is None: + raise RuntimeError("failed to load server module") + server_mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = server_mod + spec.loader.exec_module(server_mod) + + return iface_mod, kv_mod, scheduler_mod, server_mod + + +iface_mod, kv_mod, scheduler_mod, server_mod = _load_modules() +ChatService = server_mod.ChatService +BatchSequenceState = server_mod.BatchSequenceState +BatchState = server_mod.BatchState +StepResult = server_mod.StepResult + + +# --------------------------------------------------------------------------- +# Fake model / tokenizer helpers +# --------------------------------------------------------------------------- + +class _EndToken: + def __init__(self, value): + self.value = value + + +class _Meta: + def __init__(self, eos=-1): + self.end_token = _EndToken(eos) + + +class FakeTokenizer: + def encode(self, text): + return [ord(ch) for ch in text] + + def decode(self, token_ids): + return "".join(chr(int(t)) for t in token_ids) + + +class FakeModel: + """Model mock with packed API support.""" + + def __init__(self, eos=-1): + self._meta = _Meta(eos) + self.bind_calls = [] + self.export_calls = [] + self.reset_calls = 0 + self._ctx_seq = 0 + self.prefill_packed_calls = 0 + self.step_packed_calls = 0 + self.prefill_packed_sampling_calls = 0 + self.step_packed_sampling_calls = 0 + # Track decode_inputs sizes for shrink verification + self.step_packed_input_sizes: List[int] = [] + + def reset_kv_cache(self): + self.reset_calls += 1 + + def prefill(self, prompt_ids): + return 65 + + def prefill_sampling(self, prompt_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.prefill(prompt_ids) + + def step(self, token_ids): + return 66 + + def step_sampling(self, token_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return self.step(token_ids) + + def prefill_packed(self, prompts): + self.prefill_packed_calls += 1 + return [65] * len(prompts) + + def step_packed(self, sequences): + self.step_packed_calls += 1 + self.step_packed_input_sizes.append(len(sequences)) + return [66] * len(sequences) + + def prefill_packed_sampling(self, prompts, params_list): + self.prefill_packed_sampling_calls += 1 + return [65] * len(prompts) + + def step_packed_sampling(self, sequences, params_list): + self.step_packed_sampling_calls += 1 + self.step_packed_input_sizes.append(len(sequences)) + return [66] * len(sequences) + + def set_kv_context(self, ctx): + self.bind_calls.append(ctx) + return 0 + + def kv_context_create(self): + self._ctx_seq += 1 + return {"ctx_id": self._ctx_seq} + + def kv_context_release(self, ctx): + return None + + def export_kv_context(self, ctx, block_tokens): + self.export_calls.append((ctx, block_tokens)) + return 0 + + +class FakeModelNoPacked: + """Model without any packed methods.""" + + def __init__(self): + self._meta = _Meta() + self.bind_calls = [] + self.export_calls = [] + self.reset_calls = 0 + self._ctx_seq = 0 + + def reset_kv_cache(self): + self.reset_calls += 1 + + def prefill(self, prompt_ids): + return 65 + + def prefill_sampling(self, prompt_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return 65 + + def step(self, token_ids): + return 66 + + def step_sampling(self, token_ids, top_k=1, top_p=0.0, temperature=0.0, seed=0): + return 66 + + def set_kv_context(self, ctx): + self.bind_calls.append(ctx) + return 0 + + def kv_context_create(self): + self._ctx_seq += 1 + return {"ctx_id": self._ctx_seq} + + def kv_context_release(self, ctx): + return None + + def export_kv_context(self, ctx, block_tokens): + self.export_calls.append((ctx, block_tokens)) + return 0 + + +def _make_service(model=None, **kwargs): + if model is None: + model = FakeModel() + tok = FakeTokenizer() + service = ChatService( + model=model, + tokenizer=tok, + model_path=None, + enable_kv_runtime_reuse=kwargs.get("enable_kv_runtime_reuse", False), + block_size=kwargs.get("block_size", 4), + max_blocks=kwargs.get("max_blocks", 256), + max_bytes=kwargs.get("max_bytes", 1024 * 1024), + ) + return service, model + + +# =========================================================================== +# Test 1: Streaming batch produces correct SSE chunks +# =========================================================================== + +def test_streaming_batch_correct_chunks(): + """prepare_batch + step_batch should produce correct delta text for multiple sequences.""" + service, model = _make_service() + payloads = [ + {"session_id": "s1", "prompt": "hi", "max_new_tokens": 3}, + {"session_id": "s2", "prompt": "yo", "max_new_tokens": 3}, + ] + state = service.prepare_batch(payloads) + assert state is not None, "prepare_batch should return BatchState" + assert len(state.sequences) == 2 + assert state.sequences[0].context_id == "s1" + assert state.sequences[1].context_id == "s2" + + # First token already generated in prefill + for seq in state.sequences: + assert len(seq.generated_ids) == 1 + assert seq.generated_ids[0] == 65 # 'A' + + # Step until all done + all_deltas: Dict[str, str] = {"s1": "", "s2": ""} + rounds = 0 + while not all(s.finished for s in state.sequences): + results = service.step_batch(state) + for sr in results: + seq = state.sequences[sr.seq_index] + all_deltas[seq.context_id] += sr.delta_text + rounds += 1 + assert rounds < 20, "Too many decode rounds" + + # Each sequence should have generated max_new_tokens tokens + for seq in state.sequences: + assert len(seq.generated_ids) == 3 + assert seq.finish_reason == "length" + + print(" streaming batch correct chunks OK") + + +# =========================================================================== +# Test 2: Non-stream requests via batch path +# =========================================================================== + +def test_non_stream_via_batch_path(): + """Non-stream payloads should work through prepare_batch/step_batch.""" + service, model = _make_service() + payloads = [ + {"session_id": "ns1", "prompt": "hello", "max_new_tokens": 2}, + {"session_id": "ns2", "prompt": "world", "max_new_tokens": 2}, + ] + state = service.prepare_batch(payloads) + assert state is not None + + while not all(s.finished for s in state.sequences): + service.step_batch(state) + + for i, seq in enumerate(state.sequences): + assert seq.finished + assert len(seq.generated_ids) == 2 + service.finalize_sequence(state, i) + + print(" non-stream via batch path OK") + + +# =========================================================================== +# Test 3: Mixed stream + non-stream in same batch +# =========================================================================== + +def test_mixed_stream_non_stream_batch(): + """Both stream and non-stream payloads can be batched together.""" + service, model = _make_service() + payloads = [ + {"session_id": "mix-s", "prompt": "hi", "max_new_tokens": 2, "stream": True}, + {"session_id": "mix-ns", "prompt": "yo", "max_new_tokens": 2}, + ] + state = service.prepare_batch(payloads) + assert state is not None + assert len(state.sequences) == 2 + + while not all(s.finished for s in state.sequences): + service.step_batch(state) + + for seq in state.sequences: + assert seq.finished + assert len(seq.generated_ids) == 2 + + print(" mixed stream+non-stream batch OK") + + +# =========================================================================== +# Test 4: Single sequence cancellation +# =========================================================================== + +def test_single_sequence_cancellation(): + """Cancelling one sequence should not affect others.""" + service, model = _make_service() + payloads = [ + {"session_id": "cancel-1", "prompt": "hi", "max_new_tokens": 5}, + {"session_id": "cancel-2", "prompt": "yo", "max_new_tokens": 5}, + ] + state = service.prepare_batch(payloads) + assert state is not None + + # Cancel first sequence after prefill + state.sequences[0].cancel_event.set() + + results = service.step_batch(state) + # First sequence should be marked as cancelled + cancelled = [r for r in results if r.seq_index == 0] + assert len(cancelled) == 1 + assert cancelled[0].finished + assert cancelled[0].stopped + + # Second sequence should still be active + assert not state.sequences[1].finished + + # Continue stepping until second finishes + rounds = 0 + while not all(s.finished for s in state.sequences): + service.step_batch(state) + rounds += 1 + assert rounds < 20 + + assert state.sequences[1].finished + assert state.sequences[1].finish_reason == "length" + assert len(state.sequences[1].generated_ids) == 5 + + print(" single sequence cancellation OK") + + +# =========================================================================== +# Test 5: Different max_new_tokens (partial early finish) +# =========================================================================== + +def test_different_max_new_tokens(): + """Sequences with different max_new_tokens should finish at different times.""" + service, model = _make_service() + payloads = [ + {"session_id": "short", "prompt": "hi", "max_new_tokens": 1}, + {"session_id": "long", "prompt": "yo", "max_new_tokens": 4}, + ] + state = service.prepare_batch(payloads) + assert state is not None + + # Short sequence should finish after prefill (1 token generated) + assert state.sequences[0].finished, "1-token sequence should finish after prefill" + assert state.sequences[0].finish_reason == "length" + assert not state.sequences[1].finished + + # Step until long finishes + rounds = 0 + while not all(s.finished for s in state.sequences): + service.step_batch(state) + rounds += 1 + assert rounds < 20 + + assert state.sequences[1].finished + assert len(state.sequences[1].generated_ids) == 4 + + print(" different max_new_tokens OK") + + +# =========================================================================== +# Test 6: Batch size limit enforcement (via scheduler) +# =========================================================================== + +def test_batch_size_limit(): + """Scheduler should respect max_batch_size.""" + service, model = _make_service() + InferenceScheduler = scheduler_mod.InferenceScheduler + scheduler = InferenceScheduler( + [service], + queue_size=16, + request_timeout_ms=5000, + continuous_batching=True, + max_batch_size=2, + ) + scheduler.start() + try: + handles = [] + for i in range(4): + h = scheduler.submit( + {"session_id": f"bs-{i}", "prompt": "test", "max_new_tokens": 2, "stream": True}, + stream=True, + ) + handles.append(h) + + for h in handles: + items = list(h.iter_stream(timeout=5.0)) + assert len(items) > 0 + last = items[-1] + assert last["choices"][0]["finish_reason"] is not None + + snap = scheduler.debug_snapshot() + # Should have done multiple prefill batches since max_batch_size=2 and 4 tasks + assert snap["max_batch_size"] == 2 + finally: + scheduler.stop() + + print(" batch size limit OK") + + +# =========================================================================== +# Test 7: Dynamic shrink verification +# =========================================================================== + +def test_dynamic_shrink(): + """step_batch should only pass active sequences to model (dynamic shrinking).""" + service, model = _make_service() + payloads = [ + {"session_id": "shrink-1", "prompt": "hi", "max_new_tokens": 1}, # finishes after prefill + {"session_id": "shrink-2", "prompt": "yo", "max_new_tokens": 3}, + ] + state = service.prepare_batch(payloads) + assert state is not None + assert state.sequences[0].finished # 1 token = done after prefill + + model.step_packed_input_sizes.clear() + + # Step: only sequence 1 should be active + rounds = 0 + while not all(s.finished for s in state.sequences): + service.step_batch(state) + rounds += 1 + assert rounds < 20 + + # All step_packed calls should have received only 1 sequence (the active one) + for size in model.step_packed_input_sizes: + assert size == 1, f"Expected 1 active sequence in step_packed, got {size}" + + print(" dynamic shrink OK") + + +# =========================================================================== +# Test 8: Fallback to single path when no packed API +# =========================================================================== + +def test_fallback_no_packed_api(): + """prepare_batch should return None when model has no packed methods.""" + model = FakeModelNoPacked() + service, _ = _make_service(model=model) + payloads = [ + {"session_id": "fb-1", "prompt": "hi", "max_new_tokens": 2}, + ] + result = service.prepare_batch(payloads) + assert result is None, "Should return None without packed API" + print(" fallback no packed API OK") + + +def test_fallback_edit_from_session(): + """prepare_batch should return None for edit_from_session_id requests.""" + service, _ = _make_service() + payloads = [ + {"session_id": "edit-1", "prompt": "hi", "max_new_tokens": 2, + "edit_from_session_id": "other", "edit_message_index": 0}, + ] + result = service.prepare_batch(payloads) + assert result is None, "Should return None for edit requests" + print(" fallback edit_from_session OK") + + +# =========================================================================== +# Test 9: Scheduler integration - streaming batch end-to-end +# =========================================================================== + +def test_scheduler_streaming_batch_e2e(): + """Full end-to-end: scheduler uses prepare_batch/step_batch for streaming.""" + service, model = _make_service() + InferenceScheduler = scheduler_mod.InferenceScheduler + scheduler = InferenceScheduler( + [service], + queue_size=8, + request_timeout_ms=5000, + continuous_batching=True, + max_batch_size=4, + ) + scheduler.start() + try: + # Submit multiple stream requests + handles = [] + for i in range(3): + h = scheduler.submit( + {"session_id": f"e2e-{i}", "prompt": "test", "max_new_tokens": 3, "stream": True}, + stream=True, + ) + handles.append(h) + + # Collect all chunks + for i, h in enumerate(handles): + items = list(h.iter_stream(timeout=5.0)) + assert len(items) > 0, f"Stream {i} should produce chunks" + last = items[-1] + assert last["choices"][0]["finish_reason"] is not None, f"Stream {i} should have finish_reason" + assert last["session_id"] == f"e2e-{i}" + + snap = scheduler.debug_snapshot() + metrics = snap["metrics"] + # Should have used the batch path + assert metrics["stream_batch_prefill_batches"] >= 1.0 or metrics["stream_batch_fallback_tasks"] >= 1.0 + finally: + scheduler.stop() + + print(" scheduler streaming batch e2e OK") + + +# =========================================================================== +# Test 10: Scheduler non-stream via batch path +# =========================================================================== + +def test_scheduler_non_stream_batch(): + """Non-stream requests through continuous batching scheduler.""" + service, model = _make_service() + InferenceScheduler = scheduler_mod.InferenceScheduler + scheduler = InferenceScheduler( + [service], + queue_size=8, + request_timeout_ms=5000, + continuous_batching=True, + max_batch_size=4, + ) + scheduler.start() + try: + h = scheduler.submit( + {"session_id": "ns-sched", "prompt": "test", "max_new_tokens": 2}, + stream=False, + ) + result = h.get_result(timeout=5.0) + assert result["session_id"] == "ns-sched" + assert "choices" in result + finally: + scheduler.stop() + + print(" scheduler non-stream batch OK") + + +# =========================================================================== +# Test 11: Scheduler fallback path (no packed API) +# =========================================================================== + +def test_scheduler_fallback_path(): + """Scheduler should fall back to legacy iterator path when prepare_batch returns None.""" + model = FakeModelNoPacked() + service, _ = _make_service(model=model) + InferenceScheduler = scheduler_mod.InferenceScheduler + scheduler = InferenceScheduler( + [service], + queue_size=8, + request_timeout_ms=5000, + continuous_batching=True, + max_batch_size=4, + ) + scheduler.start() + try: + h = scheduler.submit( + {"session_id": "fb-sched", "prompt": "test", "max_new_tokens": 2, "stream": True}, + stream=True, + ) + items = list(h.iter_stream(timeout=5.0)) + assert len(items) > 0 + last = items[-1] + assert last["choices"][0]["finish_reason"] is not None + + snap = scheduler.debug_snapshot() + assert snap["metrics"]["stream_batch_fallback_tasks"] >= 1.0 + finally: + scheduler.stop() + + print(" scheduler fallback path OK") + + +# =========================================================================== +# Test 12: finalize_sequence saves messages +# =========================================================================== + +def test_finalize_saves_messages(): + """finalize_sequence should save assistant message to session history.""" + service, model = _make_service() + payloads = [ + {"session_id": "fin-1", "prompt": "hello", "max_new_tokens": 2}, + ] + state = service.prepare_batch(payloads) + assert state is not None + + while not all(s.finished for s in state.sequences): + service.step_batch(state) + + service.finalize_sequence(state, 0) + + # Verify session has saved messages + msgs = service._session_mgr.get_messages("fin-1") + assert len(msgs) >= 2, "Should have user + assistant messages" + assert msgs[-1]["role"] == "assistant" + assert len(msgs[-1]["content"]) > 0 + + print(" finalize saves messages OK") + + +# =========================================================================== +# Test 13: finalize_sequence on cancelled does not save +# =========================================================================== + +def test_finalize_cancelled_no_save(): + """finalize_sequence on a cancelled sequence should not save assistant message.""" + service, model = _make_service() + payloads = [ + {"session_id": "fin-cancel", "prompt": "hello", "max_new_tokens": 5}, + ] + state = service.prepare_batch(payloads) + assert state is not None + + # Cancel immediately + state.sequences[0].cancel_event.set() + service.step_batch(state) + service.finalize_sequence(state, 0) + + # Should not have saved assistant message + msgs = service._session_mgr.get_messages("fin-cancel") + has_assistant = any(m["role"] == "assistant" for m in msgs) + assert not has_assistant, "Cancelled sequence should not save assistant message" + + print(" finalize cancelled no save OK") + + +# =========================================================================== +# Test 14: Sampling batch via prepare_batch +# =========================================================================== + +def test_sampling_batch_prepare(): + """Sampling requests should use prefill_packed_sampling in prepare_batch.""" + service, model = _make_service() + payloads = [ + {"session_id": "samp-1", "prompt": "hi", "max_new_tokens": 2, "temperature": 0.8, "top_k": 50}, + {"session_id": "samp-2", "prompt": "yo", "max_new_tokens": 2, "temperature": 0.8, "top_k": 50}, + ] + state = service.prepare_batch(payloads) + assert state is not None + assert state.any_sampling + assert model.prefill_packed_sampling_calls >= 1 + + while not all(s.finished for s in state.sequences): + service.step_batch(state) + + assert model.step_packed_sampling_calls >= 1 + + print(" sampling batch prepare OK") + + +# =========================================================================== +# Runner +# =========================================================================== + +if __name__ == "__main__": + tests = [ + test_streaming_batch_correct_chunks, + test_non_stream_via_batch_path, + test_mixed_stream_non_stream_batch, + test_single_sequence_cancellation, + test_different_max_new_tokens, + test_batch_size_limit, + test_dynamic_shrink, + test_fallback_no_packed_api, + test_fallback_edit_from_session, + test_scheduler_streaming_batch_e2e, + test_scheduler_non_stream_batch, + test_scheduler_fallback_path, + test_finalize_saves_messages, + test_finalize_cancelled_no_save, + test_sampling_batch_prepare, + ] + + passed = 0 + failed = 0 + for test_fn in tests: + name = test_fn.__name__ + try: + print(f"[RUN ] {name}") + test_fn() + print(f"[PASS] {name}") + passed += 1 + except Exception as exc: + import traceback + print(f"[FAIL] {name}: {exc}") + traceback.print_exc() + failed += 1 + + print(f"\n{'='*60}") + print(f"Results: {passed} passed, {failed} failed, {passed + failed} total") + if failed > 0: + print("SOME TESTS FAILED") + sys.exit(1) + else: + print("ALL TESTS PASSED") diff --git a/test/test_tokenizer.py b/test/test_tokenizer.py new file mode 100644 index 000000000..587510416 --- /dev/null +++ b/test/test_tokenizer.py @@ -0,0 +1,47 @@ +import argparse +import os +from ctypes import c_char_p, c_int64, c_size_t, create_string_buffer + +from llaisys.libllaisys import LIB_LLAISYS + + +def test_sentencepiece(model_path: str, text: str): + tokenizer = LIB_LLAISYS.llaisysTokenizerCreateSentencePiece(model_path.encode("utf-8")) + if not tokenizer: + print("SentencePiece tokenizer not available or model load failed. Skipped.") + return + + # query required length + needed = LIB_LLAISYS.llaisysTokenizerEncode(tokenizer, text.encode("utf-8"), None, c_size_t(0)) + assert needed > 0 + + ids = (c_int64 * needed)() + n = LIB_LLAISYS.llaisysTokenizerEncode(tokenizer, text.encode("utf-8"), ids, c_size_t(needed)) + assert n > 0 + + # query decode length + decode_needed = LIB_LLAISYS.llaisysTokenizerDecode(tokenizer, ids, c_size_t(n), None, c_size_t(0)) + assert decode_needed > 0 + + out = create_string_buffer(decode_needed) + nbytes = LIB_LLAISYS.llaisysTokenizerDecode(tokenizer, ids, c_size_t(n), out, c_size_t(decode_needed)) + assert nbytes >= 0 + decoded = out.value.decode("utf-8") + assert decoded != "" + + LIB_LLAISYS.llaisysTokenizerDestroy(tokenizer) + print("Encoded ids:", list(ids)[: min(8, n)], "...") + print("Decoded text:", decoded) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default=os.environ.get("LLAISYS_TOKENIZER_MODEL", ""), type=str) + parser.add_argument("--text", default="我喜欢人工智能", type=str) + args = parser.parse_args() + + if not args.model: + print("No SentencePiece model path provided. Set --model or LLAISYS_TOKENIZER_MODEL. Skipped.") + else: + test_sentencepiece(args.model, args.text) + print("\033[92mTest passed!\033[0m\n") diff --git a/test/test_utils.py b/test/test_utils.py index 0f38f0c8e..597ee861c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,10 +1,12 @@ +from typing import Tuple + import llaisys import torch def random_tensor( shape, dtype_name, device_name, device_id=0, scale=None, bias=None -) -> tuple[torch.Tensor, llaisys.Tensor]: +) -> Tuple[torch.Tensor, llaisys.Tensor]: torch_tensor = torch.rand( shape, dtype=torch_dtype(dtype_name), @@ -64,7 +66,7 @@ def random_int_tensor(shape, device_name, dtype_name="i64", device_id=0, low=0, def zero_tensor( shape, dtype_name, device_name, device_id=0 -) -> tuple[torch.Tensor, llaisys.Tensor]: +) -> Tuple[torch.Tensor, llaisys.Tensor]: torch_tensor = torch.zeros( shape, dtype=torch_dtype(dtype_name), @@ -92,7 +94,7 @@ def zero_tensor( def arrange_tensor( start, end, device_name, device_id=0 -) -> tuple[torch.Tensor, llaisys.Tensor]: +) -> Tuple[torch.Tensor, llaisys.Tensor]: torch_tensor = torch.arange(start, end, device=torch_device(device_name, device_id)) llaisys_tensor = llaisys.Tensor( (end - start,), @@ -186,7 +188,7 @@ def time_op(func): def torch_device(device_name: str, device_id=0): if device_name == "cpu": return torch.device("cpu") - elif device_name == "nvidia": + elif device_name == "nvidia" or device_name == "iluvatar": return torch.device(f"cuda:{device_id}") else: raise ValueError(f"Unsupported device name: {device_name}") @@ -197,6 +199,8 @@ def llaisys_device(device_name: str): return llaisys.DeviceType.CPU elif device_name == "nvidia": return llaisys.DeviceType.NVIDIA + elif device_name == "iluvatar": + return llaisys.DeviceType.ILUVATAR else: raise ValueError(f"Unsupported device name: {device_name}") @@ -206,6 +210,8 @@ def device_name(llaisys_device: llaisys.DeviceType): return "cpu" elif llaisys_device == llaisys.DeviceType.NVIDIA: return "nvidia" + elif llaisys_device == llaisys.DeviceType.ILUVATAR: + return "iluvatar" else: raise ValueError(f"Unsupported llaisys device: {llaisys_device}") diff --git a/xmake.lua b/xmake.lua index 1f65f7a95..d72ee2bda 100644 --- a/xmake.lua +++ b/xmake.lua @@ -13,11 +13,29 @@ option("nv-gpu") set_description("Whether to compile implementations for Nvidia GPU") option_end() +option("sentencepiece") + set_default(false) + set_showmenu(true) + set_description("Enable SentencePiece tokenizer support") +option_end() + if has_config("nv-gpu") then add_defines("ENABLE_NVIDIA_API") includes("xmake/nvidia.lua") end +-- ILUVATAR -- +option("iluvatar-gpu") + set_default(false) + set_showmenu(true) + set_description("Whether to compile implementations for Iluvatar CoreX GPU") +option_end() + +if has_config("iluvatar-gpu") then + add_defines("ENABLE_ILUVATAR_API") + includes("xmake/iluvatar.lua") +end + target("llaisys-utils") set_kind("static") @@ -37,6 +55,12 @@ target("llaisys-device") set_kind("static") add_deps("llaisys-utils") add_deps("llaisys-device-cpu") + if has_config("nv-gpu") then + add_deps("llaisys-device-nvidia") + end + if has_config("iluvatar-gpu") then + add_deps("llaisys-device-iluvatar") + end set_languages("cxx17") set_warnings("all", "error") @@ -83,6 +107,12 @@ target_end() target("llaisys-ops") set_kind("static") add_deps("llaisys-ops-cpu") + if has_config("nv-gpu") then + add_deps("llaisys-ops-nvidia") + end + if has_config("iluvatar-gpu") then + add_deps("llaisys-ops-iluvatar") + end set_languages("cxx17") set_warnings("all", "error") @@ -106,8 +136,38 @@ target("llaisys") set_languages("cxx17") set_warnings("all", "error") add_files("src/llaisys/*.cc") + add_files("src/llaisys/*/*.cpp") + add_files("src/models/*/*.cpp") + add_files("src/models/*/*/*.cpp") + add_files("src/tokenizer/*/*.cpp") set_installdir(".") + if has_config("sentencepiece") then + add_defines("LLAISYS_ENABLE_SENTENCEPIECE") + add_links("sentencepiece") + end + if has_config("nv-gpu") then + set_languages("cxx17", "cuda") + set_policy("build.cuda.devlink", true) + add_links("cudadevrt", "cudart") + add_files("src/device/nvidia/devlink_stub.cu") + elseif has_config("iluvatar-gpu") then + -- No .cu files in this target, no CUDA toolchain + -- Use add_shflags to control exact link order: + -- 1. whole-archive iluvatar static libs (defines nvidia:: symbols) + -- 2. -lcudart AFTER the .a files (so cudart symbols are resolved) + add_shflags( + "-Wl,--whole-archive", + "build/linux/x86_64/release/libllaisys-ops-iluvatar.a", + "build/linux/x86_64/release/libllaisys-device-iluvatar.a", + "-Wl,--no-whole-archive", + "-L/usr/local/corex/lib64", + "-Wl,-rpath,/usr/local/corex/lib64", + "-lcudart", + {force = true} + ) + end + after_install(function (target) -- copy shared library to python package diff --git a/xmake/iluvatar.lua b/xmake/iluvatar.lua new file mode 100644 index 000000000..9afdf0d9e --- /dev/null +++ b/xmake/iluvatar.lua @@ -0,0 +1,117 @@ +-- Iluvatar CoreX GPU targets +-- Uses clang++ with CUDA frontend, NOT nvcc +-- We use on_build to completely bypass xmake's CUDA toolchain detection + +target("llaisys-device-iluvatar") + set_kind("static") + add_deps("llaisys-utils") + set_languages("cxx17") + set_warnings("all", "error") + + -- Do NOT add .cu files via add_files - that triggers xmake CUDA toolchain + -- Instead, build everything in on_build + on_build(function (target) + import("core.project.depend") + + local sourcedir = path.absolute("src/device/iluvatar") + local sources = { + path.join(sourcedir, "iluvatar_runtime_api.cu"), + path.join(sourcedir, "iluvatar_resource.cu"), + } + + local objectfiles = {} + for _, sourcefile in ipairs(sources) do + local objectfile = target:objectfile(sourcefile) + local objectdir = path.directory(objectfile) + if not os.isdir(objectdir) then + os.mkdir(objectdir) + end + + local dependfile = target:dependfile(objectfile) + depend.on_changed(function () + local argv = { + "-x", "cuda", + "--cuda-gpu-arch=ivcore10", + "--cuda-path=/usr/local/corex", + "-std=c++17", + "-fPIC", + "-O3", + "-DENABLE_ILUVATAR_API", + "-Iinclude", + "-I/usr/local/corex/include", + "-c", + "-o", objectfile, + sourcefile + } + os.vrunv("/usr/local/corex/bin/clang++", argv) + end, {dependfile = dependfile, files = {sourcefile}}) + + table.insert(objectfiles, objectfile) + end + + -- Archive into static library + local targetfile = target:targetfile() + local targetdir = path.directory(targetfile) + if not os.isdir(targetdir) then + os.mkdir(targetdir) + end + os.vrunv("ar", {"-cr", targetfile, table.unpack(objectfiles)}) + end) + + on_install(function (target) end) +target_end() + +target("llaisys-ops-iluvatar") + set_kind("static") + add_deps("llaisys-tensor") + set_languages("cxx17") + set_warnings("all", "error") + + -- Do NOT add .cu files via add_files + on_build(function (target) + import("core.project.depend") + + -- Find all .cu files under src/ops/*/nvidia/ + local sources = os.files("src/ops/*/nvidia/*.cu") + + local objectfiles = {} + for _, sourcefile in ipairs(sources) do + local objectfile = target:objectfile(sourcefile) + local objectdir = path.directory(objectfile) + if not os.isdir(objectdir) then + os.mkdir(objectdir) + end + + local dependfile = target:dependfile(objectfile) + depend.on_changed(function () + local argv = { + "-x", "cuda", + "--cuda-gpu-arch=ivcore10", + "--cuda-path=/usr/local/corex", + "-std=c++17", + "-fPIC", + "-O3", + "-DENABLE_ILUVATAR_API", + "-Iinclude", + "-I/usr/local/corex/include", + "-c", + "-o", objectfile, + sourcefile + } + os.vrunv("/usr/local/corex/bin/clang++", argv) + end, {dependfile = dependfile, files = {sourcefile}}) + + table.insert(objectfiles, objectfile) + end + + -- Archive into static library + local targetfile = target:targetfile() + local targetdir = path.directory(targetfile) + if not os.isdir(targetdir) then + os.mkdir(targetdir) + end + os.vrunv("ar", {"-cr", targetfile, table.unpack(objectfiles)}) + end) + + on_install(function (target) end) +target_end() diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua new file mode 100644 index 000000000..208d096b3 --- /dev/null +++ b/xmake/nvidia.lua @@ -0,0 +1,40 @@ +target("llaisys-device-nvidia") + set_kind("static") + add_deps("llaisys-utils") + set_languages("cxx17", "cuda") + set_warnings("all", "error") + if is_plat("windows") then + set_runtimes("MD") + add_cuflags("--compiler-options=/MD", "-rdc=true", {force = true}) + end + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cuflags("-rdc=true", "--compiler-options=-fPIC") + end + add_links("cudart") + add_links("cudadevrt") + add_links("nccl") + add_files("../src/device/nvidia/nvidia_runtime_api.cu") + add_files("../src/device/nvidia/nvidia_resource.cu") + add_files("../src/device/nvidia/nvidia_comm.cu") + on_install(function (target) end) +target_end() + +target("llaisys-ops-nvidia") + set_kind("static") + add_deps("llaisys-tensor") + set_languages("cxx17", "cuda") + set_warnings("all", "error") + if is_plat("windows") then + set_runtimes("MD") + add_cuflags("--compiler-options=/MD", "-rdc=true", {force = true}) + end + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cuflags("-rdc=true", "--compiler-options=-fPIC") + end + add_links("cudart") + add_links("cudadevrt") + add_files("../src/ops/*/nvidia/*.cu") + on_install(function (target) end) +target_end()