diff --git a/README.md b/README.md
index 456067c82..7704dbd5b 100644
--- a/README.md
+++ b/README.md
@@ -1,148 +1,147 @@
-# Welcome to LLAISYS
+# 欢迎使用 LLAISYS
English |
中文
-## Introduction
+## 简介
-LLAISYS (Let's Learn AI SYStem) is an educational project that aims to provide a platform for new and future AI engineers to learn how to build AI systems from scratch. LLAISYS consists of several assignments, which help students learn and build the basic modules, and projects that challenge them to add more fancy features to their systems. LLAISYS uses C++ as primary programming language for system backend, and is compiled into shared libraries exposing C language APIs. Frontend codes are written in Python which calls these APIs to provide more convenient testing and interaction with other architectures such as PyTorch.
+LLAISYS(Let's Learn AI SYStem)是一个教育项目,旨在为新手和未来的AI工程师提供一个从零开始构建AI系统的学习平台。LLAISYS包含多个作业,帮助学生学习和构建基础模块;以及一些项目挑战,让他们为系统添加更多高级功能。LLAISYS使用C++作为系统后端的主要编程语言,并编译成共享库,提供C语言API。前端代码使用Python编写,调用这些API以提供更便捷的测试和与其他架构(如PyTorch)的交互。
-### Project Structure Overview
+### 项目结构概览
-- `\include`: directory that contains of the header files which defines all the C APIs exposed by the shared library. (Functions declarations start with `__export`)
+- `\include`:包含所有定义共享库提供的C API的头文件的目录。(函数声明以`__export`开头)
-- `\src`: C++ source files.
- - `\src\llaisys` contains all the direct implementation of waht are defined in the header files and follows the same directory structure as the `\include`. This is also as far as C++ codes can go.
- - other directories contain the actual implementaion of different modules.
+- `\src`:C++源文件。
+ - `\src\llaisys`包含头文件中定义的所有直接实现,并遵循与`\include`相同的目录结构。这也是C++代码的边界。
+ - 其他目录包含不同模块的实际实现。
-- `xmake.lua`: build rules for llaisys backend. `\xmake` directory contains the sub-xmake files for different devices. You may add `nvidia.lua` in the directory in the future for instance to support CUDA.
+- `xmake.lua`:llaisys后端的构建规则。`\xmake`目录包含不同设备的子xmake文件。例如,将来可以在目录中添加`nvidia.lua`来支持CUDA。
-- `\python`: Python source files.
- - `\python\llaisys\libllaisys` contains all the ctypes wrapper functions of llaisys APIs. It basically matches the structure of C header files.
- - `\python\llaisys` contains Python warppers of the ctypes functions to make the package more Python-like.
+- `\python`:Python源文件。
+ - `\python\llaisys\libllaisys`包含llaisys API的所有ctypes封装函数。它基本上与C头文件的结构相匹配。
+ - `\python\llaisys`包含ctypes函数的Python包装器,使包更符合Python风格。
-- `\test`: Python test files that import llaisys python package.
+- `\test`:导入llaisys python包的Python测试文件。
-## Assignment #0: Getting Started
+## 作业 #0:入门
-### Task-0.1 Install Prerequisites
+### 任务-0.1 安装必备组件
-- Compile Tool: [Xmake](https://xmake.io/)
-- C++ Compiler: MSVC (Windows) or Clang or GCC
-- Python >= 3.9 (PyTorch, Transformers, etc.)
-- Clang-Format-16 (Optional): for formatting C++ codes.
+- 编译工具:[Xmake](https://xmake.io/)
+- C++编译器:MSVC(Windows)或Clang或GCC
+- Python >= 3.9(PyTorch、Transformers等)
+- Clang-Format-16(可选):用于格式化C++代码。
-### Task-0.2 Fork and Build LLAISYS
+### 任务-0.2 Fork并构建LLAISYS
-- FORK LLAISYS Repository and Clone it to your local machine. Both Windows and Linux are supported.
+- Fork LLAISYS仓库并克隆到本地机器。支持Windows和Linux。
-- Compile and Install
+- 编译和安装
```bash
- # compile c++ codes
+ # 编译c++代码
xmake
- # install llaisys shared library
+ # 安装llaisys共享库
xmake install
- # install llaisys python package
+ # 安装llaisys python包
pip install ./python/
```
-- Github Auto Tests
+- Github自动测试
- LLAISYS uses Github Actions to run automated tests on every push and pull request. You can see testing results on your repo page. All tests should pass once you have finished all assignment tasks.
+ LLAISYS使用Github Actions在每次推送和拉取请求时运行自动化测试。你可以在仓库页面上看到测试结果。完成所有作业任务后,所有测试都应该通过。
-### Task-0.3 Run LLAISYS for the First Time
+### 任务-0.3 首次运行LLAISYS
-- Run cpu runtime tests
+- 运行cpu运行时测试
```bash
python test/test_runtime.py --device cpu
```
- You should see the test passed.
+ 你应该看到测试通过。
-### Task-0.4 Download test model
+### 任务-0.4 下载测试模型
-- The model we use for assignments is [DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B).
+- 我们用于作业的模型是[DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)。
-- Run an inference test with the model using PyTorch
+- 使用PyTorch运行模型推理测试
```bash
python test/test_infer.py --model [dir_path/to/model]
```
- You can see that PyTorch is able to load the model and perform inference with the sample input. You can debug into `transformers` library codes to see how what is going on behind. Right now, your code cannot do anything yet, but you are going to build a system that can achieve the same functionality in the assignments.
+ 你可以看到PyTorch能够加载模型并使用示例输入执行推理。你可以调试进入`transformers`库代码来深入查看并了解其内部运作原理。现在,你的代码还无法执行任何操作,但在后续的作业中,你将构建一个能够实现相同功能的系统。
-## Assignment #1: Tensor
+## 作业 #1:张量
-Tensor is a data structure that represents multi-dimensional data. It is the basic building block of LLAISYS, and most AI frameworks such as PyTorch. In this assignment, you will learn how to implement a basic tensor class.
+张量是表示多维数据的数据结构。它是LLAISYS和大多数AI框架(如PyTorch)的基本构建单元。在这个作业中,你将学习如何实现一个基本的张量类。
-A Tensor object has the following fields:
+张量对象具有以下字段:
-- `storage`: a shared pointer to a memory block that stores the tensor's data. It can be shared by multiple tensors. Check storage class for more details.
-- `offset`: the starting index (in bytes) of the tensor in the storage.
-- `meta`: metadata that describes the tensor's shape, data type, and strides.
+- `storage`:指向存储张量数据的内存块的共享指针。它可以被多个张量共享。有关更多详细信息,请查看storage类。
+- `offset`:张量在存储中的起始索引(以字节为单位)。
+- `meta`:描述张量形状、数据类型和步长的元数据。
-Implement the following functions defined in the `src/tensor/tensor.hpp`:
+实现`src/tensor/tensor.hpp`中定义的以下函数:
-### Task-1.1
+### 任务-1.1
```c++
void load(const void *src);
```
-Load host (cpu) data to the tensor (can be on device). Check contructor to see how to get runtime apis of the current device context, and do a memcpy from host to device.
+将主机(cpu)数据加载到张量(可以在设备上)。查看构造函数了解如何获取当前设备上下文的运行时API,并执行从主机到设备的内存复制。
-### Task-1.2
+### 任务-1.2
```c++
bool isContiguous() const;
```
-Check shape and strides of the tensor, and tell wether it is contiguous in memory.
+检查张量的形状和步长,判断它在内存中是否连续。
-### Task-1.3
+### 任务-1.3
```c++
tensor_t view(const std::vector &shape) const;
```
-Create a new tensor which reshapes the original tensor to the given shape by splitting or merging the original dimensions. No data transfer is involved. For example change a tensor of shape (2, 3, 5) to (2, 15) by merging the last two dimensions.
+创建一个新张量,通过拆分或合并原始维度将原始张量重塑为给定形状。不涉及数据传输。例如,通过合并最后两个维度,将形状为(2, 3, 5)的张量更改为(2, 15)。
-This function is not as easy as simply changing the shape of the tensor, although the test will pass. It should raise an error if new view is not compatible with the original tensor. Think about a tensor of shape (2, 3, 5) and strides (30, 10, 1). Can you still reshape it to (2, 15) without data transfer?
+这个函数不是简单地改变张量的形状那么简单,尽管测试会通过。如果新视图与原始张量不兼容,它应该引发错误。想想一个形状为(2, 3, 5)、步长为(30, 10, 1)的张量。你还能在不传输数据的情况下将其重塑为(2, 15)吗?
-### Task-1.4
+### 任务-1.4
```c++
tensor_t permute(const std::vector &order) const;
```
-Create a new tensor which changes the order of the dimensions of original tensor. Transpose can be achieved by this function without moving data around.
+创建一个新张量,改变原始张量维度的顺序。转置可以通过这个函数实现,而无需移动数据。
-### Task-1.5
+### 任务-1.5
```c++
tensor_t slice(size_t dim, size_t start, size_t end) const;
```
-Create a new tensor which slices the original tensor along the given dimension,
-start (inclusive) and end (exclusive) indices.
+创建一个新张量,沿给定维度,start(包含)和end(不包含)索引对原始张量进行切片操作。
-### Task-1.6
+### 任务-1.6
-Run tensor tests.
+运行张量测试。
```bash
python test/test_tensor.py
```
-You should see all tests passed. Commit and push your changes. You should see the auto tests for assignment #1 passed.
+你应该看到所有测试都通过了。提交并推送你的更改。你应该看到作业#1的自动测试通过了。
-## Assignment #2: Operators
+## 作业 #2:算子
-In this assignment, you will implement the cpu verision the following operators:
+在这个作业中,你将实现以下算子的cpu版本:
- argmax
- embedding
@@ -152,102 +151,102 @@ In this assignment, you will implement the cpu verision the following operators:
- self_attention
- swiglu
-Read the codes in `src/ops/add/` to see how "add" operator is implemented. Make sure you understand how the operator codes are organized, compiled, linked, and exposed to Python frontend. **Your operators should at least support Float32, Float16 and BFloat16 data types**. A helper function for naive type casting is provided in `src/utils/`. All python tests are in `test/ops`, you implementation should at least pass these tests. Try running the test script for "add" operator for starting.
+阅读`src/ops/add/`中的代码,了解"add"算子是如何实现的。确保你理解算子代码是如何组织、编译、链接以及暴露给Python前端的。**你的算子应该至少支持Float32、Float16和BFloat16数据类型**。`src/utils/`中提供了一个用于简单类型转换的辅助函数。所有python测试都在`test/ops`中,你的实现应该至少通过这些测试。首先尝试运行"add"算子的测试脚本。
-### Task-2.1 argmax
+### 任务-2.1 Argmax
```c++
void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals);
```
-Get the max value and its index of tensor `vals`, and store them in `max_val` and `max_idx` respectively. You can assume that `vals` is a 1D tensor for now, and `max_idx` and `max_val` are both 1D tensors with a single element (, which means the dimension of `vals` is kept).
+获取张量`vals`的最大值及其索引,并分别存储在`max_val`和`max_idx`中。你暂时可以假设`vals`是一个1D张量,`max_idx`和`max_val`都是包含单个元素的1D张量(这意味着保留了`vals`的维度)。
-You should be able to pass the test cases in `test/ops/argmax.py` after you finish the implementation.
+完成实现后,你应该能够通过`test/ops/argmax.py`中的测试用例。
-### Task-2.2 embedding
+### 任务-2.2 Embedding
```c++
void embedding(tensor_t out, tensor_t index, tensor_t weight);
```
-Copy the rows in `index` (1-D) from `weight` (2-D) to `output` (2-D). `index` must be of type Int64 (the default data type for int of PyTorch).
+从`weight`(2-D)中复制`index`(1-D)中的行到`output`(2-D)。`index`必须是Int64类型(PyTorch中int的默认数据类型)。
-You should be able to pass the test cases in `test/ops/embedding.py` after you finish the implementation.
+完成实现后,你应该能够通过`test/ops/embedding.py`中的测试用例。
-### Task-2.3 linear
+### 任务-2.3 Linear
```c++
void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias);
```
-Compute the following:
+计算以下内容:
$$
Y = xW^T + b
$$
-- `out`: output $Y$ . You can assume output is a 2D contiguous tensor and no broadcasting is involved for now.
-- `input`: input $X$ . You can assume input is a 2D contiguous tensor and no broadcasting is involved for now.
-- `weight`: weight $W$ . 2D contiguous tensor. Note that weight tensor is not transposed. You need to deal with this during your calculation.
-- `bias` (optional): bias $b$ . 1D tensor. You need to support the situation where bias is not provided.
+- `out`:输出 $Y$ 。你暂时可以假设输出是一个2D连续张量,不涉及广播。
+- `input`:输入 $X$ 。你暂时可以假设输入是一个2D连续张量,不涉及广播。
+- `weight`:权重 $W$ 。2D连续张量。注意权重张量没有转置。你需要在计算过程中处理这个问题。
+- `bias`(可选):偏置 $b$ 。1D张量。你需要支持不提供偏置的情况。
-You should be able to pass the test cases in `test/ops/linear.py` after you finish the implementation.
+完成实现后,你应该能够通过`test/ops/linear.py`中的测试用例。
-### Task-2.4 rms normalization
+### 任务-2.4 RMS Normalization
```c++
void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps);
```
-Compute the following for each row:
+为每一行计算以下内容:
$$
Y_i = \frac{W_i \times X_i}{\sqrt{\frac{1}{d}(\sum_{j=1}^d X_j^2) + \epsilon}}
$$
-- `out`: output $Y$ . You can assume output is a 2D contiguous tensor and no broadcasting is involved for now.
-- `input`: input $X$ . You can assume input is a 2D contiguous tensor and no broadcasting is involved for now. The normalization is performed along the last dimension (a.k.a. each row of length $d$ ) of the input tensor.
-- `weight`: weight $W$ . 1D tensor, same length as a row of input tensor.
-- `eps`: small value $\epsilon$ to avoid division by zero.
+- `out`:输出 $Y$ 。你暂时可以假设输出是一个2D连续张量,不涉及广播。
+- `input`:输入 $X$ 。你暂时可以假设输入是一个2D连续张量,不涉及广播。标准化沿输入张量的最后一个维度(即每一行,长度为 $d$ )执行。
+- `weight`:权重 $W$ 。1D张量,与输入张量的一行长度相同。
+- `eps`:小值 $\epsilon$ 以避免除以零。
-You should be able to pass the test cases in `test/ops/rms_norm.py` after you finish the implementation.
+完成实现后,你应该能够通过`test/ops/rms_norm.py`中的测试用例。
-### Task-2.5 rope
+### 任务-2.5 旋转位置编码(RoPE)
```c++
void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta);
```
-Compute the following for each vector of input tensor `in`, corresponding to a position id in `pos_ids`:
+为输入张量`in`的每个向量(这些向量与 pos_ids 中的位置 id 相对应)计算以下内容:
-Let $\mathbf{x}_i = [\mathbf{a}_i, \mathbf{b}_i] \in \mathbb{R}^d$ be the input vector and $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i] \in \mathbb{R}^d$ be the output vector at index $i$, where $\mathbf{a}_i, \mathbf{b}_i,\mathbf{a}'_i, \mathbf{b}'_i \in \mathbb{R}^{d/2}$ .
+设 $\mathbf{x}_i = [\mathbf{a}_i, \mathbf{b}_i] \in \mathbb{R}^d$ 为输入向量, $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i] \in \mathbb{R}^d$ 为索引 $i$ 处的输出向量,其中 $\mathbf{a}_i, \mathbf{b}_i,\mathbf{a}'_i, \mathbf{b}'_i \in \mathbb{R}^{d/2}$ 。
-Let $\theta$ be a fixed base (e.g. $\theta = 10000$) and $j = 0, 1, \ldots, d/2 - 1$.
+设 $\theta$ 为固定基数(例如 $\theta = 10000$), $j = 0, 1, \ldots, d/2 - 1$。
-Let $p_i \in \mathbb{N}$ is the position id for token at input index i.
+设 $p_i \in \mathbb{N}$ 是输入索引i处token的位置id。
-Then the angle for RoPE is $\phi_{i,j} = \frac{p_i}{\theta^{2j/d}}$
+那么RoPE的角度为 $\phi_{i,j} = \frac{p_i}{\theta^{2j/d}}$
-The output vector $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i]$ is computed as follows:
+输出向量 $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i]$ 计算如下:
$$a_{i,j}' = a_{i,j} \cos(\phi_{i,j}) - b_{i,j} \sin(\phi_{i,j})$$
$$b_{i,j}' = b_{i,j} \cos(\phi_{i,j}) + a_{i,j} \sin(\phi_{i,j})$$
-- `out`: the resulting **q** or **k** tensor. Shape should be [seqlen, nhead, d] or [seqlen, nkvhead, d]. You can assume that the tensor is contiguous for now.
-- `in`: the orignal **q** or **k** tensor. Shape should be [seqlen, nhead, d] or [seqlen, nkvhead, d]. You can assume that the tensor is contiguous for now.
-- `pos_ids`: the position id (index in the whole context) for each token in the input sequence. Shape should be [seqlen,], dtype should be int64.
-- `theta`: the base value for the frequency vector.
+- `out`:结果**q**或**k**张量。形状应该是 [seqlen, nhead, d] 或 [seqlen, nkvhead, d]。你暂时可以假设张量是连续的。
+- `in`:原始**q**或**k**张量。形状应该是 [seqlen, nhead, d] 或 [seqlen, nkvhead, d]。你暂时可以假设张量是连续的。
+- `pos_ids`:输入序列中每个token的位置id(整个上下文中的索引)。形状应该是 [seqlen,],dtype应该是int64。
+- `theta`:频率向量的基值。
-You should be able to pass the test cases in `test/ops/rope.py` after you finish the implementation.
+完成实现后,你应该能够通过`test/ops/rope.py`中的测试用例。
-### Task-2.6 self-attention
+### 任务-2.6 自注意力(self-attention)
```c++
void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale);
```
-Compute the self-attention for query tensor `q`, key tensor `k`, and value tensor `v`. You should concat kvcache tensors, if needed, before doing this calculation.
+为查询张量`q`、键张量`k`和值张量`v`计算自注意力。如果需要,你应该在进行此计算之前连接kvcache张量。
$$
A = Q K^\top * scale \\
@@ -257,108 +256,109 @@ $$
Y = \mathrm{causalsoftmax}(A) \cdot V \\
$$
-- `attn_val`: the resulting attention value tensor. Shape should be [seqlen, nhead, dv]. You can assume that the tensor is contiguous for now.
-- `q`: the query tensor. Shape should be [seqlen, nhead, d]. You can assume that the tensor is contiguous for now.
-- `k`: the key tensor. Shape should be [total_len, nkvhead, d]. You can assume that the tensor is contiguous for now.
-- `v`: the value tensor. Shape should be [total_len, nkvhead, dv]. You can assume that the tensor is contiguous for now.
-- `scale`: a scaling factor. It is set to $\frac{1}{\sqrt{d}}$ in most cases.
+- `attn_val`:结果注意力值张量。形状应该是[seqlen, nhead, dv]。你暂时可以假设张量是连续的。
+- `q`:查询张量。形状应该是 [seqlen, nhead, d]。你暂时可以假设张量是连续的。
+- `k`:键张量。形状应该是 [total_len, nkvhead, d]。你暂时可以假设张量是连续的。
+- `v`:值张量。形状应该是 [total_len, nkvhead, dv]。你暂时可以假设张量是连续的。
+- `scale`:缩放因子。在大多数情况下取值为 $\frac{1}{\sqrt{d}}$ 。
-You should be able to pass the test cases in `test/ops/self_attention.py` after you finish the implementation.
+完成实现后,你应该能够通过`test/ops/self_attention.py`中的测试用例。
-### Task-2.7 swiglu
+### 任务-2.7 SwiGLU
```c++
void swiglu(tensor_t out, tensor_t gate, tensor_t up);
```
-This is an element-wise function that computes the following:
+这是一个逐元素函数,计算以下内容:
$$
out_{i} = up_{i} \circ \frac { gate_{i}}{1 + e^{-gate_{i}}}
$$
-`out`, `up` and `gate` are 2D contiguous tensors with the same shape [seqlen, intermediate_size].
+`out`、`up`和`gate`是具有相同形状 [seqlen, intermediate_size] 的2D连续张量。
-You should be able to pass the test cases in `test/ops/swiglu.py` after you finish the implementation.
+完成实现后,你应该能够通过`test/ops/swiglu.py`中的测试用例。
-### Task-2.8
+### 任务-2.8
-Run operator tests.
+运行算子测试。
```bash
python test/test_ops.py
```
-You should see all tests passed. Commit and push your changes. You should see the auto tests for assignment #2 passed.
+你应该看到所有测试都通过了。提交并推送你的更改。你应该看到作业#2的自动测试通过了。
-### Task-2.9 (Optional) rearrange
+### 任务-2.9(可选)rearrange
-This is a bonus task. You may or may not need it for model inference.
+这是一个奖励任务。你在模型推理中可能需要也可能不需要它。
```c++
void rearrange(tensor_t out, tensor_t in);
```
-This operator is used to copy data from a tensor to another tensor with the same shape but different strides. With this, you can easily implement `contiguous` functionality for tensors.
+此算子用于将数据从一个张量复制到另一个具有相同形状但不同步长的张量。有了这个,你可以轻松地为张量实现`contiguous`功能。
-## Assignment #3: Large Language Model Inference
+## 作业 #3:大语言模型推理
-Finally, it is the time for you to achieve text generation with LLAISYS.
+终于,是时候用LLAISYS实现文本生成了。
-- In `test/test_infer.py`, your implementation should be able to generate the same texts as PyTorch, using argmax sampling. The model we use for this assignment is [DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B).
+- 在`test/test_infer.py`中,你的实现应该能够使用argmax采样生成与PyTorch相同的文本。我们用于此作业的模型是[DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)。
-- The python wrapper of your implementation is in `python/llaisys/models/qwen2.py`. You are NOT allowed to implement your model infer logic here using any python based frameworks, such as PyTorch. Instead, you need to implement the model with C/C++ in LLAISYS backend. The script loads each tensor in the safetensors file, and you will need to load data from them into your model backend.
+- 你的实现的python包装器在`python/llaisys/models/qwen2.py`中。你不允许在这里使用任何基于python的框架(如PyTorch)实现你的模型推理逻辑。相反,你需要在LLAISYS后端用C/C++实现模型。脚本加载safetensors文件中的每个张量,你需要从它们加载数据到你的模型后端。
-- In `include/llaisys/models/qwen2.h`, a prototype is defined for you. Feel free to modify the codes as you want, but you should at least provide basic APIs for model creation, destruction, data loading, and infer. Implement your C APIs in `src/llaisys/` and organize your C++ codes as other modules in `src/`. Remember to define the compiling procedures in `xmake.lua`.
+- 在`include/llaisys/models/qwen2.h`中,为你定义了一个原型。你可以随意修改代码,但你应该至少提供模型创建、销毁、数据加载和推理的基本API。在`src/llaisys/`中实现你的C API,并像`src/`中的其他模块一样组织你的C++代码。记得在`xmake.lua`中定义编译过程。
-- In `python/llaisys/libllaisys/`, define the ctypes wrapper functions for your C APIs. Implement `python/llaisys/models/qwen2.py` with your wrapper functions.
+- 在`python/llaisys/libllaisys/`中,为你的C API定义ctypes包装函数。使用你的包装函数实现`python/llaisys/models/qwen2.py`。
-- You need to implement KV Cache, or your model will be too slow.
+- 你需要实现 KV-Cache 功能,否则模型推理速度会过慢。
-- Debug until your model works. Take advantage of tensor's `debug` function which prints the tensor data. It allows you to compare the data of any tensor during the model inference with PyTorch.
+- 调试直到你的模型工作。利用张量的`debug`函数打印张量数据。它允许你在模型推理期间将任何张量的数据与PyTorch进行比较。
-After you finish the implementation, you can run the following command to test your model:
+完成实现后,你可以运行以下命令来测试你的模型:
```bash
python test/test_infer.py --model [dir_path/to/model] --test
```
-Commit and push your changes. You should see the auto tests for assignment #3 passed.
+提交并推送你的更改。你应该看到作业#3的自动测试通过了。
+## 只有完成作业后,才能开始做项目。
-## You can proceed to the projects only after you finish the assignments.
+## 项目#1:优化 LLAISYS 的 CPU 推理
-## Project #1: Optimize LLAISYS for CPU
-You probably have already noticed that your model inference is very slow compared to PyTorch. This is mostly because your operators are not optimized. Run your operater test scripts with "--profile" flag to see how your operators perform. You would probably see that `linear` operation is much slower than PyTorch. This operator is mainly a matrix multiplication, and is the most time consuming operation in transformer-based models.
+你可能已经注意到,你的模型推理速度相比 PyTorch 非常慢。这主要是因为你的算子没有经过优化。运行算子测试脚本时加上 ``--profile`` 参数,看看算子的性能表现。你可能会发现 ``linear`` 操作比 PyTorch 慢很多。这个算子本质上是矩阵乘法,是 Transformer 模型里最耗时的操作。
-There are several ways to optimize your operators for CPU:
+以下是几种优化 CPU 算子的方法:
-### SIMD instructions
+### 使用 SIMD 指令
-SIMD (Single Instruction Multiple Data) instructions are instructions that can perform the same operation on multiple data elements in a single instruction. Modern CPUs have support for SIMD instructions. Look for online materials to learn about compiler intrinsics (such as AVX2, AVX-512, NEON, SVE) to vectorize your operations.
+SIMD(单指令多数据)是一类可以在单条指令中对多个数据元素同时执行相同操作的指令。现代 CPU 都支持 SIMD。你可以查阅相关资料,学习编译器内建函数(如 AVX2、AVX-512、NEON、SVE)来向量化你的算子。
-### Use OpenMP for parallelism
+### 使用 OpenMP 实现并行
-You can use multi-threading to parallelize your operators. OpenMP is a popular library for multi-threading in C/C++. Add OpenMP support for LLAISYS to parallelize your `linear` and other operators.
+你可以用多线程来并行化算子。OpenMP 是 C/C++ 中常见的多线程库。为 LLAISYS 增加 OpenMP 支持,使得 ``linear`` 等算子能够并行执行。
-### 3rd-party Libraries
+### 使用第三方库
-There are several libraries that can help you optimize your operators for CPU. Look for libraries like Eigen, OpenBLAS, MKL, etc. to optimize your linear algebra operations. Note that some libraries are supported only for certain hardware platforms. Check their documentations and use them in your codes with care. You can also try to dig out how PyTorch implement these operators and see if you can use them.
+有很多库能帮你优化 CPU 上的算子,例如 Eigen、OpenBLAS、MKL 等,它们能高效处理线性代数运算。但要注意,有些库只支持特定硬件平台,需要仔细阅读文档并小心使用。你也可以参考 PyTorch 的算子实现,看是否能复用。
-Optimize your implementation with any methods you like and report your performance improvement.
+用任何你喜欢的方法优化你的推理实现,并报告性能提升情况。
-## Project #2: Intigrate CUDA into LLAISYS
+## 项目#2:在 LLAISYS 中集成 CUDA,适配两款CUDA或类CUDA平台(以下统称CUDA)
-This project does not depend on **Project #1**. You should choose two CUDA/CUDA-ish hardware platforms from Nvidia, Iluvatar, Metax, and Moore Threads.
+这个项目不依赖 ``项目#1``。需要选择 Nvidia、天数、摩尔、沐曦中的至少两款平台。
-This camp session provides computation resources from the four platforms above, access to which is granted based on applications from the official website. You can accelerate your model with CUDA on these GPU platforms. Before doing that, let's dive deeper into LLAISYS framework.
+本次训练营提供了以上四种平台的算力,可以在官方进行申请算力,并用 CUDA 加速模型推理。在动手前,先深入理解 LLAISYS 框架。
-LLAISYS is actually a framework with homogeous hardware support. When using LLAISYS, each thread will create a thread-local `Context` object which manages all the device `Runtime` objects used by this thread. A `Runtime` object is a resource manager for a device, and `Context` will create (with lazy initialization) a single `Runtime` object for each device. You can set and switch between them using `setDevice` function in `Context`. Only one device will be active at a time for each thread. Check `src/core/context.hpp` for more details.
+事实上,LLAISYS 是一个支持同构硬件的框架。使用时,每个线程会创建一个线程唯一的 **Context** 对象,管理该线程使用的所有设备 **Runtime**。**Runtime** 对象是设备的资源管理器,**Context** 会为每个设备(以延迟初始化的方式)创建唯一的 **Runtime**。你可以用 ``setDevice`` 在不同设备间切换,每个线程同一时间只会激活一个设备。详情见 ``src/core/context.hpp``。
-### Implement CUDA Runtime APIs
-Each `Runtime` object is intialized with a set of generic functions called `Runtime APIs`. You will need to implement CUDA version of these APIS. Check `src/device/cpu/cpu_runtime_api.cpp` to see how these functions are implemented for CPU and look for CUDA APIs to use in [`CUDA Runtime documentation`](https://docs.nvidia.com/cuda/cuda-runtime-api/index.html).
+### 实现 CUDA Runtime API
-You can see in `src/device/runtime_api.hpp` that `nvidia::getRuntimeAPI()` is guarded by `ENABLE_NVIDIA_API` macro.
+每个 **Runtime** 对象都会初始化一组通用的 **Runtime API**。你需要实现 CUDA 版本的 API。参考 ``src/device/cpu/cpu_runtime_api.cpp`` 看 CPU 的实现方式,查阅 [`CUDA Runtime 文档`](https://docs.nvidia.com/cuda/cuda-runtime-api/index.html) 找到对应 API。
+
+在 ``src/device/runtime_api.hpp`` 中,``nvidia::getRuntimeAPI()`` 被 ``ENABLE_NVIDIA_API`` 宏保护:
```c++
#ifdef ENABLE_NVIDIA_API
@@ -368,9 +368,9 @@ const LlaisysRuntimeAPI *getRuntimeAPI();
#endif
```
-This macro is defined in `xmake.lua` as a switch to enable/disable CUDA support. CUDA codes will not be compiled if the switch is off. In `xmake/` directory, create a `nvidia.lua` that configs your compiling process. (Similar to `cpu.lua` for CPU.) Search online to learn how to do it with Xmake.
+该宏的定义在 ``xmake.lua`` 中,用于开关 CUDA 支持。若关闭,CUDA 代码不会被编译。你需要在 ``xmake/`` 下新建 ``nvidia.lua``,配置编译流程(参考 ``cpu.lua``)。查阅资料学习如何用 Xmake 配置。
-After you implement the CUDA Runtime APIs, config your xmake with `--nv-gpu=y` to enable CUDA support and recompile your program. Run runtime tests to see if your implementation works.
+完成 CUDA Runtime API 后,用 ``--nv-gpu=y`` 打开 CUDA 支持并重新编译,运行测试:
```bash
xmake f --nv-gpu=y -cv
@@ -379,53 +379,54 @@ xmake install
python test/test_runtime.py --device nvidia
```
-### Implement CUDA Operators
-Create a `nvdia/` sub-directory in each operator source directory and implement a cuda version. Check `src/ops/add/op.cpp` to see how to include your cuda implementations. Remeber to define the compiling procedures in the xmake files. Run the operator tests with `--device nvidia` flag to test your CUDA implementation.
+### 实现 CUDA 算子
+
+在每个算子目录下新建 ``nvidia/`` 子目录,写 CUDA 版本实现。参考 ``src/ops/add/op.cpp`` 看如何包含 CUDA 实现。别忘了在 xmake 文件中定义编译流程。用 ``--device nvidia`` 参数运行测试。
-You can use CUDA libraries like cuBLAS, cuDNN, etc. to accelerate your operators. Check their documentations to see how to use them. You can store extra device resources in `src/device/nvidia/nvidia_resource.cu`.
+你可以使用 cuBLAS、cuDNN 等 CUDA 库来加速算子,额外的设备资源可以放在 `src/device/nvidia/nvidia_resource.cu`。
-Modify your model codes to support CUDA inference.
+最后,修改模型代码,支持 CUDA 推理:
```bash
python test/test_infer.py --model [dir_path/to/model] --test --device nvidia
```
-## Project #3: Build an AI chatbot
+## 项目#3:构建 AI 聊天机器人
-In this project you will build an AI chatbot that can do live conversations with single user with LLAISYS.
+本项目中,你将用 LLAISYS 构建一个能与单用户实时对话的聊天机器人。
-### Random Sampling
+### 随机采样
-So far we have been testing our model with argmax sampling. This is good enough for testing, but a chatbot should be able to generate more natural responses. Implement a random sample operator. Try to add supports for **Temperature**, **Top-K** and **Top-P**.
+目前我们只用过 argmax 采样,这在测试时够用,但聊天机器人需要更自然的回复。请实现一个随机采样算子,并尽量支持 **Temperature**、**Top-K**、**Top-P**。
-### Build a Chatbot Server
+### 搭建聊天服务器
-In your Python frontend, implement a server that can receive http requests from user and send responses back. You can use frameworks like FastAPI to build the server. You should follow the OpenAI chat-completion APIs. Try to support streaming responses if you can. You can assume, for now, that the server is only serving one user, and block the endpoint until the previous request is served.
+在 Python 前端里,实现一个能接收 HTTP 请求并返回响应的服务器。可以用 FastAPI 等框架。接口最好遵循 OpenAI 的 chat-completion API。如果可以,尽量支持流式输出。你可以先假设只有一个用户在使用,每次请求可以阻塞直到处理完成。
+### 交互式聊天 UI
-### Interactive Chat UI
+实现一个 UI,能向服务器发送请求并接收回复。可以是命令行界面,也可以是 Web 界面。要能通过连续发送消息与机器人保持对话。
-Build a UI that send requests to and receive responses from the chatbot server. You can build a simple command-line interface or a fancy web interface. You should be able to keep a conversation going with the chatbot by sending messages and receiving responses consecutively.
+### (可选)会话管理
-### (Optional) Chat Session Management
+实际应用中,用户可以开启多个对话并在它们之间切换,还能修改历史问题让 AI 重新生成回答。扩展 UI,支持这些功能。实现一个支持前缀匹配的 KV-Cache 池,尽可能复用已有结果。
-In real-world AI applications, users are allowed to start new conversations and switch between them. Users can also edit a past question and let the AI regenerate an answer. Enhance your UI to support these features. Implement a KV-Cache pool with prefix matching to reuse past results as much as possible.
+## 项目#4:多用户推理服务
+在做这个项目之前,你需要完成 ``项目#3`` 并实现流式输出。
-## Project #4: Multi-user Inference Service
+### 支持多用户
-You need to finish **Project #2** and achieve streaming response first before proceeding to this project.
+现实中推理服务要同时为多个用户提供服务,请求可能随时到来。你的服务端需要将请求加入请求池/队列,并用单独的循环线程/进程来处理。
-### Serving Multiple Users
+### 连续批处理
-In real-world scenarios, an inference service will serve multiple users. Requests can come in at any time, and the service should be able to handle them concurrently. Your endpoint should add a new request to a request pool or queue and have a another looping process or thread to serve the requests.
+为了最大化吞吐量,你需要做批处理,而不是逐一处理。由于每个请求长度不同,需要实现连续的迭代级批处理机制:每轮从池中取出若干请求组成批次(batch),执行一次批量推理,再把未完成的请求放回池中。推理时尽量用批量矩阵乘法加速。注意每个请求需要绑定不同的 KV-Cache,应实现支持前缀匹配的 KV-Cache 池来复用结果。
-### Continous Batching
-To maximize the throughput of your inference service, you need to batch your requests instead of serving them one by one. Since each request can have different length, you will need a continous and iteration-level batching mechanism. For each interation you extract several requests from pool to form a batch, do one round of batch inference, and then return the unfinished requests back to the pool. Use batched matrix multiplication when possible to speed up your inference. Note that every request in the batch need to bind with a different KV-Cache. You should build a KV-Cache pool with prefix matching to reuse past results as much as possible.
+## 项目#5:分布式推理
-## Project #5: Distributed Inference
-Introduce Tensor Parallelism to LLAISYS. Shard your model across multiple devices and implement distributed model inference. Support NCCL in LLAISYS if your are uing Nvidia GPUs, or MPI if you are using CPUs.
+在 LLAISYS 中引入张量并行。把模型分片到多个设备上,实现分布式推理。如果用 Nvidia GPU,需要支持 NCCL;如果用 CPU,需要支持 MPI。
-## Project #6: Support New Models
+## 项目#6:支持新模型
-Support another model type than the one we use for homework in LLAISYS.
+在 LLAISYS 中支持除作业所用模型以外的其他模型。
diff --git a/README_ZN.md b/README_ZN.md
deleted file mode 100644
index 7704dbd5b..000000000
--- a/README_ZN.md
+++ /dev/null
@@ -1,432 +0,0 @@
-# 欢迎使用 LLAISYS
-
-
-English |
-中文
-
-
-## 简介
-
-LLAISYS(Let's Learn AI SYStem)是一个教育项目,旨在为新手和未来的AI工程师提供一个从零开始构建AI系统的学习平台。LLAISYS包含多个作业,帮助学生学习和构建基础模块;以及一些项目挑战,让他们为系统添加更多高级功能。LLAISYS使用C++作为系统后端的主要编程语言,并编译成共享库,提供C语言API。前端代码使用Python编写,调用这些API以提供更便捷的测试和与其他架构(如PyTorch)的交互。
-
-### 项目结构概览
-
-- `\include`:包含所有定义共享库提供的C API的头文件的目录。(函数声明以`__export`开头)
-
-- `\src`:C++源文件。
- - `\src\llaisys`包含头文件中定义的所有直接实现,并遵循与`\include`相同的目录结构。这也是C++代码的边界。
- - 其他目录包含不同模块的实际实现。
-
-- `xmake.lua`:llaisys后端的构建规则。`\xmake`目录包含不同设备的子xmake文件。例如,将来可以在目录中添加`nvidia.lua`来支持CUDA。
-
-- `\python`:Python源文件。
- - `\python\llaisys\libllaisys`包含llaisys API的所有ctypes封装函数。它基本上与C头文件的结构相匹配。
- - `\python\llaisys`包含ctypes函数的Python包装器,使包更符合Python风格。
-
-- `\test`:导入llaisys python包的Python测试文件。
-
-## 作业 #0:入门
-
-### 任务-0.1 安装必备组件
-
-- 编译工具:[Xmake](https://xmake.io/)
-- C++编译器:MSVC(Windows)或Clang或GCC
-- Python >= 3.9(PyTorch、Transformers等)
-- Clang-Format-16(可选):用于格式化C++代码。
-
-### 任务-0.2 Fork并构建LLAISYS
-
-- Fork LLAISYS仓库并克隆到本地机器。支持Windows和Linux。
-
-- 编译和安装
-
- ```bash
- # 编译c++代码
- xmake
- # 安装llaisys共享库
- xmake install
- # 安装llaisys python包
- pip install ./python/
- ```
-
-- Github自动测试
-
- LLAISYS使用Github Actions在每次推送和拉取请求时运行自动化测试。你可以在仓库页面上看到测试结果。完成所有作业任务后,所有测试都应该通过。
-
-### 任务-0.3 首次运行LLAISYS
-
-- 运行cpu运行时测试
-
- ```bash
- python test/test_runtime.py --device cpu
- ```
-
- 你应该看到测试通过。
-
-### 任务-0.4 下载测试模型
-
-- 我们用于作业的模型是[DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)。
-
-- 使用PyTorch运行模型推理测试
-
- ```bash
- python test/test_infer.py --model [dir_path/to/model]
- ```
-
- 你可以看到PyTorch能够加载模型并使用示例输入执行推理。你可以调试进入`transformers`库代码来深入查看并了解其内部运作原理。现在,你的代码还无法执行任何操作,但在后续的作业中,你将构建一个能够实现相同功能的系统。
-
-## 作业 #1:张量
-
-张量是表示多维数据的数据结构。它是LLAISYS和大多数AI框架(如PyTorch)的基本构建单元。在这个作业中,你将学习如何实现一个基本的张量类。
-
-张量对象具有以下字段:
-
-- `storage`:指向存储张量数据的内存块的共享指针。它可以被多个张量共享。有关更多详细信息,请查看storage类。
-- `offset`:张量在存储中的起始索引(以字节为单位)。
-- `meta`:描述张量形状、数据类型和步长的元数据。
-
-实现`src/tensor/tensor.hpp`中定义的以下函数:
-
-### 任务-1.1
-
-```c++
-void load(const void *src);
-```
-
-将主机(cpu)数据加载到张量(可以在设备上)。查看构造函数了解如何获取当前设备上下文的运行时API,并执行从主机到设备的内存复制。
-
-### 任务-1.2
-
-```c++
-bool isContiguous() const;
-```
-
-检查张量的形状和步长,判断它在内存中是否连续。
-
-### 任务-1.3
-
-```c++
-tensor_t view(const std::vector &shape) const;
-```
-
-创建一个新张量,通过拆分或合并原始维度将原始张量重塑为给定形状。不涉及数据传输。例如,通过合并最后两个维度,将形状为(2, 3, 5)的张量更改为(2, 15)。
-
-这个函数不是简单地改变张量的形状那么简单,尽管测试会通过。如果新视图与原始张量不兼容,它应该引发错误。想想一个形状为(2, 3, 5)、步长为(30, 10, 1)的张量。你还能在不传输数据的情况下将其重塑为(2, 15)吗?
-
-### 任务-1.4
-
-```c++
-tensor_t permute(const std::vector &order) const;
-```
-
-创建一个新张量,改变原始张量维度的顺序。转置可以通过这个函数实现,而无需移动数据。
-
-### 任务-1.5
-
-```c++
-tensor_t slice(size_t dim, size_t start, size_t end) const;
-```
-
-创建一个新张量,沿给定维度,start(包含)和end(不包含)索引对原始张量进行切片操作。
-
-### 任务-1.6
-
-运行张量测试。
-
-```bash
-python test/test_tensor.py
-```
-
-你应该看到所有测试都通过了。提交并推送你的更改。你应该看到作业#1的自动测试通过了。
-
-## 作业 #2:算子
-
-在这个作业中,你将实现以下算子的cpu版本:
-
-- argmax
-- embedding
-- linear
-- rms_norm
-- rope
-- self_attention
-- swiglu
-
-阅读`src/ops/add/`中的代码,了解"add"算子是如何实现的。确保你理解算子代码是如何组织、编译、链接以及暴露给Python前端的。**你的算子应该至少支持Float32、Float16和BFloat16数据类型**。`src/utils/`中提供了一个用于简单类型转换的辅助函数。所有python测试都在`test/ops`中,你的实现应该至少通过这些测试。首先尝试运行"add"算子的测试脚本。
-
-### 任务-2.1 Argmax
-
-```c++
-void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals);
-```
-
-获取张量`vals`的最大值及其索引,并分别存储在`max_val`和`max_idx`中。你暂时可以假设`vals`是一个1D张量,`max_idx`和`max_val`都是包含单个元素的1D张量(这意味着保留了`vals`的维度)。
-
-完成实现后,你应该能够通过`test/ops/argmax.py`中的测试用例。
-
-### 任务-2.2 Embedding
-
-```c++
-void embedding(tensor_t out, tensor_t index, tensor_t weight);
-```
-
-从`weight`(2-D)中复制`index`(1-D)中的行到`output`(2-D)。`index`必须是Int64类型(PyTorch中int的默认数据类型)。
-
-完成实现后,你应该能够通过`test/ops/embedding.py`中的测试用例。
-
-### 任务-2.3 Linear
-
-```c++
-void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias);
-```
-
-计算以下内容:
-
-$$
-Y = xW^T + b
-$$
-
-- `out`:输出 $Y$ 。你暂时可以假设输出是一个2D连续张量,不涉及广播。
-- `input`:输入 $X$ 。你暂时可以假设输入是一个2D连续张量,不涉及广播。
-- `weight`:权重 $W$ 。2D连续张量。注意权重张量没有转置。你需要在计算过程中处理这个问题。
-- `bias`(可选):偏置 $b$ 。1D张量。你需要支持不提供偏置的情况。
-
-完成实现后,你应该能够通过`test/ops/linear.py`中的测试用例。
-
-### 任务-2.4 RMS Normalization
-
-```c++
-void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps);
-```
-
-为每一行计算以下内容:
-
-$$
-Y_i = \frac{W_i \times X_i}{\sqrt{\frac{1}{d}(\sum_{j=1}^d X_j^2) + \epsilon}}
-$$
-
-- `out`:输出 $Y$ 。你暂时可以假设输出是一个2D连续张量,不涉及广播。
-- `input`:输入 $X$ 。你暂时可以假设输入是一个2D连续张量,不涉及广播。标准化沿输入张量的最后一个维度(即每一行,长度为 $d$ )执行。
-- `weight`:权重 $W$ 。1D张量,与输入张量的一行长度相同。
-- `eps`:小值 $\epsilon$ 以避免除以零。
-
-完成实现后,你应该能够通过`test/ops/rms_norm.py`中的测试用例。
-
-### 任务-2.5 旋转位置编码(RoPE)
-
-```c++
-void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta);
-```
-
-为输入张量`in`的每个向量(这些向量与 pos_ids 中的位置 id 相对应)计算以下内容:
-
-设 $\mathbf{x}_i = [\mathbf{a}_i, \mathbf{b}_i] \in \mathbb{R}^d$ 为输入向量, $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i] \in \mathbb{R}^d$ 为索引 $i$ 处的输出向量,其中 $\mathbf{a}_i, \mathbf{b}_i,\mathbf{a}'_i, \mathbf{b}'_i \in \mathbb{R}^{d/2}$ 。
-
-设 $\theta$ 为固定基数(例如 $\theta = 10000$), $j = 0, 1, \ldots, d/2 - 1$。
-
-设 $p_i \in \mathbb{N}$ 是输入索引i处token的位置id。
-
-那么RoPE的角度为 $\phi_{i,j} = \frac{p_i}{\theta^{2j/d}}$
-
-输出向量 $\mathbf{y}_i = [\mathbf{a}'_i, \mathbf{b}'_i]$ 计算如下:
-
-$$a_{i,j}' = a_{i,j} \cos(\phi_{i,j}) - b_{i,j} \sin(\phi_{i,j})$$
-
-$$b_{i,j}' = b_{i,j} \cos(\phi_{i,j}) + a_{i,j} \sin(\phi_{i,j})$$
-
-- `out`:结果**q**或**k**张量。形状应该是 [seqlen, nhead, d] 或 [seqlen, nkvhead, d]。你暂时可以假设张量是连续的。
-- `in`:原始**q**或**k**张量。形状应该是 [seqlen, nhead, d] 或 [seqlen, nkvhead, d]。你暂时可以假设张量是连续的。
-- `pos_ids`:输入序列中每个token的位置id(整个上下文中的索引)。形状应该是 [seqlen,],dtype应该是int64。
-- `theta`:频率向量的基值。
-
-完成实现后,你应该能够通过`test/ops/rope.py`中的测试用例。
-
-### 任务-2.6 自注意力(self-attention)
-
-```c++
-void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale);
-```
-
-为查询张量`q`、键张量`k`和值张量`v`计算自注意力。如果需要,你应该在进行此计算之前连接kvcache张量。
-
-$$
-A = Q K^\top * scale \\
-$$
-
-$$
-Y = \mathrm{causalsoftmax}(A) \cdot V \\
-$$
-
-- `attn_val`:结果注意力值张量。形状应该是[seqlen, nhead, dv]。你暂时可以假设张量是连续的。
-- `q`:查询张量。形状应该是 [seqlen, nhead, d]。你暂时可以假设张量是连续的。
-- `k`:键张量。形状应该是 [total_len, nkvhead, d]。你暂时可以假设张量是连续的。
-- `v`:值张量。形状应该是 [total_len, nkvhead, dv]。你暂时可以假设张量是连续的。
-- `scale`:缩放因子。在大多数情况下取值为 $\frac{1}{\sqrt{d}}$ 。
-
-完成实现后,你应该能够通过`test/ops/self_attention.py`中的测试用例。
-
-### 任务-2.7 SwiGLU
-
-```c++
-void swiglu(tensor_t out, tensor_t gate, tensor_t up);
-```
-
-这是一个逐元素函数,计算以下内容:
-
-$$
-out_{i} = up_{i} \circ \frac { gate_{i}}{1 + e^{-gate_{i}}}
-$$
-
-`out`、`up`和`gate`是具有相同形状 [seqlen, intermediate_size] 的2D连续张量。
-
-完成实现后,你应该能够通过`test/ops/swiglu.py`中的测试用例。
-
-### 任务-2.8
-
-运行算子测试。
-
-```bash
-python test/test_ops.py
-```
-
-你应该看到所有测试都通过了。提交并推送你的更改。你应该看到作业#2的自动测试通过了。
-
-### 任务-2.9(可选)rearrange
-
-这是一个奖励任务。你在模型推理中可能需要也可能不需要它。
-
-```c++
-void rearrange(tensor_t out, tensor_t in);
-```
-
-此算子用于将数据从一个张量复制到另一个具有相同形状但不同步长的张量。有了这个,你可以轻松地为张量实现`contiguous`功能。
-
-## 作业 #3:大语言模型推理
-
-终于,是时候用LLAISYS实现文本生成了。
-
-- 在`test/test_infer.py`中,你的实现应该能够使用argmax采样生成与PyTorch相同的文本。我们用于此作业的模型是[DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)。
-
-- 你的实现的python包装器在`python/llaisys/models/qwen2.py`中。你不允许在这里使用任何基于python的框架(如PyTorch)实现你的模型推理逻辑。相反,你需要在LLAISYS后端用C/C++实现模型。脚本加载safetensors文件中的每个张量,你需要从它们加载数据到你的模型后端。
-
-- 在`include/llaisys/models/qwen2.h`中,为你定义了一个原型。你可以随意修改代码,但你应该至少提供模型创建、销毁、数据加载和推理的基本API。在`src/llaisys/`中实现你的C API,并像`src/`中的其他模块一样组织你的C++代码。记得在`xmake.lua`中定义编译过程。
-
-- 在`python/llaisys/libllaisys/`中,为你的C API定义ctypes包装函数。使用你的包装函数实现`python/llaisys/models/qwen2.py`。
-
-- 你需要实现 KV-Cache 功能,否则模型推理速度会过慢。
-
-- 调试直到你的模型工作。利用张量的`debug`函数打印张量数据。它允许你在模型推理期间将任何张量的数据与PyTorch进行比较。
-
-完成实现后,你可以运行以下命令来测试你的模型:
-
-```bash
-python test/test_infer.py --model [dir_path/to/model] --test
-```
-
-提交并推送你的更改。你应该看到作业#3的自动测试通过了。
-
-## 只有完成作业后,才能开始做项目。
-
-## 项目#1:优化 LLAISYS 的 CPU 推理
-
-你可能已经注意到,你的模型推理速度相比 PyTorch 非常慢。这主要是因为你的算子没有经过优化。运行算子测试脚本时加上 ``--profile`` 参数,看看算子的性能表现。你可能会发现 ``linear`` 操作比 PyTorch 慢很多。这个算子本质上是矩阵乘法,是 Transformer 模型里最耗时的操作。
-
-以下是几种优化 CPU 算子的方法:
-
-### 使用 SIMD 指令
-
-SIMD(单指令多数据)是一类可以在单条指令中对多个数据元素同时执行相同操作的指令。现代 CPU 都支持 SIMD。你可以查阅相关资料,学习编译器内建函数(如 AVX2、AVX-512、NEON、SVE)来向量化你的算子。
-
-### 使用 OpenMP 实现并行
-
-你可以用多线程来并行化算子。OpenMP 是 C/C++ 中常见的多线程库。为 LLAISYS 增加 OpenMP 支持,使得 ``linear`` 等算子能够并行执行。
-
-### 使用第三方库
-
-有很多库能帮你优化 CPU 上的算子,例如 Eigen、OpenBLAS、MKL 等,它们能高效处理线性代数运算。但要注意,有些库只支持特定硬件平台,需要仔细阅读文档并小心使用。你也可以参考 PyTorch 的算子实现,看是否能复用。
-
-用任何你喜欢的方法优化你的推理实现,并报告性能提升情况。
-
-## 项目#2:在 LLAISYS 中集成 CUDA,适配两款CUDA或类CUDA平台(以下统称CUDA)
-
-这个项目不依赖 ``项目#1``。需要选择 Nvidia、天数、摩尔、沐曦中的至少两款平台。
-
-本次训练营提供了以上四种平台的算力,可以在官方进行申请算力,并用 CUDA 加速模型推理。在动手前,先深入理解 LLAISYS 框架。
-
-事实上,LLAISYS 是一个支持同构硬件的框架。使用时,每个线程会创建一个线程唯一的 **Context** 对象,管理该线程使用的所有设备 **Runtime**。**Runtime** 对象是设备的资源管理器,**Context** 会为每个设备(以延迟初始化的方式)创建唯一的 **Runtime**。你可以用 ``setDevice`` 在不同设备间切换,每个线程同一时间只会激活一个设备。详情见 ``src/core/context.hpp``。
-
-### 实现 CUDA Runtime API
-
-每个 **Runtime** 对象都会初始化一组通用的 **Runtime API**。你需要实现 CUDA 版本的 API。参考 ``src/device/cpu/cpu_runtime_api.cpp`` 看 CPU 的实现方式,查阅 [`CUDA Runtime 文档`](https://docs.nvidia.com/cuda/cuda-runtime-api/index.html) 找到对应 API。
-
-在 ``src/device/runtime_api.hpp`` 中,``nvidia::getRuntimeAPI()`` 被 ``ENABLE_NVIDIA_API`` 宏保护:
-
-```c++
-#ifdef ENABLE_NVIDIA_API
-namespace nvidia {
-const LlaisysRuntimeAPI *getRuntimeAPI();
-}
-#endif
-```
-
-该宏的定义在 ``xmake.lua`` 中,用于开关 CUDA 支持。若关闭,CUDA 代码不会被编译。你需要在 ``xmake/`` 下新建 ``nvidia.lua``,配置编译流程(参考 ``cpu.lua``)。查阅资料学习如何用 Xmake 配置。
-
-完成 CUDA Runtime API 后,用 ``--nv-gpu=y`` 打开 CUDA 支持并重新编译,运行测试:
-
-```bash
-xmake f --nv-gpu=y -cv
-xmake
-xmake install
-python test/test_runtime.py --device nvidia
-```
-
-### 实现 CUDA 算子
-
-在每个算子目录下新建 ``nvidia/`` 子目录,写 CUDA 版本实现。参考 ``src/ops/add/op.cpp`` 看如何包含 CUDA 实现。别忘了在 xmake 文件中定义编译流程。用 ``--device nvidia`` 参数运行测试。
-
-你可以使用 cuBLAS、cuDNN 等 CUDA 库来加速算子,额外的设备资源可以放在 `src/device/nvidia/nvidia_resource.cu`。
-
-最后,修改模型代码,支持 CUDA 推理:
-
-```bash
-python test/test_infer.py --model [dir_path/to/model] --test --device nvidia
-```
-
-## 项目#3:构建 AI 聊天机器人
-
-本项目中,你将用 LLAISYS 构建一个能与单用户实时对话的聊天机器人。
-
-### 随机采样
-
-目前我们只用过 argmax 采样,这在测试时够用,但聊天机器人需要更自然的回复。请实现一个随机采样算子,并尽量支持 **Temperature**、**Top-K**、**Top-P**。
-
-### 搭建聊天服务器
-
-在 Python 前端里,实现一个能接收 HTTP 请求并返回响应的服务器。可以用 FastAPI 等框架。接口最好遵循 OpenAI 的 chat-completion API。如果可以,尽量支持流式输出。你可以先假设只有一个用户在使用,每次请求可以阻塞直到处理完成。
-
-### 交互式聊天 UI
-
-实现一个 UI,能向服务器发送请求并接收回复。可以是命令行界面,也可以是 Web 界面。要能通过连续发送消息与机器人保持对话。
-
-### (可选)会话管理
-
-实际应用中,用户可以开启多个对话并在它们之间切换,还能修改历史问题让 AI 重新生成回答。扩展 UI,支持这些功能。实现一个支持前缀匹配的 KV-Cache 池,尽可能复用已有结果。
-
-## 项目#4:多用户推理服务
-
-在做这个项目之前,你需要完成 ``项目#3`` 并实现流式输出。
-
-### 支持多用户
-
-现实中推理服务要同时为多个用户提供服务,请求可能随时到来。你的服务端需要将请求加入请求池/队列,并用单独的循环线程/进程来处理。
-
-### 连续批处理
-
-为了最大化吞吐量,你需要做批处理,而不是逐一处理。由于每个请求长度不同,需要实现连续的迭代级批处理机制:每轮从池中取出若干请求组成批次(batch),执行一次批量推理,再把未完成的请求放回池中。推理时尽量用批量矩阵乘法加速。注意每个请求需要绑定不同的 KV-Cache,应实现支持前缀匹配的 KV-Cache 池来复用结果。
-
-## 项目#5:分布式推理
-
-在 LLAISYS 中引入张量并行。把模型分片到多个设备上,实现分布式推理。如果用 Nvidia GPU,需要支持 NCCL;如果用 CPU,需要支持 MPI。
-
-## 项目#6:支持新模型
-
-在 LLAISYS 中支持除作业所用模型以外的其他模型。
diff --git a/REPORT.md b/REPORT.md
new file mode 100644
index 000000000..49914b741
--- /dev/null
+++ b/REPORT.md
@@ -0,0 +1,308 @@
+# LLAISYS 项目报告
+
+## 一、已完成工作总览
+
+| 模块 | 说明 |
+|------|------|
+| 作业 #0-#3(基础) | 张量、算子、模型推理全部完成 |
+| 项目 #2:多平台 CUDA 适配 | Nvidia + 天数 Iluvatar CoreX 双平台 |
+| 项目 #3:AI 聊天机器人 | 服务器 + 前端 + 流式输出 + 会话管理 + KV 复用 |
+| 项目 #4:多用户推理服务 | 调度器 + 连续批处理 + 共享模型池 + KV 感知路由 |
+| 项目 #5:分布式推理 | 通信层 + NCCL 后端 + 张量并行 |
+
+## 二、作业阶段
+
+**作业 #1:张量** — 实现了张量的核心操作:`load`、`isContiguous`、`view`、`permute`、`slice`。所有测试通过。
+
+**作业 #2:算子** — 实现了 9 个 CPU 算子:`add`、`argmax`、`embedding`、`linear`、`rearrange`、`rms_norm`、`rope`、`self_attention`、`swiglu`。支持 Float32/Float16/BFloat16 数据类型,全部测试通过。
+
+**作业 #3:大语言模型推理** — 实现了 DeepSeek-R1-Distill-Qwen-1.5B 模型的完整推理链路:C++ Decoder 实现(Transformer 前向传播 + KV Cache)、C API 导出 + Python ctypes 封装、端到端推理输出与 PyTorch 完全一致。
+
+## 三、项目阶段
+
+**项目 #2:多平台 CUDA 适配**
+
+在 Nvidia GPU 和天数 Iluvatar CoreX GPU 两个平台上实现 CUDA 加速推理。
+
+实现方案:
+- Nvidia 平台:实现 CUDA Runtime API + 9 个 CUDA 算子内核,使用 nvcc 编译
+- 天数 Iluvatar CoreX 平台:采用 kernel 零复制策略,直接复用 `nvidia::` 命名空间的 CUDA 内核,使用 `clang++ -x cuda --cuda-gpu-arch=ivcore10` 编译
+
+关键问题与解决:
+
+| 问题 | 解决方案 |
+|------|----------|
+| xmake 自动调用 nvcc 而非 clang++ | 使用 `on_build()` 手动控制编译 |
+| xmake 注入 `-lcudadevrt` | 不注册 .cu 文件,避免 CUDA 检测 |
+| 静态库符号未解析 | `--whole-archive` 强制完整包含 |
+| `-lcudart` 链接顺序错误 | 统一放入 `add_shflags()` 控制顺序 |
+
+验证结果:Nvidia 和 Iluvatar 平台的 runtime、算子、端到端推理测试全部通过。
+
+---
+
+**项目 #3:AI 聊天机器人**
+
+实现内容:
+1. 随机采样:支持 Temperature、Top-K、Top-P、Seed(C API + Python 封装)
+2. 聊天服务器(`python/llaisys/server.py`):HTTP 服务,兼容 OpenAI Chat Completion API(`/v1/chat/completions`),支持流式输出(SSE)和非流式输出
+3. 前端 UI(`frontend/`):Web 界面,支持连续对话、流式显示
+4. 会话管理:多会话支持、历史消息编辑 + 分叉重新生成、前缀匹配 KV Cache 池跨会话复用
+
+架构:
+```
+前端 (HTML/JS) → HTTP → 服务器 (Python) → C API → C++ 推理引擎
+ ↕
+ KV Cache Pool
+```
+
+---
+
+**项目 #4:多用户推理服务**
+
+实现内容:
+1. 请求调度器(`python/llaisys/scheduler.py`):入口线程 + 调度器 + Worker 执行模式,支持多 Worker、请求队列、超时控制、会话粘性路由 + KV 感知路由
+2. 连续批处理:迭代级批处理、Packed Prefill、动态缩批、流式 + 非流式请求均走批量路径
+3. 共享模型池(`--shared-model`):多 Worker 共享一份模型,内存从 N×model_size 降到 1×model_size
+4. KV 内存感知流控(`--kv-memory-threshold`):内存压力超阈值时拒绝新请求(429)
+
+推荐启动参数:
+```bash
+python -m llaisys.server --model <模型路径> \
+ --workers 4 --shared-model \
+ --continuous-batching --max-batch-size 8 \
+ --kv-aware-routing --kv-memory-threshold 0.85
+```
+
+压测结果:
+
+| 参数 | 成功率 | 吞吐 | 平均延迟 |
+|------|--------|------|----------|
+| total=20, concurrency=2, tokens=16 | 20/20 | 0.18 rps | 11.1s |
+| total=12, concurrency=4, tokens=8 | 12/12 | 0.37 rps | 10.2s |
+
+---
+
+**项目 #5:分布式推理**
+
+引入张量并行,将模型分片到多个 GPU 上实现分布式推理,使用 NCCL 通信。
+
+实现内容:
+
+1. 通信层(C API + C++ + NCCL 后端):
+ - `include/llaisys/comm.h` → C API 头文件(函数指针表)
+ - `src/device/comm_api.{hpp,cpp}` → C++ dispatcher(#ifdef 条件编译)
+ - `src/device/nvidia/nvidia_comm.cu` → NCCL 后端(8 个操作)
+ - 支持操作:init、destroy、get_rank、get_size、allreduce、broadcast、send、recv
+
+2. 张量并行(Megatron-style):Decoder 中每层插入 2 个 AllReduce(`attn_o` 和 `mlp_down` 线性投影后、残差加之前),单 GPU 时零开销
+
+3. 权重切分(`python/llaisys/tensor_parallel.py`):
+
+ | 权重 | 切分方式 | 说明 |
+ |------|----------|------|
+ | Q/K/V/gate/up | Column split (dim 0) | 每 rank 获得 nh/tp_size 个 head |
+ | attn_o/down | Row split (dim 1) | 输出需 AllReduce 聚合 |
+ | embeddings/norms | 复制 | 所有 rank 持有完整副本 |
+
+4. 多进程启动器:Rank 0 生成 NCCL unique ID → 文件 IPC 广播 → 各 rank 加载切分权重 → 分布式推理
+
+验证结果(8×A100-80GB 服务器):
+
+| 测试 | 结果 |
+|------|------|
+| 单卡 runtime + 算子 | ✅ 通过 |
+| 通信层单元测试 | ✅ 通过 |
+| 2 卡 AllReduce | ✅ 通过(SUM = 3.0) |
+| 4 卡 AllReduce | ✅ 通过(SUM = 10.0) |
+| 8 卡 AllReduce | ⚠️ 超时(显存被其他进程占用) |
+| 张量并行推理 | ✅ 通过(2 卡,token 一致) |
+
+## 四、代码架构
+
+```
+llaisys/
+├── include/llaisys/ # C API 头文件
+│ ├── llaisys.h # 基础类型定义
+│ ├── runtime.h # 运行时 API
+│ ├── comm.h # 通信 API
+│ └── models/qwen2.h # 模型 API
+├── src/
+│ ├── device/ # 设备抽象层
+│ │ ├── cpu/ # CPU 实现
+│ │ ├── nvidia/ # CUDA 实现 + NCCL 通信
+│ │ └── iluvatar/ # 天数 CoreX 实现
+│ ├── ops/ # 算子(9 个,各含 cpu/nvidia 子目录)
+│ ├── models/ # 模型实现(Qwen2 Decoder)
+│ └── core/ # 运行时核心(Context/Runtime/Storage)
+├── python/llaisys/ # Python 前端
+│ ├── server.py # 聊天服务器
+│ ├── scheduler.py # 请求调度器
+│ ├── tensor_parallel.py # 权重切分
+│ └── libllaisys/ # ctypes 绑定
+├── frontend/ # Web UI
+├── scripts/ # 工具脚本(启动器、压测)
+├── test/ # 测试文件
+└── xmake.lua # 构建配置
+```
+
+## 五、复现流程
+
+### 环境要求
+
+| 依赖 | 版本要求 | 用途 |
+|------|----------|------|
+| Xmake | >= 2.7 | 构建工具 |
+| C++ 编译器 | GCC >= 9 / Clang >= 10 / MSVC 2019+ | 编译后端 |
+| Python | >= 3.9 | 前端 + 测试 |
+| PyTorch | >= 2.0 | 对比验证(仅测试时需要) |
+| CUDA Toolkit | >= 11.0 | GPU 推理(项目 #2 起) |
+| NCCL | >= 2.10 | 分布式推理(项目 #5) |
+
+### 步骤 0:克隆仓库 + 下载模型
+
+```bash
+git clone https://github.com/KevinSusan/llaisys_tt.git
+cd llaisys_tt
+
+# 下载测试模型 DeepSeek-R1-Distill-Qwen-1.5B(约 3GB)
+pip install huggingface_hub
+huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --local-dir ./model
+# 国内镜像:HF_ENDPOINT=https://hf-mirror.com huggingface-cli download ...
+```
+
+> 以下所有命令中 `./model` 替换为实际模型路径。
+
+---
+
+### 作业 #1-#3 验证(CPU,任意机器)
+
+```bash
+# 编译
+xmake build
+
+# 安装共享库(Linux)
+cp build/linux/x86_64/release/libllaisys.so python/llaisys/libllaisys/
+# Windows: copy build\windows\x64\release\llaisys.dll python\llaisys\libllaisys\
+
+# 设置 Python 路径
+export PYTHONPATH=$(pwd)/python:$PYTHONPATH
+
+# 作业 #1:张量测试
+python test/test_tensor.py
+
+# 作业 #2:CPU 算子测试
+python test/test_ops.py
+
+# 作业 #3:CPU 端到端推理(输出应与 PyTorch 完全一致)
+python test/test_infer.py --model ./model --test
+```
+
+---
+
+### 项目 #2 验证(Nvidia GPU)
+
+设备要求:Nvidia GPU + CUDA Toolkit
+
+```bash
+# 编译(开启 Nvidia GPU 支持)
+xmake f --nv-gpu=y -c
+xmake build
+cp build/linux/x86_64/release/libllaisys.so python/llaisys/libllaisys/
+export PYTHONPATH=$(pwd)/python:$PYTHONPATH
+
+# GPU 运行时测试
+python test/test_runtime.py --device nvidia
+
+# GPU 算子测试(9 个算子)
+python test/ops_gpu/run_all.py --device nvidia
+
+# GPU 端到端推理(输出应与 PyTorch 完全一致)
+python test/test_infer.py --model ./model --test --device nvidia
+```
+
+### 项目 #2 验证(天数 Iluvatar CoreX GPU)
+
+设备要求:天数 Iluvatar CoreX GPU + CoreX SDK
+
+```bash
+xmake f --iluvatar-gpu=y -c
+xmake build
+cp build/linux/x86_64/release/libllaisys.so python/llaisys/libllaisys/
+export PYTHONPATH=$(pwd)/python:/usr/local/corex/lib64/python3/dist-packages:$PYTHONPATH
+
+python test/test_runtime.py --device iluvatar
+python test/ops_gpu/run_all.py --device iluvatar
+python test/test_infer.py --model ./model --test --device iluvatar
+```
+
+---
+
+### 项目 #3 验证(聊天机器人)
+
+设备要求:同项目 #2(GPU 推理)
+
+```bash
+# 启动聊天服务器
+python -m llaisys.server --model ./model --device nvidia
+
+# 在另一个终端测试 API(兼容 OpenAI 格式)
+curl http://localhost:8000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":32}'
+
+# 或打开浏览器访问 http://localhost:8000 使用 Web UI
+```
+
+---
+
+### 项目 #4 验证(多用户推理服务)
+
+设备要求:同项目 #2(GPU 推理)
+
+```bash
+# 启动多用户服务(共享模型 + 连续批处理)
+python -m llaisys.server --model ./model --device nvidia \
+ --workers 2 --shared-model --continuous-batching --max-batch-size 8
+
+# 运行调度器测试
+python test/test_scheduler_inmemory.py
+
+# 运行并发压测
+python scripts/benchmark_chat_scheduler.py \
+ --url http://localhost:8000 --total 12 --concurrency 4 --max-new-tokens 8
+```
+
+---
+
+### 项目 #5 验证(分布式推理)
+
+设备要求:多张 Nvidia GPU + NCCL
+
+```bash
+# 编译(确保 --nv-gpu=y)
+xmake f --nv-gpu=y -c && xmake build
+cp build/linux/x86_64/release/libllaisys.so python/llaisys/libllaisys/
+export PYTHONPATH=$(pwd)/python:$PYTHONPATH
+pip install transformers safetensors
+
+# 通信层单元测试(单卡)
+python test/test_comm_api.py --device nvidia
+
+# 多卡 AllReduce 集成测试
+python test/test_allreduce.py --nranks 2 --device nvidia
+python test/test_allreduce.py --nranks 4 --device nvidia
+
+# 张量并行推理(2 卡)
+python scripts/launch_tp.py \
+ --model ./model --nranks 2 --device nvidia \
+ --prompt "Hello, world" --max-tokens 32
+```
+
+## 六、技术亮点
+
+1. **跨平台 CUDA 适配**:通过 kernel 零复制策略,天数 Iluvatar 平台无需修改任何 CUDA 内核代码,直接复用 Nvidia 实现
+2. **完整推理服务栈**:从底层 C++ 算子到 HTTP API,全链路自研,兼容 OpenAI API 格式
+3. **连续批处理**:迭代级调度 + Packed Prefill + 动态缩批,支持流式和非流式混合请求
+4. **Megatron-style 张量并行**:通信层抽象设计,支持 NCCL/IXCCL/MPI 多后端,Decoder 中仅需 2 个 AllReduce/层
+5. **KV Cache 复用体系**:前缀匹配 + 跨会话 donor 复用 + 分叉编辑 + 内存感知流控
diff --git a/Untitled b/Untitled
new file mode 100644
index 000000000..fbd2fd289
--- /dev/null
+++ b/Untitled
@@ -0,0 +1 @@
+xmake f -m release --nv-gpu=y --vs=2022
\ No newline at end of file
diff --git a/frontend/app.js b/frontend/app.js
new file mode 100644
index 000000000..8f766634e
--- /dev/null
+++ b/frontend/app.js
@@ -0,0 +1,366 @@
+let activeId = "";
+const conversations = [];
+
+const chat = document.getElementById("chat");
+const form = document.getElementById("chat-form");
+const promptInput = document.getElementById("prompt");
+const endpointInput = document.getElementById("endpoint");
+const maxTokensInput = document.getElementById("max-tokens");
+const samplingModeInput = document.getElementById("sampling-mode");
+const temperatureInput = document.getElementById("temperature");
+const topKInput = document.getElementById("top-k");
+const topPInput = document.getElementById("top-p");
+const seedInput = document.getElementById("seed");
+const editHint = document.getElementById("edit-hint");
+const sendButton = document.getElementById("send");
+const stopButton = document.getElementById("stop");
+const sessionList = document.getElementById("session-list");
+const newChatButton = document.getElementById("new-chat");
+let activeStreamController = null;
+let pendingEdit = null;
+
+const createLocalId = () => {
+ if (crypto && crypto.randomUUID) return crypto.randomUUID();
+ return `local-${Date.now()}-${Math.random().toString(16).slice(2)}`;
+};
+
+const dedupeAdjacentParagraphs = (text) => {
+ const parts = text
+ .split(/\n{2,}/)
+ .map((p) => p.trim())
+ .filter(Boolean);
+ const deduped = [];
+ for (const p of parts) {
+ if (deduped.length > 0 && deduped[deduped.length - 1] === p) continue;
+ deduped.push(p);
+ }
+ return deduped.join("\n\n");
+};
+
+const parseAssistantSections = (rawText) => {
+ const normalized = String(rawText || "").replaceAll("<|end_of_sentence|>", "");
+ const openTag = "";
+ const closeTag = " ";
+ const start = normalized.indexOf(openTag);
+ const closeOnly = normalized.indexOf(closeTag);
+
+ // Tolerate outputs containing only a closing tag.
+ if (start < 0 && closeOnly >= 0) {
+ const thinking = normalized.slice(0, closeOnly).trim();
+ const answer = normalized.slice(closeOnly + closeTag.length).trim();
+ return {
+ thinking: dedupeAdjacentParagraphs(thinking.replaceAll(closeTag, "")),
+ answer: dedupeAdjacentParagraphs(answer.replaceAll(closeTag, "")),
+ };
+ }
+
+ if (start < 0) {
+ return { thinking: "", answer: dedupeAdjacentParagraphs(normalized.replaceAll(closeTag, "")) };
+ }
+ const afterOpen = start + openTag.length;
+ const end = normalized.indexOf(closeTag, afterOpen);
+ if (end < 0) {
+ return {
+ thinking: dedupeAdjacentParagraphs(normalized.slice(afterOpen).replaceAll(openTag, "")),
+ answer: "",
+ };
+ }
+ const thinking = normalized.slice(afterOpen, end).replaceAll(openTag, "").replaceAll(closeTag, "");
+ const answer = normalized.slice(end + closeTag.length).replaceAll(openTag, "").replaceAll(closeTag, "");
+ return {
+ thinking: dedupeAdjacentParagraphs(thinking),
+ answer: dedupeAdjacentParagraphs(answer),
+ };
+};
+
+const renderAssistantBubble = (bubble, rawText) => {
+ bubble.dataset.raw = rawText;
+ const { thinking, answer } = parseAssistantSections(rawText);
+ const thinkingSection = bubble.querySelector(".assistant-thinking");
+ const answerSection = bubble.querySelector(".assistant-answer");
+ const normalizedThinking = thinking.replace(/\s+/g, " ").trim();
+ const normalizedAnswer = answer.replace(/\s+/g, " ").trim();
+ const isRedundantThinking =
+ normalizedThinking &&
+ normalizedAnswer &&
+ (normalizedThinking === normalizedAnswer ||
+ normalizedAnswer.includes(normalizedThinking));
+
+ if (thinking && thinking.trim() && !isRedundantThinking) {
+ thinkingSection.style.display = "block";
+ thinkingSection.querySelector(".assistant-thinking-content").textContent = thinking.trim();
+ } else {
+ thinkingSection.style.display = "none";
+ thinkingSection.querySelector(".assistant-thinking-content").textContent = "";
+ }
+ answerSection.textContent = answer.trimStart();
+};
+
+const clearPendingEdit = () => {
+ pendingEdit = null;
+ sendButton.textContent = "发送";
+ editHint.style.display = "none";
+ editHint.textContent = "";
+};
+
+const setPendingEdit = (state) => {
+ pendingEdit = state;
+ sendButton.textContent = "分叉发送";
+ const round = Number(state.editMessageIndex) + 1;
+ editHint.textContent = `正在编辑第 ${round} 轮用户消息,发送后将创建分叉会话(Esc 可取消)`;
+ editHint.style.display = "block";
+};
+
+const appendBubble = (text, role, options = {}) => {
+ const div = document.createElement("div");
+ div.className = `bubble ${role}`;
+ if (role === "assistant") {
+ div.innerHTML = `
+
+
+ `;
+ renderAssistantBubble(div, text || "");
+ } else {
+ const content = document.createElement("div");
+ content.className = "user-content";
+ content.textContent = text;
+ div.appendChild(content);
+ if (options.canEdit) {
+ const editButton = document.createElement("button");
+ editButton.type = "button";
+ editButton.className = "bubble-edit";
+ editButton.textContent = "编辑";
+ editButton.addEventListener("click", options.onEdit);
+ div.appendChild(editButton);
+ }
+ }
+ chat.appendChild(div);
+ chat.scrollTop = chat.scrollHeight;
+ return div;
+};
+
+const renderChat = (conversation) => {
+ chat.innerHTML = "";
+ for (let i = 0; i < conversation.messages.length; i += 1) {
+ const message = conversation.messages[i];
+ const canEdit = message.role === "user" && Boolean(conversation.serverId);
+ appendBubble(message.text, message.role, {
+ canEdit,
+ onEdit: () => {
+ if (!conversation.serverId) return;
+ setPendingEdit({
+ sourceLocalId: conversation.id,
+ sourceServerId: conversation.serverId,
+ editMessageIndex: i,
+ });
+ promptInput.value = message.text || "";
+ promptInput.focus();
+ },
+ });
+ }
+};
+
+const renderSessions = () => {
+ sessionList.innerHTML = "";
+ for (const convo of conversations) {
+ const item = document.createElement("div");
+ item.className = `session-item${convo.id === activeId ? " active" : ""}`;
+ item.textContent = convo.title || "新对话";
+ item.addEventListener("click", () => {
+ activeId = convo.id;
+ clearPendingEdit();
+ renderSessions();
+ renderChat(convo);
+ });
+ sessionList.appendChild(item);
+ }
+};
+
+const createConversation = () => {
+ const convo = {
+ id: createLocalId(),
+ serverId: "",
+ title: "新对话",
+ messages: [],
+ };
+ conversations.unshift(convo);
+ activeId = convo.id;
+ clearPendingEdit();
+ renderSessions();
+ renderChat(convo);
+ return convo;
+};
+
+const getActiveConversation = () => {
+ let convo = conversations.find((c) => c.id === activeId);
+ if (!convo) {
+ convo = createConversation();
+ }
+ return convo;
+};
+
+const streamChat = async (payload, bubble, convo, controller) => {
+ const res = await fetch(`${endpointInput.value}/v1/chat/completions`, {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify({ ...payload, stream: true }),
+ signal: controller.signal,
+ });
+
+ if (!res.ok || !res.body) {
+ throw new Error(`请求失败:${res.status}`);
+ }
+
+ const reader = res.body.getReader();
+ const decoder = new TextDecoder("utf-8");
+ let buffer = "";
+
+ while (true) {
+ const { value, done } = await reader.read();
+ if (done) break;
+ buffer += decoder.decode(value, { stream: true });
+ const parts = buffer.split("\n\n");
+ buffer = parts.pop() || "";
+ for (const part of parts) {
+ if (!part.startsWith("data: ")) continue;
+ const payload_str = part.slice(6).trim();
+ if (payload_str === "[DONE]") return;
+ const data = JSON.parse(payload_str);
+ if (data.session_id && !convo.serverId) {
+ convo.serverId = data.session_id;
+ }
+ const delta = data.choices && data.choices[0] && data.choices[0].delta;
+ if (delta && delta.content) {
+ const raw = (bubble.dataset.raw || "") + delta.content;
+ renderAssistantBubble(bubble, raw);
+ }
+ if (data.choices && data.choices[0] && data.choices[0].finish_reason) {
+ return;
+ }
+ }
+ }
+};
+
+form.addEventListener("submit", async (event) => {
+ event.preventDefault();
+ const prompt = promptInput.value.trim();
+ if (!prompt) return;
+
+ const activeConvo = getActiveConversation();
+ let convo = activeConvo;
+ let payloadSessionId = activeConvo.serverId;
+ let payloadEditFrom = "";
+ let payloadEditIndex = -1;
+ const editState = pendingEdit;
+ const isForkEdit = Boolean(editState && editState.sourceServerId);
+
+ if (isForkEdit) {
+ const sourceConvo = conversations.find((c) => c.id === editState.sourceLocalId);
+ if (!sourceConvo || !sourceConvo.serverId) {
+ clearPendingEdit();
+ return;
+ }
+ convo = createConversation();
+ convo.title = `${sourceConvo.title || "新对话"} (分叉)`;
+ const prefix = sourceConvo.messages.slice(0, editState.editMessageIndex + 1).map((m) => ({ ...m }));
+ if (
+ prefix.length === 0 ||
+ prefix[prefix.length - 1].role !== "user"
+ ) {
+ clearPendingEdit();
+ return;
+ }
+ prefix[prefix.length - 1].text = prompt;
+ convo.messages = prefix;
+ renderSessions();
+ renderChat(convo);
+ payloadEditFrom = sourceConvo.serverId;
+ payloadEditIndex = editState.editMessageIndex;
+ payloadSessionId = "";
+ }
+
+ if (!isForkEdit) {
+ convo.messages.push({ role: "user", text: prompt });
+ appendBubble(prompt, "user", { canEdit: false });
+ }
+ promptInput.value = "";
+ const assistantBubble = appendBubble("", "assistant");
+ convo.messages.push({ role: "assistant", text: "" });
+
+ sendButton.disabled = true;
+ stopButton.disabled = false;
+ activeStreamController = new AbortController();
+ const payload = {
+ prompt,
+ max_tokens: Number(maxTokensInput.value) || 128,
+ temperature: Number(temperatureInput.value) || 0,
+ top_k: Number(topKInput.value) || 1,
+ top_p: Number(topPInput.value) || 0,
+ seed: Number(seedInput.value) || 0,
+ };
+ if (samplingModeInput.value) {
+ payload.sampling = samplingModeInput.value;
+ }
+ if (payloadSessionId) {
+ payload.session_id = payloadSessionId;
+ }
+ if (payloadEditFrom && payloadEditIndex >= 0) {
+ payload.edit_from_session_id = payloadEditFrom;
+ payload.edit_message_index = payloadEditIndex;
+ }
+
+ try {
+ await streamChat(payload, assistantBubble, convo, activeStreamController);
+ convo.messages[convo.messages.length - 1].text = assistantBubble.dataset.raw || "";
+ if (convo.title === "新对话") {
+ convo.title = prompt.slice(0, 12);
+ renderSessions();
+ }
+ } catch (err) {
+ if (err && err.name === "AbortError") {
+ convo.messages[convo.messages.length - 1].text = assistantBubble.dataset.raw || "";
+ return;
+ }
+ renderAssistantBubble(assistantBubble, `请求失败:${err.message}`);
+ convo.messages[convo.messages.length - 1].text = assistantBubble.dataset.raw || "";
+ } finally {
+ clearPendingEdit();
+ activeStreamController = null;
+ stopButton.disabled = true;
+ sendButton.disabled = false;
+ }
+});
+
+newChatButton.addEventListener("click", () => {
+ createConversation();
+});
+
+stopButton.addEventListener("click", async () => {
+ if (activeStreamController) {
+ activeStreamController.abort();
+ }
+ const convo = getActiveConversation();
+ if (!convo.serverId) {
+ return;
+ }
+ try {
+ await fetch(`${endpointInput.value}/chat/stop`, {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify({ session_id: convo.serverId }),
+ });
+ } catch (_) {
+ // no-op
+ }
+});
+
+document.addEventListener("keydown", (event) => {
+ if (event.key === "Escape" && pendingEdit) {
+ clearPendingEdit();
+ }
+});
+
+createConversation();
diff --git a/frontend/index.html b/frontend/index.html
new file mode 100644
index 000000000..fd94944a8
--- /dev/null
+++ b/frontend/index.html
@@ -0,0 +1,76 @@
+
+
+
+
+
+ LLAISYS Chat
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/style.css b/frontend/style.css
new file mode 100644
index 000000000..1f8da19cb
--- /dev/null
+++ b/frontend/style.css
@@ -0,0 +1,245 @@
+* {
+ box-sizing: border-box;
+}
+
+body {
+ margin: 0;
+ font-family: "Segoe UI", "Microsoft YaHei", sans-serif;
+ background: #0f1115;
+ color: #e6e6e6;
+}
+
+.container {
+ max-width: 960px;
+ margin: 0 auto;
+ padding: 24px;
+ display: grid;
+ grid-template-columns: 220px 1fr;
+ gap: 16px;
+ min-height: 100vh;
+}
+
+.sidebar {
+ background: #141820;
+ border: 1px solid #252a35;
+ border-radius: 10px;
+ padding: 16px;
+ display: flex;
+ flex-direction: column;
+ gap: 12px;
+ height: fit-content;
+}
+
+.brand {
+ font-size: 20px;
+ font-weight: 700;
+}
+
+.new-chat {
+ width: 100%;
+}
+
+.session-label {
+ font-size: 12px;
+ color: #9aa4b2;
+ text-transform: uppercase;
+ letter-spacing: 0.08em;
+}
+
+.session-list {
+ display: flex;
+ flex-direction: column;
+ gap: 8px;
+}
+
+.session-item {
+ padding: 8px 10px;
+ border-radius: 8px;
+ background: #1b1f2a;
+ cursor: pointer;
+ border: 1px solid transparent;
+}
+
+.session-item.active {
+ border-color: #1f6feb;
+ background: #1c273d;
+}
+
+.panel {
+ display: flex;
+ flex-direction: column;
+ gap: 16px;
+}
+
+.header {
+ display: flex;
+ flex-direction: column;
+ gap: 12px;
+}
+
+.header h1 {
+ margin: 0;
+}
+
+.meta input {
+ width: 360px;
+ max-width: 100%;
+ padding: 6px 8px;
+ border-radius: 6px;
+ border: 1px solid #2b2f3a;
+ background: #171a21;
+ color: inherit;
+}
+
+.chat {
+ flex: 1;
+ background: #141820;
+ border: 1px solid #252a35;
+ border-radius: 10px;
+ padding: 16px;
+ overflow-y: auto;
+ min-height: 360px;
+ display: flex;
+ flex-direction: column;
+}
+
+.bubble {
+ padding: 10px 12px;
+ border-radius: 10px;
+ margin-bottom: 12px;
+ white-space: pre-wrap;
+ line-height: 1.5;
+}
+
+.bubble.user {
+ background: #1f6feb;
+ color: white;
+ align-self: flex-end;
+ display: flex;
+ gap: 8px;
+ align-items: flex-start;
+}
+
+.user-content {
+ white-space: pre-wrap;
+}
+
+.bubble-edit {
+ background: rgba(255, 255, 255, 0.18);
+ border: 1px solid rgba(255, 255, 255, 0.45);
+ color: #fff;
+ padding: 2px 8px;
+ border-radius: 6px;
+ font-size: 12px;
+ line-height: 1.4;
+ cursor: pointer;
+ flex: 0 0 auto;
+}
+
+.bubble.assistant {
+ background: #222836;
+ padding: 8px 10px;
+}
+
+.assistant-thinking {
+ margin-bottom: 6px;
+ padding: 6px 8px;
+ border-left: 3px solid #6b7280;
+ background: rgba(255, 255, 255, 0.03);
+}
+
+.assistant-thinking-label {
+ font-size: 12px;
+ color: #a7b0bf;
+ margin-bottom: 4px;
+}
+
+.assistant-thinking-content {
+ font-size: 13px;
+ line-height: 1.45;
+ color: #c9d1dd;
+ white-space: pre-wrap;
+ max-height: 120px;
+ overflow-y: auto;
+}
+
+.assistant-answer {
+ font-size: 18px;
+ line-height: 1.55;
+ color: #f2f5fb;
+ white-space: pre-wrap;
+}
+
+.composer {
+ display: flex;
+ flex-direction: column;
+ gap: 12px;
+}
+
+.edit-hint {
+ padding: 8px 10px;
+ border-radius: 8px;
+ border: 1px solid #375a8f;
+ background: #18263d;
+ color: #cfe3ff;
+ font-size: 13px;
+}
+
+textarea {
+ width: 100%;
+ padding: 12px;
+ border-radius: 10px;
+ border: 1px solid #2b2f3a;
+ background: #171a21;
+ color: inherit;
+ resize: vertical;
+}
+
+.actions {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+ gap: 12px;
+}
+
+.action-buttons {
+ display: flex;
+ gap: 8px;
+ align-items: center;
+}
+
+.controls {
+ display: flex;
+ flex-wrap: wrap;
+ gap: 12px;
+ align-items: center;
+}
+
+.actions input,
+.actions select {
+ width: 110px;
+ padding: 6px 8px;
+ border-radius: 6px;
+ border: 1px solid #2b2f3a;
+ background: #171a21;
+ color: inherit;
+}
+
+button {
+ padding: 8px 18px;
+ border: none;
+ border-radius: 8px;
+ background: #1f6feb;
+ color: white;
+ font-weight: 600;
+ cursor: pointer;
+}
+
+button:disabled {
+ opacity: 0.6;
+ cursor: not-allowed;
+}
+
+button.secondary {
+ background: #3a4456;
+}
diff --git a/include/llaisys.h b/include/llaisys.h
index 73ca7eead..3bd516920 100644
--- a/include/llaisys.h
+++ b/include/llaisys.h
@@ -24,6 +24,7 @@ typedef enum {
LLAISYS_DEVICE_CPU = 0,
//// TODO: Add more device types here. Numbers need to be consecutive.
LLAISYS_DEVICE_NVIDIA = 1,
+ LLAISYS_DEVICE_ILUVATAR = 2,
LLAISYS_DEVICE_TYPE_COUNT
} llaisysDeviceType_t;
diff --git a/include/llaisys/comm.h b/include/llaisys/comm.h
new file mode 100644
index 000000000..0afefa82b
--- /dev/null
+++ b/include/llaisys/comm.h
@@ -0,0 +1,50 @@
+#ifndef LLAISYS_COMM_H
+#define LLAISYS_COMM_H
+
+#include "../llaisys.h"
+
+__C {
+ // Communication Types
+ typedef void *llaisysComm_t;
+
+ typedef enum {
+ LLAISYS_COMM_NCCL = 0,
+ LLAISYS_COMM_IXCCL = 1,
+ LLAISYS_COMM_MPI = 2,
+ } llaisysCommBackend_t;
+
+ typedef enum {
+ LLAISYS_REDUCE_SUM = 0,
+ LLAISYS_REDUCE_PROD = 1,
+ LLAISYS_REDUCE_MIN = 2,
+ LLAISYS_REDUCE_MAX = 3,
+ } llaisysReduceOp_t;
+
+ #define LLAISYS_COMM_UNIQUE_ID_MAX_SIZE 128
+
+ // Communication API Functions
+ typedef int (*comm_init_api)(llaisysComm_t *, int, int, const void *);
+ typedef void (*comm_destroy_api)(llaisysComm_t);
+ typedef int (*comm_get_rank_api)(llaisysComm_t);
+ typedef int (*comm_get_size_api)(llaisysComm_t);
+ typedef void (*comm_allreduce_api)(const void *, void *, size_t, llaisysDataType_t, llaisysReduceOp_t, llaisysComm_t, llaisysStream_t);
+ typedef void (*comm_broadcast_api)(void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t);
+ typedef void (*comm_send_api)(const void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t);
+ typedef void (*comm_recv_api)(void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t);
+
+ struct LlaisysCommAPI {
+ comm_init_api init;
+ comm_destroy_api destroy;
+ comm_get_rank_api get_rank;
+ comm_get_size_api get_size;
+ comm_allreduce_api allreduce;
+ comm_broadcast_api broadcast;
+ comm_send_api send;
+ comm_recv_api recv;
+ };
+
+ __export const LlaisysCommAPI *llaisysGetCommAPI(llaisysCommBackend_t);
+ __export int llaisysCommGenerateUniqueId(llaisysCommBackend_t backend, void *id_out, size_t *id_size);
+}
+
+#endif // LLAISYS_COMM_H
diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h
index 7054626d4..f18d09a11 100644
--- a/include/llaisys/models/qwen2.h
+++ b/include/llaisys/models/qwen2.h
@@ -2,15 +2,22 @@
#define LLAISYS_MODELS_QWEN2_H
#include "../tensor.h"
+#include "../comm.h"
__C {
+ //千问2模型元信息
struct LlaisysQwen2Meta {
+ //数据类型
llaisysDataType_t dtype;
+ //模型参数
size_t nlayer, hs, nh, nkvh, dh, di, maxseq, voc;
+ //其他参数
float epsilon, theta;
+ //特殊token
int64_t end_token;
};
+ //千问2模型权重
struct LlaisysQwen2Weights {
llaisysTensor_t in_embed;
llaisysTensor_t out_embed;
@@ -29,14 +36,135 @@ __C {
llaisysTensor_t *mlp_down_w;
};
+ // 采样参数
+ struct LlaisysSamplingParams {
+ int32_t top_k; // <=1 表示贪心
+ float top_p; // (0,1],<=0 表示不启用
+ float temperature; // <=0 表示禁用温度缩放
+ uint32_t seed; // 0 表示随机
+ };
+
+ //千问2模型
struct LlaisysQwen2Model;
+ // KV block / context (experimental)
+ struct LlaisysQwen2KVBlock;
+ struct LlaisysQwen2KVContext;
+
+ struct LlaisysQwen2KVBlockMeta {
+ llaisysDataType_t dtype;
+ size_t nlayer, nh, nkvh, dh;
+ size_t max_tokens;
+ };
+ //创建千问2模型实例
__export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice);
+ //销毁千问2模型实例
__export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model);
+ //获取千问2模型权重
__export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model);
+ //执行千问2模型推理(兼容接口,建议改用 Prefill/Step)
__export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken);
+
+ //执行千问2模型预填充(prefill)
+ __export int64_t llaisysQwen2ModelPrefill(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken);
+
+ //执行千问2模型单步解码(step)
+ __export int64_t llaisysQwen2ModelStep(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken);
+
+ //执行千问2模型批量预填充(packed prompts)
+ // token_offsets 长度为 nseq + 1,且 token_offsets[0]=0, token_offsets[nseq]=ntoken
+ // out_next_tokens 需为长度 nseq 的可写缓冲区
+ __export int32_t llaisysQwen2ModelPrefillPacked(struct LlaisysQwen2Model *model,
+ int64_t *token_ids,
+ const int64_t *token_offsets,
+ size_t nseq,
+ int64_t *out_next_tokens);
+ //执行千问2模型批量解码(packed,当前为过渡语义,详见实现注释)
+ __export int32_t llaisysQwen2ModelStepPacked(struct LlaisysQwen2Model *model,
+ int64_t *token_ids,
+ const int64_t *token_offsets,
+ size_t nseq,
+ int64_t *out_next_tokens);
+
+ //执行千问2模型预填充(prefill,带采样参数)
+ __export int64_t llaisysQwen2ModelPrefillSampling(struct LlaisysQwen2Model * model,
+ int64_t * token_ids,
+ size_t ntoken,
+ const struct LlaisysSamplingParams *params);
+
+ //执行千问2模型单步解码(step,带采样参数)
+ __export int64_t llaisysQwen2ModelStepSampling(struct LlaisysQwen2Model * model,
+ int64_t * token_ids,
+ size_t ntoken,
+ const struct LlaisysSamplingParams *params);
+
+ //执行千问2模型推理(带采样参数)
+ __export int64_t llaisysQwen2ModelInferSampling(struct LlaisysQwen2Model * model,
+ int64_t * token_ids,
+ size_t ntoken,
+ const struct LlaisysSamplingParams *params);
+
+ //执行千问2模型推理(带采样参数,按值传递)
+ __export int64_t llaisysQwen2ModelInferSamplingEx(struct LlaisysQwen2Model * model,
+ int64_t * token_ids,
+ size_t ntoken,
+ int32_t top_k,
+ float top_p,
+ float temperature,
+ uint32_t seed);
+
+ //重置千问2模型的 KV-cache
+ __export void llaisysQwen2ModelResetKVCache(struct LlaisysQwen2Model * model);
+
+ //启用/禁用 KV-cache
+ __export void llaisysQwen2ModelSetKVCacheEnabled(struct LlaisysQwen2Model * model, uint8_t enabled);
+
+ //设置张量并行参数
+ __export int32_t llaisysQwen2ModelSetTensorParallel(struct LlaisysQwen2Model *model,
+ llaisysComm_t comm,
+ llaisysStream_t stream,
+ int tp_size);
+
+ // ===== Experimental KV block/context APIs =====
+ __export struct LlaisysQwen2KVBlock *llaisysQwen2KVBlockCreate(
+ const struct LlaisysQwen2KVBlockMeta *meta,
+ llaisysDeviceType_t device,
+ int device_id);
+ __export void llaisysQwen2KVBlockRetain(struct LlaisysQwen2KVBlock *block);
+ __export void llaisysQwen2KVBlockRelease(struct LlaisysQwen2KVBlock *block);
+ __export int32_t llaisysQwen2KVBlockSetTokenCount(struct LlaisysQwen2KVBlock *block, size_t used_tokens);
+ __export size_t llaisysQwen2KVBlockTokenCount(const struct LlaisysQwen2KVBlock *block);
+ __export llaisysTensor_t llaisysQwen2KVBlockKeyTensor(struct LlaisysQwen2KVBlock *block, size_t layer);
+ __export llaisysTensor_t llaisysQwen2KVBlockValueTensor(struct LlaisysQwen2KVBlock *block, size_t layer);
+
+ __export struct LlaisysQwen2KVContext *llaisysQwen2KVContextCreate(
+ llaisysDataType_t dtype,
+ llaisysDeviceType_t device,
+ int device_id,
+ size_t nlayer,
+ size_t nh,
+ size_t nkvh,
+ size_t dh);
+ __export void llaisysQwen2KVContextRetain(struct LlaisysQwen2KVContext *ctx);
+ __export void llaisysQwen2KVContextRelease(struct LlaisysQwen2KVContext *ctx);
+ __export int32_t llaisysQwen2KVContextAttachBlock(
+ struct LlaisysQwen2KVContext *ctx,
+ struct LlaisysQwen2KVBlock *block);
+ __export void llaisysQwen2KVContextDetachAll(struct LlaisysQwen2KVContext *ctx);
+ __export size_t llaisysQwen2KVContextBlockCount(const struct LlaisysQwen2KVContext *ctx);
+ __export size_t llaisysQwen2KVContextTokenCount(const struct LlaisysQwen2KVContext *ctx);
+
+ __export int32_t llaisysQwen2ModelSetKVContext(
+ struct LlaisysQwen2Model *model,
+ struct LlaisysQwen2KVContext *ctx);
+ __export struct LlaisysQwen2KVContext *llaisysQwen2ModelGetKVContext(
+ struct LlaisysQwen2Model *model);
+ __export int32_t llaisysQwen2ModelExportKVContext(
+ struct LlaisysQwen2Model *model,
+ struct LlaisysQwen2KVContext *ctx,
+ size_t block_tokens);
}
#endif // LLAISYS_MODELS_QWEN2_H
diff --git a/include/llaisys/ops.h b/include/llaisys/ops.h
index ddb3be246..b79f074ca 100644
--- a/include/llaisys/ops.h
+++ b/include/llaisys/ops.h
@@ -12,6 +12,16 @@ __C {
__export void llaisysRmsNorm(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, float eps);
__export void llaisysROPE(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t pos_ids, float theta);
__export void llaisysSelfAttention(llaisysTensor_t attn_val, llaisysTensor_t q, llaisysTensor_t k, llaisysTensor_t v, float scale);
+ // Segmented self-attention for packed batches.
+ // q_offsets/kv_offsets must both have length nseg + 1 and be non-decreasing.
+ __export void llaisysSelfAttentionSegmented(llaisysTensor_t attn_val,
+ llaisysTensor_t q,
+ llaisysTensor_t k,
+ llaisysTensor_t v,
+ float scale,
+ const int64_t *q_offsets,
+ const int64_t *kv_offsets,
+ size_t nseg);
__export void llaisysSwiGLU(llaisysTensor_t out, llaisysTensor_t gate, llaisysTensor_t up);
}
diff --git a/include/llaisys/tokenizer.h b/include/llaisys/tokenizer.h
new file mode 100644
index 000000000..d7ff6a6c1
--- /dev/null
+++ b/include/llaisys/tokenizer.h
@@ -0,0 +1,45 @@
+#ifndef LLAISYS_TOKENIZER_H
+#define LLAISYS_TOKENIZER_H
+
+#include "../llaisys.h"
+
+__C {
+ struct LlaisysTokenizer;
+
+ // Create a SentencePiece tokenizer from model file path.
+ __export struct LlaisysTokenizer *llaisysTokenizerCreateSentencePiece(const char *model_path);
+
+ // Destroy tokenizer instance.
+ __export void llaisysTokenizerDestroy(struct LlaisysTokenizer *tokenizer);
+
+ // Encode text into token ids.
+ // If out_ids is null or max_ids is 0, returns the required length.
+ // On error returns -1.
+ __export int llaisysTokenizerEncode(struct LlaisysTokenizer *tokenizer,
+ const char *text,
+ int64_t *out_ids,
+ size_t max_ids);
+
+ // Decode token ids into text.
+ // If out_text is null or max_len is 0, returns the required length (including null terminator).
+ // On error returns -1.
+ __export int llaisysTokenizerDecode(struct LlaisysTokenizer *tokenizer,
+ const int64_t *ids,
+ size_t len,
+ char *out_text,
+ size_t max_len);
+
+ // Map a single token string to its id. Returns -1 if not found.
+ __export int64_t llaisysTokenizerTokenToId(struct LlaisysTokenizer *tokenizer, const char *token);
+
+ // Map a token id to its string.
+ // If out_token is null or max_len is 0, returns the required length (including null terminator).
+ // On error returns -1.
+ __export int llaisysTokenizerIdToToken(struct LlaisysTokenizer *tokenizer,
+ int64_t id,
+ char *out_token,
+ size_t max_len);
+
+}
+
+#endif // LLAISYS_TOKENIZER_H
diff --git a/python/llaisys/__init__.py b/python/llaisys/__init__.py
index de8d99f48..69ca3476e 100644
--- a/python/llaisys/__init__.py
+++ b/python/llaisys/__init__.py
@@ -5,6 +5,7 @@
from .libllaisys import llaisysStream_t as Stream
from .tensor import Tensor
from .ops import Ops
+from .tokenizer import Tokenizer
from . import models
from .models import *
@@ -16,5 +17,6 @@
"Stream",
"Tensor",
"Ops",
+ "Tokenizer",
"models",
]
diff --git a/python/llaisys/interfaces.py b/python/llaisys/interfaces.py
new file mode 100644
index 000000000..6e8461958
--- /dev/null
+++ b/python/llaisys/interfaces.py
@@ -0,0 +1,194 @@
+"""接口定义 - 解耦调度器、服务、KVCache 池之间的依赖"""
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Sequence
+
+if TYPE_CHECKING:
+ from llaisys.kv_cache_pool import AcquireResult
+
+
+class IKVCachePool(ABC):
+ """KVCache 池接口
+
+ 调度器可以通过此接口查询 KV 状态,而不需要知道具体实现。
+ """
+
+ @property
+ @abstractmethod
+ def block_size(self) -> int:
+ """每个 block 的 token 数量"""
+ pass
+
+ @abstractmethod
+ def query_prefix_len(self, tokens: Sequence[int]) -> int:
+ """查询前缀命中长度(只读,不修改状态)
+
+ Args:
+ tokens: 待查询的 token 序列
+
+ Returns:
+ 命中的前缀长度(token 数量)
+ """
+ pass
+
+ @abstractmethod
+ def acquire_context(self, context_id: str, tokens: Sequence[int]) -> "AcquireResult":
+ """获取/创建上下文,返回匹配的前缀长度
+
+ Args:
+ context_id: 上下文/会话 ID
+ tokens: 当前请求的完整 token 序列
+
+ Returns:
+ AcquireResult,包含 context_id 和 prefix_len
+ """
+ pass
+
+ @abstractmethod
+ def update_context(self, context_id: str, tokens: Sequence[int]) -> None:
+ """更新上下文的 token 序列(生成结束后调用)"""
+ pass
+
+ @abstractmethod
+ def release_context(self, context_id: str) -> None:
+ """释放上下文"""
+ pass
+
+ @abstractmethod
+ def memory_pressure(self) -> float:
+ """返回 KV 内存压力值 (0.0~1.0)
+
+ 取 used_blocks/max_blocks 和 used_bytes/max_bytes 的较大值。
+ 调度器可用此值做流控决策。
+ """
+ pass
+
+ @abstractmethod
+ def snapshot_stats(self) -> Dict[str, float]:
+ """获取统计信息快照"""
+ pass
+
+
+class IInferenceService(ABC):
+ """推理服务接口
+
+ 调度器依赖此接口进行任务分发和执行。
+ """
+
+ @abstractmethod
+ def generate(self, payload: Dict[str, Any]) -> Dict[str, Any]:
+ """非流式生成
+
+ Args:
+ payload: 请求参数(prompt, session_id, max_new_tokens 等)
+
+ Returns:
+ 生成结果(response, session_id, usage 等)
+ """
+ pass
+
+ @abstractmethod
+ def stream(self, payload: Dict[str, Any]) -> Iterable[Dict[str, Any]]:
+ """流式生成
+
+ Args:
+ payload: 请求参数
+
+ Yields:
+ 流式输出的每个 chunk(delta, done, session_id 等)
+ """
+ pass
+
+ @abstractmethod
+ def request_stop(self, session_id: str) -> bool:
+ """请求停止生成
+
+ Args:
+ session_id: 要停止的会话 ID
+
+ Returns:
+ 是否成功发送停止信号
+ """
+ pass
+
+ @abstractmethod
+ def kv_debug_snapshot(self, session_id: Optional[str] = None) -> Dict[str, Any]:
+ """获取 KV 调试快照
+
+ Args:
+ session_id: 可选,指定会话 ID
+
+ Returns:
+ 调试信息字典
+ """
+ pass
+
+ @property
+ @abstractmethod
+ def kv_pool(self) -> IKVCachePool:
+ """暴露 KVCache 池给调度器查询"""
+ pass
+
+ def generate_packed_non_stream(
+ self, payloads: Sequence[Dict[str, Any]]
+ ) -> Optional[Sequence[Dict[str, Any]]]:
+ """批量非流式生成(可选实现)
+
+ Args:
+ payloads: 多个请求参数
+
+ Returns:
+ 批量生成结果,如果不支持则返回 None
+ """
+ return None
+
+ def tokenize_for_routing(
+ self, payload: Dict[str, Any]
+ ) -> Optional[Sequence[int]]:
+ """为 KV 感知路由进行轻量级 tokenize(可选实现)
+
+ 调度器可以调用此方法将请求转换为 token 序列,
+ 用于查询各 worker 的 KV 命中情况。
+
+ Args:
+ payload: 请求参数(prompt, messages 等)
+
+ Returns:
+ token ids 序列,如果无法 tokenize 则返回 None
+ """
+ return None
+
+ def prepare_batch(
+ self, payloads: Sequence[Dict[str, Any]]
+ ) -> Optional[Any]:
+ """准备���式批处理:prefill 所有序列,返回 BatchState(可选实现)
+
+ Args:
+ payloads: 多个请求参数
+
+ Returns:
+ BatchState 对象,如果不支持则返回 None
+ """
+ return None
+
+ def step_batch(self, state: Any) -> Optional[Sequence[Any]]:
+ """执行一步批量 decode,返回每个序列的 StepResult(可选实现)
+
+ Args:
+ state: prepare_batch 返回的 BatchState
+
+ Returns:
+ StepResult 列表,如果不支持则返回 None
+ """
+ return None
+
+ def finalize_sequence(self, state: Any, seq_index: int) -> None:
+ """完成单个序列:保存消息历史,清理状态(可选实现)
+
+ Args:
+ state: BatchState 对象
+ seq_index: 序列在 batch 中的索引
+ """
+ pass
diff --git a/python/llaisys/kv_cache_pool.py b/python/llaisys/kv_cache_pool.py
new file mode 100644
index 000000000..7f6d952c4
--- /dev/null
+++ b/python/llaisys/kv_cache_pool.py
@@ -0,0 +1,327 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+import threading
+import time
+from typing import Dict, List, Optional, Sequence, Tuple
+
+from llaisys.interfaces import IKVCachePool
+
+
+@dataclass
+class KVBlock:
+ block_id: int
+ generation: int
+ parent_id: Optional[int]
+ tokens: Tuple[int, ...]
+ sealed: bool
+ ref_count: int
+ last_access: float
+ prefix_key: Optional[Tuple[int, ...]]
+
+ @property
+ def size_bytes(self) -> int:
+ # int64 token ids
+ return len(self.tokens) * 8
+
+
+@dataclass
+class ContextState:
+ block_ids: List[int]
+ tokens: Tuple[int, ...]
+ updated_at: float
+
+
+@dataclass
+class AcquireResult:
+ context_id: str
+ prefix_len: int
+
+
+class KVCachePool(IKVCachePool):
+ """In-memory token-block cache pool with reference counting.
+
+ Notes:
+ - Only sealed (full) blocks are indexed for cross-context sharing.
+ - Block IDs are monotonic and never reused.
+ """
+
+ def __init__(
+ self,
+ block_size: int = 64,
+ max_blocks: int = 4096,
+ max_bytes: int = 256 * 1024 * 1024,
+ ) -> None:
+ if block_size <= 0:
+ raise ValueError("block_size must be > 0")
+ self._block_size = int(block_size)
+ self.max_blocks = int(max_blocks)
+ self.max_bytes = int(max_bytes)
+
+ self._lock = threading.Lock()
+ self._next_block_id = 1
+ self._blocks: Dict[int, KVBlock] = {}
+ self._contexts: Dict[str, ContextState] = {}
+ # prefix(tuple(tokens up to this block)) -> (block_id, generation)
+ self._prefix_index: Dict[Tuple[int, ...], Tuple[int, int]] = {}
+ self._total_bytes = 0
+ self._acquire_count = 0
+ self._prefix_hit_count = 0
+ self._matched_tokens_total = 0
+
+ @property
+ def block_size(self) -> int:
+ return self._block_size
+
+ def acquire_context(self, context_id: str, tokens: Sequence[int]) -> AcquireResult:
+ """Bind context to current prompt tokens.
+
+ Returns matched prefix length for runtime reuse decision.
+ """
+ token_tuple = tuple(int(t) for t in tokens)
+ with self._lock:
+ _, matched_len = self._build_or_replace_context(context_id, token_tuple)
+ self._acquire_count += 1
+ self._matched_tokens_total += matched_len
+ if matched_len > 0:
+ self._prefix_hit_count += 1
+ return AcquireResult(context_id=context_id, prefix_len=matched_len)
+
+ def update_context(self, context_id: str, tokens: Sequence[int]) -> None:
+ """Update context after generation to preserve longer prefixes."""
+ token_tuple = tuple(int(t) for t in tokens)
+ with self._lock:
+ self._build_or_replace_context(context_id, token_tuple)
+
+ def release_context(self, context_id: str) -> None:
+ with self._lock:
+ old_state = self._contexts.pop(context_id, None)
+ if not old_state:
+ return
+ self._decref_chain(old_state.block_ids)
+ self._evict_if_needed()
+
+ def _build_or_replace_context(self, context_id: str, tokens: Tuple[int, ...]) -> Tuple[List[int], int]:
+ old_state = self._contexts.get(context_id)
+ old_block_ids = list(old_state.block_ids) if old_state else []
+
+ matched_block_ids, matched_len = self._find_longest_sealed_prefix(tokens)
+ new_block_ids = list(matched_block_ids)
+ created_block_ids: List[int] = []
+ incref_applied: List[int] = []
+
+ try:
+ # First, acquire refs to reused blocks.
+ for bid in matched_block_ids:
+ self._incref_block(bid)
+ incref_applied.append(bid)
+
+ parent_id = new_block_ids[-1] if new_block_ids else None
+ cursor = matched_len
+ current_prefix = tuple(tokens[:matched_len])
+ while cursor < len(tokens):
+ chunk = tuple(tokens[cursor : cursor + self.block_size])
+ sealed = len(chunk) == self.block_size
+ block_id = self._create_block(parent_id, chunk, sealed, current_prefix)
+ created_block_ids.append(block_id)
+ incref_applied.append(block_id)
+ new_block_ids.append(block_id)
+ parent_id = block_id
+ current_prefix = current_prefix + chunk
+ cursor += len(chunk)
+
+ # Commit context first, then release old refs.
+ self._contexts[context_id] = ContextState(
+ block_ids=new_block_ids,
+ tokens=tokens,
+ updated_at=time.time(),
+ )
+ self._decref_chain(old_block_ids)
+ self._evict_if_needed()
+ return new_block_ids, matched_len
+ except Exception:
+ # Rollback ref changes and newly created blocks.
+ self._safe_rollback(incref_applied, created_block_ids)
+ # Keep old state untouched.
+ if old_state is None:
+ self._contexts.pop(context_id, None)
+ else:
+ self._contexts[context_id] = old_state
+ raise
+
+ def _safe_rollback(self, incref_applied: List[int], created_block_ids: List[int]) -> None:
+ # Rollback refs idempotently.
+ seen = set()
+ for bid in reversed(incref_applied):
+ if bid in seen:
+ continue
+ seen.add(bid)
+ block = self._blocks.get(bid)
+ if not block:
+ continue
+ if block.ref_count > 0:
+ block.ref_count -= 1
+ block.last_access = time.time()
+ # Remove newly created zero-ref blocks.
+ for bid in created_block_ids:
+ block = self._blocks.get(bid)
+ if block and block.ref_count == 0:
+ self._remove_block(bid)
+
+ def _find_longest_sealed_prefix(self, tokens: Tuple[int, ...]) -> Tuple[List[int], int]:
+ matched_block_ids: List[int] = []
+ matched_len = 0
+ parent_id: Optional[int] = None
+ cursor = 0
+ prefix: Tuple[int, ...] = ()
+
+ while cursor + self.block_size <= len(tokens):
+ chunk = tuple(tokens[cursor : cursor + self.block_size])
+ prefix = prefix + chunk
+ indexed = self._prefix_index.get(prefix)
+ if not indexed:
+ break
+ bid, generation = indexed
+ block = self._blocks.get(bid)
+ if (
+ block is None
+ or block.generation != generation
+ or not block.sealed
+ or block.parent_id != parent_id
+ or block.tokens != chunk
+ ):
+ break
+ matched_block_ids.append(bid)
+ matched_len += self.block_size
+ parent_id = bid
+ cursor += self.block_size
+ return matched_block_ids, matched_len
+
+ def _create_block(
+ self,
+ parent_id: Optional[int],
+ tokens: Tuple[int, ...],
+ sealed: bool,
+ current_prefix: Tuple[int, ...],
+ ) -> int:
+ block_id = self._next_block_id
+ self._next_block_id += 1
+ generation = 1
+ prefix_key = current_prefix + tokens if sealed else None
+ block = KVBlock(
+ block_id=block_id,
+ generation=generation,
+ parent_id=parent_id,
+ tokens=tokens,
+ sealed=sealed,
+ ref_count=1,
+ last_access=time.time(),
+ prefix_key=prefix_key,
+ )
+ self._blocks[block_id] = block
+ self._total_bytes += block.size_bytes
+ if sealed and prefix_key is not None:
+ self._prefix_index[prefix_key] = (block_id, generation)
+ return block_id
+
+ def _incref_block(self, block_id: int) -> None:
+ block = self._blocks.get(block_id)
+ if not block:
+ raise RuntimeError(f"missing block {block_id}")
+ block.ref_count += 1
+ block.last_access = time.time()
+
+ def _decref_chain(self, block_ids: List[int]) -> None:
+ # Idempotent-ish: never below zero.
+ for bid in block_ids:
+ block = self._blocks.get(bid)
+ if not block:
+ continue
+ if block.ref_count > 0:
+ block.ref_count -= 1
+ block.last_access = time.time()
+
+ def _evict_if_needed(self) -> None:
+ while len(self._blocks) > self.max_blocks or self._total_bytes > self.max_bytes:
+ evict_candidates = [b for b in self._blocks.values() if b.ref_count == 0]
+ if not evict_candidates:
+ break
+ victim = min(evict_candidates, key=lambda b: b.last_access)
+ self._remove_block(victim.block_id)
+
+ def _remove_block(self, block_id: int) -> None:
+ block = self._blocks.pop(block_id, None)
+ if not block:
+ return
+ self._total_bytes = max(0, self._total_bytes - block.size_bytes)
+ if block.prefix_key is not None:
+ indexed = self._prefix_index.get(block.prefix_key)
+ if indexed and indexed[0] == block_id:
+ self._prefix_index.pop(block.prefix_key, None)
+
+ def memory_pressure(self) -> float:
+ """返回 KV 内存压力值 (0.0~1.0)"""
+ with self._lock:
+ block_ratio = len(self._blocks) / self.max_blocks if self.max_blocks > 0 else 0.0
+ byte_ratio = self._total_bytes / self.max_bytes if self.max_bytes > 0 else 0.0
+ return max(block_ratio, byte_ratio)
+
+ def snapshot_stats(self) -> Dict[str, float]:
+ """Return lightweight stats for verification and debugging."""
+ with self._lock:
+ zero_ref_blocks = sum(1 for b in self._blocks.values() if b.ref_count == 0)
+ shared_blocks = sum(1 for b in self._blocks.values() if b.ref_count > 1)
+ total_refs = sum(b.ref_count for b in self._blocks.values())
+ hit_rate = (
+ float(self._prefix_hit_count) / float(self._acquire_count)
+ if self._acquire_count > 0
+ else 0.0
+ )
+ avg_matched_tokens = (
+ float(self._matched_tokens_total) / float(self._acquire_count)
+ if self._acquire_count > 0
+ else 0.0
+ )
+ return {
+ "contexts": float(len(self._contexts)),
+ "blocks": float(len(self._blocks)),
+ "prefix_entries": float(len(self._prefix_index)),
+ "total_bytes": float(self._total_bytes),
+ "zero_ref_blocks": float(zero_ref_blocks),
+ "shared_blocks": float(shared_blocks),
+ "total_refs": float(total_refs),
+ "acquire_count": float(self._acquire_count),
+ "prefix_hit_count": float(self._prefix_hit_count),
+ "prefix_hit_rate": hit_rate,
+ "avg_matched_tokens": avg_matched_tokens,
+ }
+
+ def query_prefix_len(self, tokens: Sequence[int]) -> int:
+ """查询前缀命中长度(只读,不修改状态)
+
+ 调度器可以用此方法查询某个 token 序列在当前池中的命中情况,
+ 用于做 KV 感知的路由决策。
+
+ Args:
+ tokens: 待查询的 token 序列
+
+ Returns:
+ 命中的前缀长度(token 数量),0 表示无命中
+ """
+ token_tuple = tuple(int(t) for t in tokens)
+ with self._lock:
+ _, matched_len = self._find_longest_sealed_prefix(token_tuple)
+ return matched_len
+
+ def debug_context(self, context_id: str) -> Optional[Dict[str, object]]:
+ """Return context chain snapshot for tests and diagnostics."""
+ with self._lock:
+ state = self._contexts.get(context_id)
+ if state is None:
+ return None
+ return {
+ "context_id": context_id,
+ "tokens": list(state.tokens),
+ "block_ids": list(state.block_ids),
+ "updated_at": state.updated_at,
+ }
diff --git a/python/llaisys/kv_runtime_bridge.py b/python/llaisys/kv_runtime_bridge.py
new file mode 100644
index 000000000..c6433cae2
--- /dev/null
+++ b/python/llaisys/kv_runtime_bridge.py
@@ -0,0 +1,143 @@
+from __future__ import annotations
+
+import threading
+from typing import Any, Dict, List, Optional, Tuple
+
+
+class KVRuntimeBridge:
+ """Bridge for native C++ KV context lifecycle (bind, export, find, release, debug)."""
+
+ def __init__(self, model: Any, enabled: bool = False) -> None:
+ self._model = model
+ self._enabled = bool(enabled)
+ self._lock = threading.Lock()
+ self._native_kv_contexts: Dict[str, Any] = {}
+ self._native_kv_tokens: Dict[str, Tuple[int, ...]] = {}
+ self._last_kv_bind_debug: Dict[str, Dict[str, Any]] = {}
+
+ @property
+ def enabled(self) -> bool:
+ return self._enabled
+
+ def bind_for_request(
+ self,
+ context_id: str,
+ prompt_ids: List[int],
+ prefix_len: int,
+ ) -> None:
+ """Bind the best KV context to the model for the current request.
+
+ Search order:
+ 1. Same context_id native context
+ 2. Prefix-matching donor context
+ 3. No match -> set_kv_context(None)
+ """
+ debug: Dict[str, Any] = {
+ "enabled": self._enabled,
+ "session_id": context_id,
+ "prefix_len": int(prefix_len),
+ "bound": False,
+ "source_session_id": None,
+ "set_kv_context_rc": None,
+ }
+ if not self._enabled or prefix_len <= 0:
+ self._model.set_kv_context(None)
+ with self._lock:
+ self._last_kv_bind_debug[context_id] = debug
+ return
+ with self._lock:
+ ctx = self._native_kv_contexts.get(context_id)
+ source_session_id: Optional[str] = context_id if ctx else None
+ if not ctx:
+ source_session_id, ctx = self._find_for_prefix(prompt_ids, prefix_len)
+ if not ctx:
+ self._model.set_kv_context(None)
+ with self._lock:
+ self._last_kv_bind_debug[context_id] = debug
+ return
+ rc = self._model.set_kv_context(ctx)
+ debug["set_kv_context_rc"] = int(rc)
+ debug["source_session_id"] = source_session_id
+ if rc != 0:
+ self._model.set_kv_context(None)
+ else:
+ debug["bound"] = True
+ with self._lock:
+ self._last_kv_bind_debug[context_id] = debug
+
+ def export_after_request(
+ self,
+ context_id: str,
+ tokens: List[int],
+ block_size: int,
+ ) -> None:
+ """Export KV context after request completion for future reuse."""
+ if not self._enabled:
+ return
+ with self._lock:
+ ctx = self._native_kv_contexts.get(context_id)
+ if not ctx:
+ ctx = self._model.kv_context_create()
+ if not ctx:
+ return
+ with self._lock:
+ self._native_kv_contexts[context_id] = ctx
+ rc = self._model.export_kv_context(ctx, block_size)
+ if rc == 0:
+ with self._lock:
+ self._native_kv_tokens[context_id] = tuple(int(t) for t in tokens)
+
+ def release(self, context_id: str) -> None:
+ """Release native KV context for a given session."""
+ with self._lock:
+ ctx = self._native_kv_contexts.pop(context_id, None)
+ self._native_kv_tokens.pop(context_id, None)
+ self._last_kv_bind_debug.pop(context_id, None)
+ if ctx:
+ self._model.kv_context_release(ctx)
+
+ def debug_snapshot(self, session_id: Optional[str] = None) -> Dict[str, Any]:
+ """Return KV runtime debug information."""
+ with self._lock:
+ if session_id:
+ last_bind = dict(self._last_kv_bind_debug.get(session_id, {}))
+ native_tokens = len(self._native_kv_tokens.get(session_id, ()))
+ has_native_ctx = session_id in self._native_kv_contexts
+ else:
+ last_bind = {}
+ native_tokens = 0
+ has_native_ctx = False
+ native_contexts = len(self._native_kv_contexts)
+ tracked_token_sessions = len(self._native_kv_tokens)
+ return {
+ "session_id": session_id,
+ "has_native_context": has_native_ctx,
+ "native_tokens": native_tokens,
+ "native_contexts": native_contexts,
+ "tracked_token_sessions": tracked_token_sessions,
+ "last_bind": last_bind,
+ }
+
+ def _find_for_prefix(
+ self, prompt_ids: List[int], prefix_len: int
+ ) -> Tuple[Optional[str], Any]:
+ """Find native KV context matching the given prefix."""
+ if prefix_len <= 0:
+ return None, None
+ prompt_prefix = tuple(prompt_ids[:prefix_len])
+ with self._lock:
+ best_sid: Optional[str] = None
+ best_ctx: Any = None
+ best_len = -1
+ for sid, ctx in self._native_kv_contexts.items():
+ tokens = self._native_kv_tokens.get(sid, ())
+ tlen = len(tokens)
+ if tlen < prefix_len:
+ continue
+ if tuple(tokens[:prefix_len]) != prompt_prefix:
+ continue
+ if tlen > best_len:
+ best_len = tlen
+ best_sid = sid
+ best_ctx = ctx
+ return best_sid, best_ctx
diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py
index f536fb527..5d51a2b5d 100644
--- a/python/llaisys/libllaisys/__init__.py
+++ b/python/llaisys/libllaisys/__init__.py
@@ -12,6 +12,20 @@
from .tensor import llaisysTensor_t
from .tensor import load_tensor
from .ops import load_ops
+from .models import load_models, load_comm
+from .models import (
+ LlaisysQwen2Meta,
+ LlaisysQwen2Weights,
+ LlaisysQwen2Model,
+ LlaisysSamplingParams,
+ LlaisysQwen2KVBlockMeta,
+ LlaisysQwen2KVBlock,
+ LlaisysQwen2KVContext,
+ LlaisysCommAPI,
+ llaisysComm_t,
+ LLAISYS_COMM_UNIQUE_ID_MAX_SIZE,
+)
+from .tokenizer import load_tokenizer, LlaisysTokenizer
def load_shared_library():
@@ -38,6 +52,9 @@ def load_shared_library():
load_runtime(LIB_LLAISYS)
load_tensor(LIB_LLAISYS)
load_ops(LIB_LLAISYS)
+load_models(LIB_LLAISYS)
+load_comm(LIB_LLAISYS)
+load_tokenizer(LIB_LLAISYS)
__all__ = [
@@ -52,4 +69,15 @@ def load_shared_library():
"llaisysMemcpyKind_t",
"MemcpyKind",
"llaisysStream_t",
+ "LlaisysQwen2Meta",
+ "LlaisysQwen2Weights",
+ "LlaisysQwen2Model",
+ "LlaisysSamplingParams",
+ "LlaisysQwen2KVBlockMeta",
+ "LlaisysQwen2KVBlock",
+ "LlaisysQwen2KVContext",
+ "LlaisysCommAPI",
+ "llaisysComm_t",
+ "LLAISYS_COMM_UNIQUE_ID_MAX_SIZE",
+ "LlaisysTokenizer",
]
diff --git a/python/llaisys/libllaisys/llaisys_types.py b/python/llaisys/libllaisys/llaisys_types.py
index c5a0b4679..84c761b73 100644
--- a/python/llaisys/libllaisys/llaisys_types.py
+++ b/python/llaisys/libllaisys/llaisys_types.py
@@ -6,7 +6,8 @@
class DeviceType(IntEnum):
CPU = 0
NVIDIA = 1
- COUNT = 2
+ ILUVATAR = 2
+ COUNT = 3
llaisysDeviceType_t = ctypes.c_int
diff --git a/python/llaisys/libllaisys/models.py b/python/llaisys/libllaisys/models.py
new file mode 100644
index 000000000..419c0f3e9
--- /dev/null
+++ b/python/llaisys/libllaisys/models.py
@@ -0,0 +1,287 @@
+from ctypes import Structure, POINTER, CFUNCTYPE, c_size_t, c_int, c_float, c_int64, c_uint32, c_void_p, c_int32
+
+from .llaisys_types import llaisysDeviceType_t, llaisysDataType_t, llaisysStream_t
+from .tensor import llaisysTensor_t
+
+
+class LlaisysQwen2Meta(Structure):
+ _fields_ = [
+ ("dtype", llaisysDataType_t),
+ ("nlayer", c_size_t),
+ ("hs", c_size_t),
+ ("nh", c_size_t),
+ ("nkvh", c_size_t),
+ ("dh", c_size_t),
+ ("di", c_size_t),
+ ("maxseq", c_size_t),
+ ("voc", c_size_t),
+ ("epsilon", c_float),
+ ("theta", c_float),
+ ("end_token", c_int64),
+ ]
+
+
+class LlaisysQwen2Weights(Structure):
+ _fields_ = [
+ ("in_embed", llaisysTensor_t),
+ ("out_embed", llaisysTensor_t),
+ ("out_norm_w", llaisysTensor_t),
+ ("attn_norm_w", POINTER(llaisysTensor_t)),
+ ("attn_q_w", POINTER(llaisysTensor_t)),
+ ("attn_q_b", POINTER(llaisysTensor_t)),
+ ("attn_k_w", POINTER(llaisysTensor_t)),
+ ("attn_k_b", POINTER(llaisysTensor_t)),
+ ("attn_v_w", POINTER(llaisysTensor_t)),
+ ("attn_v_b", POINTER(llaisysTensor_t)),
+ ("attn_o_w", POINTER(llaisysTensor_t)),
+ ("mlp_norm_w", POINTER(llaisysTensor_t)),
+ ("mlp_gate_w", POINTER(llaisysTensor_t)),
+ ("mlp_up_w", POINTER(llaisysTensor_t)),
+ ("mlp_down_w", POINTER(llaisysTensor_t)),
+ ]
+
+class LlaisysSamplingParams(Structure):
+ _fields_ = [
+ ("top_k", c_int),
+ ("top_p", c_float),
+ ("temperature", c_float),
+ ("seed", c_uint32),
+ ]
+
+
+class LlaisysQwen2KVBlockMeta(Structure):
+ _fields_ = [
+ ("dtype", llaisysDataType_t),
+ ("nlayer", c_size_t),
+ ("nh", c_size_t),
+ ("nkvh", c_size_t),
+ ("dh", c_size_t),
+ ("max_tokens", c_size_t),
+ ]
+
+
+LlaisysQwen2Model = c_void_p
+LlaisysQwen2KVBlock = c_void_p
+LlaisysQwen2KVContext = c_void_p
+
+
+def load_models(lib):
+ lib.llaisysQwen2ModelCreate.argtypes = [
+ POINTER(LlaisysQwen2Meta),
+ llaisysDeviceType_t,
+ POINTER(c_int),
+ c_int,
+ ]
+ lib.llaisysQwen2ModelCreate.restype = LlaisysQwen2Model
+
+ lib.llaisysQwen2ModelDestroy.argtypes = [LlaisysQwen2Model]
+ lib.llaisysQwen2ModelDestroy.restype = None
+
+ lib.llaisysQwen2ModelWeights.argtypes = [LlaisysQwen2Model]
+ lib.llaisysQwen2ModelWeights.restype = POINTER(LlaisysQwen2Weights)
+
+ lib.llaisysQwen2ModelInfer.argtypes = [LlaisysQwen2Model, POINTER(c_int64), c_size_t]
+ lib.llaisysQwen2ModelInfer.restype = c_int64
+
+ lib.llaisysQwen2ModelPrefill.argtypes = [LlaisysQwen2Model, POINTER(c_int64), c_size_t]
+ lib.llaisysQwen2ModelPrefill.restype = c_int64
+
+ lib.llaisysQwen2ModelStep.argtypes = [LlaisysQwen2Model, POINTER(c_int64), c_size_t]
+ lib.llaisysQwen2ModelStep.restype = c_int64
+ if hasattr(lib, "llaisysQwen2ModelPrefillPacked"):
+ lib.llaisysQwen2ModelPrefillPacked.argtypes = [
+ LlaisysQwen2Model,
+ POINTER(c_int64),
+ POINTER(c_int64),
+ c_size_t,
+ POINTER(c_int64),
+ ]
+ lib.llaisysQwen2ModelPrefillPacked.restype = c_int32
+ if hasattr(lib, "llaisysQwen2ModelStepPacked"):
+ lib.llaisysQwen2ModelStepPacked.argtypes = [
+ LlaisysQwen2Model,
+ POINTER(c_int64),
+ POINTER(c_int64),
+ c_size_t,
+ POINTER(c_int64),
+ ]
+ lib.llaisysQwen2ModelStepPacked.restype = c_int32
+
+ if hasattr(lib, "llaisysQwen2ModelPrefillPackedSampling"):
+ lib.llaisysQwen2ModelPrefillPackedSampling.argtypes = [
+ LlaisysQwen2Model,
+ POINTER(c_int64),
+ POINTER(c_int64),
+ c_size_t,
+ POINTER(LlaisysSamplingParams),
+ POINTER(c_int64),
+ ]
+ lib.llaisysQwen2ModelPrefillPackedSampling.restype = c_int32
+
+ if hasattr(lib, "llaisysQwen2ModelStepPackedSampling"):
+ lib.llaisysQwen2ModelStepPackedSampling.argtypes = [
+ LlaisysQwen2Model,
+ POINTER(c_int64),
+ POINTER(c_int64),
+ c_size_t,
+ POINTER(LlaisysSamplingParams),
+ POINTER(c_int64),
+ ]
+ lib.llaisysQwen2ModelStepPackedSampling.restype = c_int32
+
+ lib.llaisysQwen2ModelPrefillSampling.argtypes = [
+ LlaisysQwen2Model,
+ POINTER(c_int64),
+ c_size_t,
+ POINTER(LlaisysSamplingParams),
+ ]
+ lib.llaisysQwen2ModelPrefillSampling.restype = c_int64
+
+ lib.llaisysQwen2ModelStepSampling.argtypes = [
+ LlaisysQwen2Model,
+ POINTER(c_int64),
+ c_size_t,
+ POINTER(LlaisysSamplingParams),
+ ]
+ lib.llaisysQwen2ModelStepSampling.restype = c_int64
+
+ lib.llaisysQwen2ModelInferSampling.argtypes = [
+ LlaisysQwen2Model,
+ POINTER(c_int64),
+ c_size_t,
+ POINTER(LlaisysSamplingParams),
+ ]
+ lib.llaisysQwen2ModelInferSampling.restype = c_int64
+
+ lib.llaisysQwen2ModelInferSamplingEx.argtypes = [
+ LlaisysQwen2Model,
+ POINTER(c_int64),
+ c_size_t,
+ c_int,
+ c_float,
+ c_float,
+ c_uint32,
+ ]
+ lib.llaisysQwen2ModelInferSamplingEx.restype = c_int64
+
+ lib.llaisysQwen2ModelResetKVCache.argtypes = [LlaisysQwen2Model]
+ lib.llaisysQwen2ModelResetKVCache.restype = None
+
+ lib.llaisysQwen2ModelSetKVCacheEnabled.argtypes = [LlaisysQwen2Model, c_int]
+ lib.llaisysQwen2ModelSetKVCacheEnabled.restype = None
+
+ # Experimental KV block/context APIs
+ lib.llaisysQwen2KVBlockCreate.argtypes = [
+ POINTER(LlaisysQwen2KVBlockMeta),
+ llaisysDeviceType_t,
+ c_int,
+ ]
+ lib.llaisysQwen2KVBlockCreate.restype = LlaisysQwen2KVBlock
+ lib.llaisysQwen2KVBlockRetain.argtypes = [LlaisysQwen2KVBlock]
+ lib.llaisysQwen2KVBlockRetain.restype = None
+ lib.llaisysQwen2KVBlockRelease.argtypes = [LlaisysQwen2KVBlock]
+ lib.llaisysQwen2KVBlockRelease.restype = None
+ lib.llaisysQwen2KVBlockSetTokenCount.argtypes = [LlaisysQwen2KVBlock, c_size_t]
+ lib.llaisysQwen2KVBlockSetTokenCount.restype = c_int32
+ lib.llaisysQwen2KVBlockTokenCount.argtypes = [LlaisysQwen2KVBlock]
+ lib.llaisysQwen2KVBlockTokenCount.restype = c_size_t
+ lib.llaisysQwen2KVBlockKeyTensor.argtypes = [LlaisysQwen2KVBlock, c_size_t]
+ lib.llaisysQwen2KVBlockKeyTensor.restype = llaisysTensor_t
+ lib.llaisysQwen2KVBlockValueTensor.argtypes = [LlaisysQwen2KVBlock, c_size_t]
+ lib.llaisysQwen2KVBlockValueTensor.restype = llaisysTensor_t
+
+ lib.llaisysQwen2KVContextCreate.argtypes = [
+ llaisysDataType_t,
+ llaisysDeviceType_t,
+ c_int,
+ c_size_t,
+ c_size_t,
+ c_size_t,
+ c_size_t,
+ ]
+ lib.llaisysQwen2KVContextCreate.restype = LlaisysQwen2KVContext
+ lib.llaisysQwen2KVContextRetain.argtypes = [LlaisysQwen2KVContext]
+ lib.llaisysQwen2KVContextRetain.restype = None
+ lib.llaisysQwen2KVContextRelease.argtypes = [LlaisysQwen2KVContext]
+ lib.llaisysQwen2KVContextRelease.restype = None
+ lib.llaisysQwen2KVContextAttachBlock.argtypes = [LlaisysQwen2KVContext, LlaisysQwen2KVBlock]
+ lib.llaisysQwen2KVContextAttachBlock.restype = c_int32
+ lib.llaisysQwen2KVContextDetachAll.argtypes = [LlaisysQwen2KVContext]
+ lib.llaisysQwen2KVContextDetachAll.restype = None
+ lib.llaisysQwen2KVContextBlockCount.argtypes = [LlaisysQwen2KVContext]
+ lib.llaisysQwen2KVContextBlockCount.restype = c_size_t
+ lib.llaisysQwen2KVContextTokenCount.argtypes = [LlaisysQwen2KVContext]
+ lib.llaisysQwen2KVContextTokenCount.restype = c_size_t
+
+ lib.llaisysQwen2ModelSetKVContext.argtypes = [LlaisysQwen2Model, LlaisysQwen2KVContext]
+ lib.llaisysQwen2ModelSetKVContext.restype = c_int32
+ lib.llaisysQwen2ModelGetKVContext.argtypes = [LlaisysQwen2Model]
+ lib.llaisysQwen2ModelGetKVContext.restype = LlaisysQwen2KVContext
+ lib.llaisysQwen2ModelExportKVContext.argtypes = [LlaisysQwen2Model, LlaisysQwen2KVContext, c_size_t]
+ lib.llaisysQwen2ModelExportKVContext.restype = c_int32
+
+ if hasattr(lib, "llaisysQwen2ModelSetTensorParallel"):
+ lib.llaisysQwen2ModelSetTensorParallel.argtypes = [LlaisysQwen2Model, c_void_p, c_void_p, c_int]
+ lib.llaisysQwen2ModelSetTensorParallel.restype = c_int32
+
+
+# --- Comm API ctypes ---
+
+llaisysComm_t = c_void_p
+
+LLAISYS_COMM_UNIQUE_ID_MAX_SIZE = 128
+
+comm_init_api = CFUNCTYPE(c_int, POINTER(llaisysComm_t), c_int, c_int, c_void_p)
+comm_destroy_api = CFUNCTYPE(None, llaisysComm_t)
+comm_get_rank_api = CFUNCTYPE(c_int, llaisysComm_t)
+comm_get_size_api = CFUNCTYPE(c_int, llaisysComm_t)
+comm_allreduce_api = CFUNCTYPE(
+ None, c_void_p, c_void_p, c_size_t, c_int, c_int, llaisysComm_t, llaisysStream_t,
+)
+comm_broadcast_api = CFUNCTYPE(
+ None, c_void_p, c_size_t, c_int, c_int, llaisysComm_t, llaisysStream_t,
+)
+comm_send_api = CFUNCTYPE(
+ None, c_void_p, c_size_t, c_int, c_int, llaisysComm_t, llaisysStream_t,
+)
+comm_recv_api = CFUNCTYPE(
+ None, c_void_p, c_size_t, c_int, c_int, llaisysComm_t, llaisysStream_t,
+)
+
+
+class LlaisysCommAPI(Structure):
+ _fields_ = [
+ ("init", comm_init_api),
+ ("destroy", comm_destroy_api),
+ ("get_rank", comm_get_rank_api),
+ ("get_size", comm_get_size_api),
+ ("allreduce", comm_allreduce_api),
+ ("broadcast", comm_broadcast_api),
+ ("send", comm_send_api),
+ ("recv", comm_recv_api),
+ ]
+
+
+def load_comm(lib):
+ if hasattr(lib, "llaisysGetCommAPI"):
+ lib.llaisysGetCommAPI.argtypes = [c_int]
+ lib.llaisysGetCommAPI.restype = POINTER(LlaisysCommAPI)
+ if hasattr(lib, "llaisysCommGenerateUniqueId"):
+ lib.llaisysCommGenerateUniqueId.argtypes = [c_int, c_void_p, POINTER(c_size_t)]
+ lib.llaisysCommGenerateUniqueId.restype = c_int
+
+
+__all__ = [
+ "LlaisysQwen2Meta",
+ "LlaisysQwen2Weights",
+ "LlaisysSamplingParams",
+ "LlaisysQwen2KVBlockMeta",
+ "LlaisysQwen2Model",
+ "LlaisysQwen2KVBlock",
+ "LlaisysQwen2KVContext",
+ "load_models",
+ "LlaisysCommAPI",
+ "llaisysComm_t",
+ "LLAISYS_COMM_UNIQUE_ID_MAX_SIZE",
+ "load_comm",
+]
diff --git a/python/llaisys/libllaisys/ops.py b/python/llaisys/libllaisys/ops.py
index 5be095eff..a1ee3a8cb 100644
--- a/python/llaisys/libllaisys/ops.py
+++ b/python/llaisys/libllaisys/ops.py
@@ -1,5 +1,5 @@
from .tensor import llaisysTensor_t
-from ctypes import c_float
+from ctypes import c_float, c_int64, c_size_t, POINTER
def load_ops(lib):
lib.llaisysAdd.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t]
@@ -32,5 +32,18 @@ def load_ops(lib):
]
lib.llaisysSelfAttention.restype = None
+ if hasattr(lib, "llaisysSelfAttentionSegmented"):
+ lib.llaisysSelfAttentionSegmented.argtypes = [
+ llaisysTensor_t, # attn_val
+ llaisysTensor_t, # q
+ llaisysTensor_t, # k
+ llaisysTensor_t, # v
+ c_float, # scale
+ POINTER(c_int64), # q_offsets ptr
+ POINTER(c_int64), # kv_offsets ptr
+ c_size_t, # nseg
+ ]
+ lib.llaisysSelfAttentionSegmented.restype = None
+
lib.llaisysSwiGLU.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t]
lib.llaisysSwiGLU.restype = None
diff --git a/python/llaisys/libllaisys/tokenizer.py b/python/llaisys/libllaisys/tokenizer.py
new file mode 100644
index 000000000..27d6df499
--- /dev/null
+++ b/python/llaisys/libllaisys/tokenizer.py
@@ -0,0 +1,37 @@
+from ctypes import POINTER, c_char_p, c_int, c_int64, c_size_t, c_void_p
+
+
+LlaisysTokenizer = c_void_p
+
+
+def load_tokenizer(lib):
+ lib.llaisysTokenizerCreateSentencePiece.argtypes = [c_char_p]
+ lib.llaisysTokenizerCreateSentencePiece.restype = LlaisysTokenizer
+
+ lib.llaisysTokenizerDestroy.argtypes = [LlaisysTokenizer]
+ lib.llaisysTokenizerDestroy.restype = None
+
+ lib.llaisysTokenizerEncode.argtypes = [
+ LlaisysTokenizer,
+ c_char_p,
+ POINTER(c_int64),
+ c_size_t,
+ ]
+ lib.llaisysTokenizerEncode.restype = c_int
+
+ lib.llaisysTokenizerDecode.argtypes = [
+ LlaisysTokenizer,
+ POINTER(c_int64),
+ c_size_t,
+ c_char_p,
+ c_size_t,
+ ]
+ lib.llaisysTokenizerDecode.restype = c_int
+
+ lib.llaisysTokenizerTokenToId.argtypes = [LlaisysTokenizer, c_char_p]
+ lib.llaisysTokenizerTokenToId.restype = c_int64
+ lib.llaisysTokenizerIdToToken.argtypes = [LlaisysTokenizer, c_int64, c_char_p, c_size_t]
+ lib.llaisysTokenizerIdToToken.restype = c_int
+
+
+__all__ = ["LlaisysTokenizer", "load_tokenizer"]
diff --git a/python/llaisys/models/__init__.py b/python/llaisys/models/__init__.py
index af9918b0d..242720675 100644
--- a/python/llaisys/models/__init__.py
+++ b/python/llaisys/models/__init__.py
@@ -1 +1 @@
-from .qwen2 import Qwen2
+from .qwen2 import Qwen2, format_chat_prompt
diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py
index 0d07b0b21..b53b4c54c 100644
--- a/python/llaisys/models/qwen2.py
+++ b/python/llaisys/models/qwen2.py
@@ -1,33 +1,599 @@
-from typing import Sequence
-from ..libllaisys import LIB_LLAISYS
-from ..libllaisys import DeviceType
-
+from typing import Sequence, Iterable, Mapping, Optional
+import warnings
+from ctypes import byref, c_int, c_size_t, c_float, c_int64, c_uint32, c_void_p
+import json
from pathlib import Path
+
+import numpy as np
import safetensors
+from ..libllaisys import (
+ LIB_LLAISYS,
+ DeviceType,
+ DataType,
+ llaisysDeviceType_t,
+ llaisysDataType_t,
+ LlaisysQwen2Meta,
+ LlaisysSamplingParams,
+ LlaisysQwen2KVBlockMeta,
+)
+
+
+def format_chat_prompt(
+ messages: Iterable[Mapping[str, str]],
+ system_prompt: Optional[str] = None,
+ add_generation_prompt: bool = True,
+) -> str:
+ lines: list[str] = []
+ if system_prompt:
+ lines.append(f"System: {system_prompt}")
+ for msg in messages:
+ role = str(msg.get("role", "")).strip().lower()
+ content = str(msg.get("content", "")).strip()
+ if not role or content == "":
+ raise ValueError("Each message must have non-empty role and content")
+ if role == "system":
+ label = "System"
+ elif role == "assistant":
+ label = "Assistant"
+ else:
+ label = "User"
+ lines.append(f"{label}: {content}")
+ if add_generation_prompt:
+ if not lines or not lines[-1].startswith("Assistant:"):
+ lines.append("Assistant: ")
+ return "\n".join(lines)
+
class Qwen2:
+ @staticmethod
+ def build_prompt(
+ messages: Iterable[Mapping[str, str]],
+ system_prompt: Optional[str] = None,
+ add_generation_prompt: bool = True,
+ ) -> str:
+ return format_chat_prompt(messages, system_prompt, add_generation_prompt)
- def __init__(self, model_path, device: DeviceType = DeviceType.CPU):
- # TODO: Implement model constructor
+ def __init__(self, model_path, device: DeviceType = DeviceType.CPU):
model_path = Path(model_path)
+ self._device = device
+ # 实例化模型元信息
+ config_path = model_path / "config.json"
+ # 如果config.json不存在,则递归查找
+ if not config_path.exists():
+ candidates = list(model_path.rglob("config.json"))
+ if not candidates:
+ raise FileNotFoundError("config.json not found under model_path")
+ config_path = candidates[0]
+ # 读取配置文件
+ with open(config_path, "r", encoding="utf-8") as f:
+ cfg = json.load(f)
+ # 解析数据类型
+ torch_dtype = str(cfg.get("torch_dtype", "bfloat16")).lower()
+ if "float32" in torch_dtype or torch_dtype in {"fp32", "f32"}:
+ dtype = DataType.F32
+ elif "float16" in torch_dtype or torch_dtype in {"fp16", "f16"}:
+ dtype = DataType.F16
+ else:
+ dtype = DataType.BF16
+ # 统一用 torch 读取,避免 numpy->torch 混合加载路径在 Windows 上触发崩溃
+ # (历史上 safetensors 在该切换路径会触发访问冲突)。
+ use_torch_loader = True
+ if dtype == DataType.BF16:
+ dtype = DataType.F16
+ # 解析模型参数
+ nlayer = int(cfg.get("num_hidden_layers", 0))
+ hs = int(cfg.get("hidden_size", 0))
+ nh = int(cfg.get("num_attention_heads", 0))
+ nkvh = int(cfg.get("num_key_value_heads", nh))
+ di = int(cfg.get("intermediate_size", 0))
+ maxseq = int(cfg.get("max_position_embeddings", 0))
+ voc = int(cfg.get("vocab_size", 0))
+ epsilon = float(cfg.get("rms_norm_eps", 1e-6))
+ theta = float(cfg.get("rope_theta", 10000.0))
+ eos = cfg.get("eos_token_id", -1)
+ # 解析结束token
+ if isinstance(eos, list):
+ end_token = int(eos[0]) if eos else -1
+ else:
+ end_token = int(eos)
+ # 解析head_dim
+ dh = int(cfg.get("head_dim", hs // nh if nh else 0))
+ # 创建模型元信息结构体
+ model_meta = LlaisysQwen2Meta(
+ llaisysDataType_t(dtype),
+ c_size_t(nlayer),
+ c_size_t(hs),
+ c_size_t(nh),
+ c_size_t(nkvh),
+ c_size_t(dh),
+ c_size_t(di),
+ c_size_t(maxseq),
+ c_size_t(voc),
+ c_float(epsilon),
+ c_float(theta),
+ c_int64(end_token),
+ )
+ # 创建模型实例
+ device_ids = (c_int * 1)(0)
+ self._model = LIB_LLAISYS.llaisysQwen2ModelCreate(
+ byref(model_meta),
+ llaisysDeviceType_t(device),
+ device_ids,
+ 1,
+ )
+ if not self._model:
+ raise RuntimeError("llaisysQwen2ModelCreate failed")
+ self._model_weights = LIB_LLAISYS.llaisysQwen2ModelWeights(self._model)
+ self._meta = model_meta
+
+ # 默认开启 KV-cache
+ LIB_LLAISYS.llaisysQwen2ModelSetKVCacheEnabled(self._model, c_int(1))
+ #
+ def _dtype_to_llaisys(dtype: np.dtype) -> DataType:
+ name = getattr(dtype, "name", str(dtype)).lower()
+ if name in {"float32", "f4"}:
+ return DataType.F32
+ if name in {"float16", "f2"}:
+ return DataType.F16
+ if name in {"bfloat16", "bf16"}:
+ return DataType.BF16
+ if name in {"int64", "i8"}:
+ return DataType.I64
+ if name in {"int32", "i4"}:
+ return DataType.I32
+ if name in {"int16", "i2"}:
+ return DataType.I16
+ if name in {"int8", "i1"}:
+ return DataType.I8
+ if name in {"uint8", "u1"}:
+ return DataType.U8
+ raise ValueError(f"Unsupported dtype: {dtype}")
+
+ def _create_tensor_from_numpy(arr: np.ndarray):
+ arr = np.ascontiguousarray(arr)
+ _shape = (c_size_t * arr.ndim)(*arr.shape)
+ _dtype = _dtype_to_llaisys(arr.dtype)
+ tensor = LIB_LLAISYS.tensorCreate(
+ _shape,
+ c_size_t(arr.ndim),
+ llaisysDataType_t(_dtype),
+ llaisysDeviceType_t(device),
+ c_int(0),
+ )
+ LIB_LLAISYS.tensorLoad(tensor, c_void_p(arr.ctypes.data))
+ return tensor
+
+ # 加载模型权重
for file in sorted(model_path.glob("*.safetensors")):
- data_ = safetensors.safe_open(file, framework="numpy", device="cpu")
+ import torch
+ data_ = safetensors.safe_open(file, framework="pt", device="cpu")
for name_ in data_.keys():
## TODO: load the model weights
- pass
+ arr = data_.get_tensor(name_)
+ if use_torch_loader:
+ if arr.dtype == torch.bfloat16:
+ arr = arr.to(torch.float16)
+ arr = arr.cpu().numpy()
+ tensor = _create_tensor_from_numpy(arr)
+ w = self._model_weights.contents
+
+ if name_ in {"model.embed_tokens.weight", "transformer.wte.weight"}:
+ w.in_embed = tensor
+ continue
+ if name_ in {"lm_head.weight", "model.lm_head.weight"}:
+ w.out_embed = tensor
+ continue
+ if name_ in {"model.norm.weight", "transformer.ln_f.weight"}:
+ w.out_norm_w = tensor
+ continue
+
+ if name_.startswith("model.layers."):
+ parts = name_.split(".")
+ if len(parts) < 4:
+ continue
+ layer = int(parts[2])
+ sub = ".".join(parts[3:])
+
+ if sub == "input_layernorm.weight":
+ w.attn_norm_w[layer] = tensor
+ elif sub == "self_attn.q_proj.weight":
+ w.attn_q_w[layer] = tensor
+ elif sub == "self_attn.q_proj.bias":
+ w.attn_q_b[layer] = tensor
+ elif sub == "self_attn.k_proj.weight":
+ w.attn_k_w[layer] = tensor
+ elif sub == "self_attn.k_proj.bias":
+ w.attn_k_b[layer] = tensor
+ elif sub == "self_attn.v_proj.weight":
+ w.attn_v_w[layer] = tensor
+ elif sub == "self_attn.v_proj.bias":
+ w.attn_v_b[layer] = tensor
+ elif sub == "self_attn.o_proj.weight":
+ w.attn_o_w[layer] = tensor
+ elif sub == "post_attention_layernorm.weight":
+ w.mlp_norm_w[layer] = tensor
+ elif sub == "mlp.gate_proj.weight":
+ w.mlp_gate_w[layer] = tensor
+ elif sub == "mlp.up_proj.weight":
+ w.mlp_up_w[layer] = tensor
+ elif sub == "mlp.down_proj.weight":
+ w.mlp_down_w[layer] = tensor
+ w = self._model_weights.contents
+ if not w.out_embed and w.in_embed:
+ w.out_embed = w.in_embed
+
+
def generate(
self,
+ # 输入数组
inputs: Sequence[int],
+ # 最大token数
max_new_tokens: int = None,
+ # top-k 采样,1 表示贪心
top_k: int = 1,
+ # top-p 核采样阈值
top_p: float = 0.8,
+ # 温度系数,越小越保守
temperature: float = 0.8,
+ # 随机种子,0 表示随机
+ seed: int = 0,
):
+ tokens = list(inputs)
+ if max_new_tokens is None:
+ max_new_tokens = 128
+ use_sampling = temperature > 0 or top_k > 1 or top_p > 0
+
+ # prefill with full prompt
+ if use_sampling:
+ next_token = int(
+ self.prefill_sampling(
+ tokens,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ seed=seed,
+ )
+ )
+ else:
+ token_buf = (c_int64 * len(tokens))(*tokens)
+ next_token = int(
+ LIB_LLAISYS.llaisysQwen2ModelPrefill(
+ self._model,
+ token_buf,
+ c_size_t(len(tokens)),
+ )
+ )
+ if next_token < 0:
+ return tokens
+ tokens.append(next_token)
+ if self._meta.end_token >= 0 and next_token == self._meta.end_token:
+ return tokens
+
+ remaining = max_new_tokens - 1
+ if remaining <= 0:
+ return tokens
+
+ # step with newly generated tokens only
+ for _ in range(remaining):
+ if next_token < 0:
+ break
+ if self._meta.end_token >= 0 and next_token == self._meta.end_token:
+ break
+ if use_sampling:
+ next_token = int(
+ self.step_sampling(
+ [next_token],
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ seed=seed,
+ )
+ )
+ else:
+ token_buf = (c_int64 * 1)(next_token)
+ next_token = int(
+ LIB_LLAISYS.llaisysQwen2ModelStep(
+ self._model,
+ token_buf,
+ c_size_t(1),
+ )
+ )
+ if next_token < 0:
+ break
+ tokens.append(next_token)
+
+ return tokens
+
+ def prefill(self, inputs: Sequence[int]) -> int:
+ tokens = list(inputs)
+ token_buf = (c_int64 * len(tokens))(*tokens)
+ return int(
+ LIB_LLAISYS.llaisysQwen2ModelPrefill(
+ self._model,
+ token_buf,
+ c_size_t(len(tokens)),
+ )
+ )
+
+ def step(self, new_tokens: Sequence[int]) -> int:
+ tokens = list(new_tokens)
+ token_buf = (c_int64 * len(tokens))(*tokens)
+ return int(
+ LIB_LLAISYS.llaisysQwen2ModelStep(
+ self._model,
+ token_buf,
+ c_size_t(len(tokens)),
+ )
+ )
+
+ def prefill_packed(self, sequences: Sequence[Sequence[int]]) -> list[int]:
+ seqs = [list(s) for s in sequences]
+ if not seqs:
+ return []
+ if not hasattr(LIB_LLAISYS, "llaisysQwen2ModelPrefillPacked"):
+ raise RuntimeError("llaisysQwen2ModelPrefillPacked is unavailable in current llaisys.dll")
+ offsets = [0]
+ flat: list[int] = []
+ for s in seqs:
+ if not s:
+ raise ValueError("each packed sequence must be non-empty")
+ flat.extend(int(x) for x in s)
+ offsets.append(len(flat))
+ token_buf = (c_int64 * len(flat))(*flat)
+ off_buf = (c_int64 * len(offsets))(*offsets)
+ out_buf = (c_int64 * len(seqs))()
+ ret = int(
+ LIB_LLAISYS.llaisysQwen2ModelPrefillPacked(
+ self._model,
+ token_buf,
+ off_buf,
+ c_size_t(len(seqs)),
+ out_buf,
+ )
+ )
+ if ret != 0:
+ raise RuntimeError(f"llaisysQwen2ModelPrefillPacked failed with code {ret}")
+ return [int(out_buf[i]) for i in range(len(seqs))]
+
+ def step_packed(self, sequences: Sequence[Sequence[int]]) -> list[int]:
+ seqs = [list(s) for s in sequences]
+ if not seqs:
+ return []
+ if not hasattr(LIB_LLAISYS, "llaisysQwen2ModelStepPacked"):
+ raise RuntimeError("llaisysQwen2ModelStepPacked is unavailable in current llaisys.dll")
+ offsets = [0]
+ flat: list[int] = []
+ for s in seqs:
+ if not s:
+ raise ValueError("each packed sequence must be non-empty")
+ flat.extend(int(x) for x in s)
+ offsets.append(len(flat))
+ token_buf = (c_int64 * len(flat))(*flat)
+ off_buf = (c_int64 * len(offsets))(*offsets)
+ out_buf = (c_int64 * len(seqs))()
+ ret = int(
+ LIB_LLAISYS.llaisysQwen2ModelStepPacked(
+ self._model,
+ token_buf,
+ off_buf,
+ c_size_t(len(seqs)),
+ out_buf,
+ )
+ )
+ if ret != 0:
+ raise RuntimeError(f"llaisysQwen2ModelStepPacked failed with code {ret}")
+ return [int(out_buf[i]) for i in range(len(seqs))]
+
+ def prefill_packed_sampling(
+ self,
+ sequences: Sequence[Sequence[int]],
+ params_list: Sequence[LlaisysSamplingParams],
+ ) -> list[int]:
+ seqs = [list(s) for s in sequences]
+ if not seqs:
+ return []
+ if not hasattr(LIB_LLAISYS, "llaisysQwen2ModelPrefillPackedSampling"):
+ raise RuntimeError("llaisysQwen2ModelPrefillPackedSampling is unavailable in current llaisys.dll")
+ if len(params_list) != len(seqs):
+ raise ValueError("params_list length must match sequences length")
+ offsets = [0]
+ flat: list[int] = []
+ for s in seqs:
+ if not s:
+ raise ValueError("each packed sequence must be non-empty")
+ flat.extend(int(x) for x in s)
+ offsets.append(len(flat))
+ token_buf = (c_int64 * len(flat))(*flat)
+ off_buf = (c_int64 * len(offsets))(*offsets)
+ params_buf = (LlaisysSamplingParams * len(seqs))(*params_list)
+ out_buf = (c_int64 * len(seqs))()
+ ret = int(
+ LIB_LLAISYS.llaisysQwen2ModelPrefillPackedSampling(
+ self._model,
+ token_buf,
+ off_buf,
+ c_size_t(len(seqs)),
+ params_buf,
+ out_buf,
+ )
+ )
+ if ret != 0:
+ raise RuntimeError(f"llaisysQwen2ModelPrefillPackedSampling failed with code {ret}")
+ return [int(out_buf[i]) for i in range(len(seqs))]
+
+ def step_packed_sampling(
+ self,
+ sequences: Sequence[Sequence[int]],
+ params_list: Sequence[LlaisysSamplingParams],
+ ) -> list[int]:
+ seqs = [list(s) for s in sequences]
+ if not seqs:
+ return []
+ if not hasattr(LIB_LLAISYS, "llaisysQwen2ModelStepPackedSampling"):
+ raise RuntimeError("llaisysQwen2ModelStepPackedSampling is unavailable in current llaisys.dll")
+ if len(params_list) != len(seqs):
+ raise ValueError("params_list length must match sequences length")
+ offsets = [0]
+ flat: list[int] = []
+ for s in seqs:
+ if not s:
+ raise ValueError("each packed sequence must be non-empty")
+ flat.extend(int(x) for x in s)
+ offsets.append(len(flat))
+ token_buf = (c_int64 * len(flat))(*flat)
+ off_buf = (c_int64 * len(offsets))(*offsets)
+ params_buf = (LlaisysSamplingParams * len(seqs))(*params_list)
+ out_buf = (c_int64 * len(seqs))()
+ ret = int(
+ LIB_LLAISYS.llaisysQwen2ModelStepPackedSampling(
+ self._model,
+ token_buf,
+ off_buf,
+ c_size_t(len(seqs)),
+ params_buf,
+ out_buf,
+ )
+ )
+ if ret != 0:
+ raise RuntimeError(f"llaisysQwen2ModelStepPackedSampling failed with code {ret}")
+ return [int(out_buf[i]) for i in range(len(seqs))]
+
+ def prefill_sampling(
+ self,
+ inputs: Sequence[int],
+ top_k: int = 1,
+ top_p: float = 0.0,
+ temperature: float = 0.0,
+ seed: int = 0,
+ ) -> int:
+ tokens = list(inputs)
+ token_buf = (c_int64 * len(tokens))(*tokens)
+ params = LlaisysSamplingParams(
+ c_int(top_k),
+ c_float(top_p),
+ c_float(temperature),
+ c_uint32(seed),
+ )
+ return int(
+ LIB_LLAISYS.llaisysQwen2ModelPrefillSampling(
+ self._model,
+ token_buf,
+ c_size_t(len(tokens)),
+ byref(params),
+ )
+ )
+
+ def step_sampling(
+ self,
+ new_tokens: Sequence[int],
+ top_k: int = 1,
+ top_p: float = 0.0,
+ temperature: float = 0.0,
+ seed: int = 0,
+ ) -> int:
+ tokens = list(new_tokens)
+ token_buf = (c_int64 * len(tokens))(*tokens)
+ params = LlaisysSamplingParams(
+ c_int(top_k),
+ c_float(top_p),
+ c_float(temperature),
+ c_uint32(seed),
+ )
+ return int(
+ LIB_LLAISYS.llaisysQwen2ModelStepSampling(
+ self._model,
+ token_buf,
+ c_size_t(len(tokens)),
+ byref(params),
+ )
+ )
+
+ def infer(self, inputs: Sequence[int]) -> int:
+ warnings.warn(
+ "Qwen2.infer is deprecated; use prefill()/step() instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.prefill(inputs)
+
+ def reset_kv_cache(self):
+ LIB_LLAISYS.llaisysQwen2ModelResetKVCache(self._model)
+
+ # ===== Experimental KV block/context wrappers =====
+ def kv_context_create(self):
+ return LIB_LLAISYS.llaisysQwen2KVContextCreate(
+ llaisysDataType_t(self._meta.dtype),
+ llaisysDeviceType_t(self._device),
+ c_int(0),
+ c_size_t(self._meta.nlayer),
+ c_size_t(self._meta.nh),
+ c_size_t(self._meta.nkvh),
+ c_size_t(self._meta.dh),
+ )
+
+ def kv_context_release(self, ctx):
+ LIB_LLAISYS.llaisysQwen2KVContextRelease(ctx)
+
+ def kv_context_attach_block(self, ctx, block):
+ return int(LIB_LLAISYS.llaisysQwen2KVContextAttachBlock(ctx, block))
+
+ def kv_context_detach_all(self, ctx):
+ LIB_LLAISYS.llaisysQwen2KVContextDetachAll(ctx)
+
+ def kv_context_block_count(self, ctx) -> int:
+ return int(LIB_LLAISYS.llaisysQwen2KVContextBlockCount(ctx))
+
+ def kv_context_token_count(self, ctx) -> int:
+ return int(LIB_LLAISYS.llaisysQwen2KVContextTokenCount(ctx))
+
+ def kv_block_create(self, max_tokens: int):
+ meta = LlaisysQwen2KVBlockMeta(
+ llaisysDataType_t(self._meta.dtype),
+ c_size_t(self._meta.nlayer),
+ c_size_t(self._meta.nh),
+ c_size_t(self._meta.nkvh),
+ c_size_t(self._meta.dh),
+ c_size_t(max_tokens),
+ )
+ return LIB_LLAISYS.llaisysQwen2KVBlockCreate(
+ byref(meta),
+ llaisysDeviceType_t(self._device),
+ c_int(0),
+ )
+
+ def kv_block_retain(self, block):
+ LIB_LLAISYS.llaisysQwen2KVBlockRetain(block)
+
+ def kv_block_release(self, block):
+ LIB_LLAISYS.llaisysQwen2KVBlockRelease(block)
+
+ def kv_block_token_count(self, block) -> int:
+ return int(LIB_LLAISYS.llaisysQwen2KVBlockTokenCount(block))
+
+ def kv_block_set_token_count(self, block, used_tokens: int) -> int:
+ return int(LIB_LLAISYS.llaisysQwen2KVBlockSetTokenCount(block, c_size_t(int(used_tokens))))
+
+ def kv_block_key_tensor(self, block, layer: int):
+ return LIB_LLAISYS.llaisysQwen2KVBlockKeyTensor(block, c_size_t(int(layer)))
+
+ def kv_block_value_tensor(self, block, layer: int):
+ return LIB_LLAISYS.llaisysQwen2KVBlockValueTensor(block, c_size_t(int(layer)))
+
+ def set_kv_context(self, ctx) -> int:
+ return int(LIB_LLAISYS.llaisysQwen2ModelSetKVContext(self._model, ctx))
- # TODO: Implement generate function
+ def get_kv_context(self):
+ return LIB_LLAISYS.llaisysQwen2ModelGetKVContext(self._model)
- return []
+ def export_kv_context(self, ctx, block_tokens: int) -> int:
+ return int(
+ LIB_LLAISYS.llaisysQwen2ModelExportKVContext(
+ self._model,
+ ctx,
+ c_size_t(int(block_tokens)),
+ )
+ )
diff --git a/python/llaisys/ops.py b/python/llaisys/ops.py
index ed0180bc8..5921339b6 100644
--- a/python/llaisys/ops.py
+++ b/python/llaisys/ops.py
@@ -1,6 +1,6 @@
from .libllaisys import LIB_LLAISYS
from .tensor import Tensor
-from ctypes import c_float, c_int
+from ctypes import c_float, c_int, c_int64, c_size_t
class Ops:
@@ -50,6 +50,35 @@ def self_attention(attn_val: Tensor, q: Tensor, k: Tensor, v: Tensor, scale: flo
c_float(scale),
)
+ @staticmethod
+ def self_attention_segmented(
+ attn_val: Tensor,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ scale: float,
+ q_offsets: list[int],
+ kv_offsets: list[int],
+ ):
+ if len(q_offsets) != len(kv_offsets):
+ raise ValueError("q_offsets and kv_offsets must have same length")
+ if len(q_offsets) < 2:
+ raise ValueError("offsets must contain at least start/end")
+ if not hasattr(LIB_LLAISYS, "llaisysSelfAttentionSegmented"):
+ raise RuntimeError("llaisysSelfAttentionSegmented is unavailable in current llaisys.dll")
+ q_buf = (c_int64 * len(q_offsets))(*[int(x) for x in q_offsets])
+ kv_buf = (c_int64 * len(kv_offsets))(*[int(x) for x in kv_offsets])
+ LIB_LLAISYS.llaisysSelfAttentionSegmented(
+ attn_val.lib_tensor(),
+ q.lib_tensor(),
+ k.lib_tensor(),
+ v.lib_tensor(),
+ c_float(scale),
+ q_buf,
+ kv_buf,
+ c_size_t(len(q_offsets) - 1),
+ )
+
@staticmethod
def swiglu(out: Tensor, gate: Tensor, up: Tensor):
LIB_LLAISYS.llaisysSwiGLU(out.lib_tensor(), gate.lib_tensor(), up.lib_tensor())
diff --git a/python/llaisys/scheduler.py b/python/llaisys/scheduler.py
new file mode 100644
index 000000000..399347081
--- /dev/null
+++ b/python/llaisys/scheduler.py
@@ -0,0 +1,910 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+import logging
+import queue
+import threading
+import time
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple
+from collections import OrderedDict, deque
+
+if TYPE_CHECKING:
+ from llaisys.interfaces import IInferenceService
+
+logger = logging.getLogger(__name__)
+
+_END = object()
+
+
+@dataclass
+class InferenceTask:
+ payload: Dict[str, Any]
+ stream: bool
+ output_queue: "queue.Queue[Any]"
+ deadline_at: Optional[float]
+
+
+@dataclass
+class _ActiveTask:
+ task: InferenceTask
+ iterator: Any
+ emitted_any: bool = False
+
+
+class SchedulerQueueFullError(RuntimeError):
+ pass
+
+
+class TaskTimeoutError(RuntimeError):
+ pass
+
+
+class TaskHandle:
+ def __init__(self, output_queue: "queue.Queue[Any]") -> None:
+ self._q = output_queue
+
+ def get_result(self, timeout: Optional[float] = None) -> Dict[str, Any]:
+ while True:
+ try:
+ item = self._q.get(timeout=timeout)
+ except queue.Empty as exc:
+ raise TaskTimeoutError("task result timeout") from exc
+ if item is _END:
+ raise RuntimeError("task ended without result")
+ if isinstance(item, dict):
+ return item
+ raise RuntimeError("unexpected task result type")
+
+ def iter_stream(self, timeout: Optional[float] = None) -> Iterable[Dict[str, Any]]:
+ while True:
+ try:
+ item = self._q.get(timeout=timeout)
+ except queue.Empty as exc:
+ raise TaskTimeoutError("task stream timeout") from exc
+ if item is _END:
+ break
+ if isinstance(item, dict):
+ yield item
+ else:
+ raise RuntimeError("unexpected stream item type")
+
+
+class InferenceScheduler:
+ """In-process scheduler with per-worker queues and session stickiness."""
+
+ def __init__(
+ self,
+ services: "List[IInferenceService]",
+ queue_size: int = 128,
+ request_timeout_ms: int = 120000,
+ continuous_batching: bool = False,
+ kv_aware_routing: bool = False,
+ max_sticky_sessions: int = 10000,
+ max_batch_size: int = 8,
+ kv_memory_threshold: float = 0.0,
+ ) -> None:
+ if not services:
+ raise ValueError("services must not be empty")
+ self._services: "List[IInferenceService]" = list(services)
+ self._queue_size = max(1, int(queue_size))
+ self._request_timeout_ms = max(0, int(request_timeout_ms))
+ self._continuous_batching = bool(continuous_batching)
+ self._kv_aware_routing = bool(kv_aware_routing)
+ self._max_sticky_sessions = max(100, int(max_sticky_sessions))
+ self._max_batch_size = max(1, int(max_batch_size))
+ self._kv_memory_threshold = float(kv_memory_threshold)
+ self._queues: List["queue.Queue[Optional[InferenceTask]]"] = [
+ queue.Queue(maxsize=self._queue_size) for _ in self._services
+ ]
+ self._threads: List[threading.Thread] = []
+ self._stop = threading.Event()
+ self._lock = threading.Lock()
+ self._session_worker: "OrderedDict[str, int]" = OrderedDict()
+ self._rr = 0
+ self._packed_prefill_last_error: str = ""
+ self._metrics: Dict[str, float] = {
+ "submitted": 0.0,
+ "completed": 0.0,
+ "cancelled": 0.0,
+ "failed": 0.0,
+ "timed_out": 0.0,
+ "queue_full": 0.0,
+ "stop_requests": 0.0,
+ "batch_rounds": 0.0,
+ "batch_active_sum": 0.0,
+ "batch_last_active": 0.0,
+ "prefill_rounds": 0.0,
+ "decode_rounds": 0.0,
+ "prefill_last_active": 0.0,
+ "decode_last_active": 0.0,
+ "packed_prefill_batches": 0.0,
+ "packed_prefill_tasks": 0.0,
+ "packed_prefill_attempts": 0.0,
+ "packed_prefill_candidate_tasks": 0.0,
+ "packed_prefill_none_returns": 0.0,
+ "packed_prefill_exceptions": 0.0,
+ # KV 感知路由指标
+ "kv_aware_routing_attempts": 0.0,
+ "kv_aware_routing_hits": 0.0,
+ "kv_aware_routing_best_prefix_len_sum": 0.0,
+ # 流式批处理指标
+ "stream_batch_prefill_batches": 0.0,
+ "stream_batch_prefill_tasks": 0.0,
+ "stream_batch_decode_rounds": 0.0,
+ "stream_batch_decode_active_sum": 0.0,
+ "stream_batch_shrink_events": 0.0,
+ "stream_batch_fallback_tasks": 0.0,
+ # KV 内存流控指标
+ "kv_memory_rejected": 0.0,
+ }
+
+ def start(self) -> None:
+ if self._threads:
+ return
+ self._stop.clear()
+ for idx in range(len(self._services)):
+ t = threading.Thread(target=self._worker_loop, args=(idx,), daemon=True)
+ t.start()
+ self._threads.append(t)
+
+ def stop(self) -> None:
+ self._stop.set()
+ for q in self._queues:
+ try:
+ q.put_nowait(None)
+ except queue.Full:
+ pass
+ for t in self._threads:
+ t.join(timeout=1.0)
+ self._threads.clear()
+
+ def submit(self, payload: Dict[str, Any], stream: bool) -> TaskHandle:
+ payload = dict(payload) # shallow copy to avoid mutating caller's dict
+
+ # KV 内存感知流控:超过阈值时拒绝新请求
+ if self._kv_memory_threshold > 0:
+ try:
+ pressure = max(
+ svc.kv_pool.memory_pressure() for svc in self._services
+ )
+ except Exception:
+ pressure = 0.0
+ if pressure > self._kv_memory_threshold:
+ with self._lock:
+ self._metrics["kv_memory_rejected"] += 1.0
+ raise SchedulerQueueFullError("KV memory pressure too high")
+
+ # 自动 tokenize:如果启用了 KV 感知路由且 payload 中没有 _prompt_tokens
+ if (
+ self._kv_aware_routing
+ and "_prompt_tokens" not in payload
+ and len(self._services) > 1
+ ):
+ try:
+ # 使用第一个服务进行 tokenize(所有服务使用相同的 tokenizer)
+ svc = self._services[0]
+ if hasattr(svc, "tokenize_for_routing"):
+ tokens = svc.tokenize_for_routing(payload)
+ if tokens:
+ payload["_prompt_tokens"] = tokens
+ except Exception:
+ logger.debug("tokenize_for_routing failed, falling back to default routing", exc_info=True)
+
+ worker_idx = self._choose_worker(payload)
+
+ # 清理路由专用的内部字段,不传递给下游
+ payload.pop("_prompt_tokens", None)
+
+ out_q: "queue.Queue[Any]" = queue.Queue()
+ deadline_at: Optional[float] = None
+ if self._request_timeout_ms > 0:
+ deadline_at = time.time() + self._request_timeout_ms / 1000.0
+ task = InferenceTask(payload=payload, stream=bool(stream), output_queue=out_q, deadline_at=deadline_at)
+ try:
+ self._queues[worker_idx].put_nowait(task)
+ except queue.Full:
+ with self._lock:
+ self._metrics["queue_full"] += 1.0
+ raise SchedulerQueueFullError("scheduler queue is full")
+ with self._lock:
+ self._metrics["submitted"] += 1.0
+ return TaskHandle(out_q)
+
+ def request_stop(self, session_id: str) -> bool:
+ sid = str(session_id or "").strip()
+ if not sid:
+ return False
+ with self._lock:
+ self._metrics["stop_requests"] += 1.0
+ idx = self._session_worker.get(sid)
+ if idx is not None:
+ return bool(self._services[idx].request_stop(sid))
+ ok = False
+ for svc in self._services:
+ ok = bool(svc.request_stop(sid)) or ok
+ return ok
+
+ def kv_debug_snapshot(self, session_id: Optional[str] = None) -> Dict[str, Any]:
+ sid = str(session_id or "").strip()
+ if sid:
+ with self._lock:
+ idx = self._session_worker.get(sid)
+ if idx is not None:
+ snap = self._services[idx].kv_debug_snapshot(sid)
+ snap["worker"] = idx
+ return snap
+ for idx2, svc in enumerate(self._services):
+ snap = svc.kv_debug_snapshot(sid)
+ if snap.get("has_native_context") or snap.get("last_bind"):
+ snap["worker"] = idx2
+ return snap
+ return self._services[0].kv_debug_snapshot(sid)
+
+ merged = {
+ "session_id": None,
+ "workers": len(self._services),
+ "queue_size": self._queue_size,
+ "queues": [q.qsize() for q in self._queues],
+ "kv_pool": {
+ "contexts": 0.0,
+ "blocks": 0.0,
+ "prefix_entries": 0.0,
+ "total_bytes": 0.0,
+ "zero_ref_blocks": 0.0,
+ "shared_blocks": 0.0,
+ "total_refs": 0.0,
+ "acquire_count": 0.0,
+ "prefix_hit_count": 0.0,
+ "prefix_hit_rate": 0.0,
+ "avg_matched_tokens": 0.0,
+ },
+ }
+
+ # 共享 KV 池优化:所有 worker 共享同一个池时只查询一次
+ first_pool = getattr(self._services[0], "kv_pool", None)
+ shared_pool = first_pool is not None and all(
+ getattr(svc, "kv_pool", None) is first_pool for svc in self._services[1:]
+ )
+
+ if shared_pool:
+ snap = self._services[0].kv_debug_snapshot(None)
+ pool = snap.get("kv_pool", {})
+ for k in merged["kv_pool"]:
+ merged["kv_pool"][k] = float(pool.get(k, 0.0))
+ else:
+ hit_rate_numer = 0.0
+ hit_rate_denom = 0.0
+ matched_numer = 0.0
+ matched_denom = 0.0
+ for svc in self._services:
+ snap = svc.kv_debug_snapshot(None)
+ pool = snap.get("kv_pool", {})
+ for k in ("contexts", "blocks", "prefix_entries", "total_bytes", "zero_ref_blocks", "shared_blocks", "total_refs", "acquire_count", "prefix_hit_count"):
+ merged["kv_pool"][k] += float(pool.get(k, 0.0))
+ hit_rate_numer += float(pool.get("prefix_hit_count", 0.0))
+ hit_rate_denom += float(pool.get("acquire_count", 0.0))
+ matched_numer += float(pool.get("avg_matched_tokens", 0.0)) * float(pool.get("acquire_count", 0.0))
+ matched_denom += float(pool.get("acquire_count", 0.0))
+ merged["kv_pool"]["prefix_hit_rate"] = hit_rate_numer / hit_rate_denom if hit_rate_denom > 0 else 0.0
+ merged["kv_pool"]["avg_matched_tokens"] = matched_numer / matched_denom if matched_denom > 0 else 0.0
+ return merged
+
+ def debug_snapshot(self) -> Dict[str, Any]:
+ with self._lock:
+ metrics = dict(self._metrics)
+ packed_prefill_last_error = self._packed_prefill_last_error
+ sticky_sessions = len(self._session_worker)
+ avg_batch_active = (
+ metrics.get("batch_active_sum", 0.0) / metrics.get("batch_rounds", 1.0)
+ if metrics.get("batch_rounds", 0.0) > 0
+ else 0.0
+ )
+ kv_routing_attempts = metrics.get("kv_aware_routing_attempts", 0.0)
+ kv_routing_hit_rate = (
+ metrics.get("kv_aware_routing_hits", 0.0) / kv_routing_attempts
+ if kv_routing_attempts > 0
+ else 0.0
+ )
+ kv_routing_avg_prefix_len = (
+ metrics.get("kv_aware_routing_best_prefix_len_sum", 0.0) / metrics.get("kv_aware_routing_hits", 1.0)
+ if metrics.get("kv_aware_routing_hits", 0.0) > 0
+ else 0.0
+ )
+ # KV memory pressure snapshot
+ try:
+ kv_memory_pressure = max(
+ svc.kv_pool.memory_pressure() for svc in self._services
+ )
+ except Exception:
+ kv_memory_pressure = 0.0
+ return {
+ "workers": len(self._services),
+ "queue_size": self._queue_size,
+ "queues": [q.qsize() for q in self._queues],
+ "request_timeout_ms": self._request_timeout_ms,
+ "continuous_batching": self._continuous_batching,
+ "kv_aware_routing": self._kv_aware_routing,
+ "kv_routing_hit_rate": kv_routing_hit_rate,
+ "kv_routing_avg_prefix_len": kv_routing_avg_prefix_len,
+ "kv_memory_threshold": self._kv_memory_threshold,
+ "kv_memory_pressure": kv_memory_pressure,
+ "max_batch_size": self._max_batch_size,
+ "avg_batch_active": avg_batch_active,
+ "sticky_sessions": sticky_sessions,
+ "packed_prefill_last_error": packed_prefill_last_error,
+ "metrics": metrics,
+ }
+
+ def request_timeout_seconds(self) -> Optional[float]:
+ if self._request_timeout_ms <= 0:
+ return None
+ return self._request_timeout_ms / 1000.0
+
+ def _touch_session(self, sid: str, worker_idx: int) -> None:
+ """Record/update session->worker mapping with LRU eviction. Caller must hold self._lock."""
+ if sid in self._session_worker:
+ self._session_worker.move_to_end(sid)
+ self._session_worker[sid] = worker_idx
+ while len(self._session_worker) > self._max_sticky_sessions:
+ self._session_worker.popitem(last=False)
+
+ def _choose_worker(self, payload: Dict[str, Any]) -> int:
+ sid = str(payload.get("session_id") or payload.get("edit_from_session_id") or "").strip()
+
+ # 1. 会话粘性:已绑定的 session 优先路由到原 worker
+ with self._lock:
+ if sid and sid in self._session_worker:
+ self._session_worker.move_to_end(sid)
+ return self._session_worker[sid]
+
+ # 2. KV 感知路由:查询各 worker 的 KV 命中情况
+ # KV 感知路由是 best-effort:查询到入队之间 KV 状态可能变化,
+ # 最坏情况是路由到非最优 worker,不影响正确性。
+ prompt_tokens: Optional[Sequence[int]] = payload.get("_prompt_tokens")
+ if self._kv_aware_routing and prompt_tokens and len(self._services) > 1:
+ best_worker = -1
+ best_prefix_len = -1
+
+ # 共享 KV 池优化:所有 worker 共享同一个池时只查询一次
+ first_pool = getattr(self._services[0], "kv_pool", None)
+ shared_pool = first_pool is not None and all(
+ getattr(svc, "kv_pool", None) is first_pool for svc in self._services[1:]
+ )
+
+ if shared_pool:
+ try:
+ prefix_len = first_pool.query_prefix_len(prompt_tokens)
+ if prefix_len > 0:
+ best_prefix_len = prefix_len
+ # 共享池模式下选负载最轻的 worker
+ best_worker = min(range(len(self._queues)), key=lambda i: self._queues[i].qsize())
+ except Exception:
+ pass
+ else:
+ for idx, svc in enumerate(self._services):
+ try:
+ kv_pool = getattr(svc, "kv_pool", None)
+ if kv_pool is None:
+ continue
+ prefix_len = kv_pool.query_prefix_len(prompt_tokens)
+ if prefix_len > best_prefix_len:
+ best_prefix_len = prefix_len
+ best_worker = idx
+ except Exception:
+ # 查询失败,跳过该 worker
+ continue
+
+ with self._lock:
+ self._metrics["kv_aware_routing_attempts"] += 1.0
+ if best_prefix_len > 0:
+ self._metrics["kv_aware_routing_hits"] += 1.0
+ self._metrics["kv_aware_routing_best_prefix_len_sum"] += float(best_prefix_len)
+
+ if best_worker >= 0 and best_prefix_len > 0:
+ if sid:
+ with self._lock:
+ self._touch_session(sid, best_worker)
+ return best_worker
+
+ # 3. Fallback:hash 或轮询
+ with self._lock:
+ if sid:
+ idx = hash(sid) % len(self._services)
+ self._touch_session(sid, idx)
+ return idx
+ idx = self._rr % len(self._services)
+ self._rr = (self._rr + 1) % len(self._services)
+ return idx
+
+ def _bind_session(self, session_id: Optional[str], worker_idx: int) -> None:
+ sid = str(session_id or "").strip()
+ if not sid:
+ return
+ with self._lock:
+ self._touch_session(sid, worker_idx)
+
+ def _worker_loop(self, idx: int) -> None:
+ if self._continuous_batching:
+ self._worker_loop_continuous(idx)
+ return
+
+ svc = self._services[idx]
+ q = self._queues[idx]
+ while not self._stop.is_set():
+ task = q.get()
+ if task is None:
+ q.task_done()
+ continue
+ try:
+ if task.deadline_at is not None and time.time() > task.deadline_at:
+ with self._lock:
+ self._metrics["timed_out"] += 1.0
+ if task.stream:
+ task.output_queue.put({"error": "request timeout", "code": "timeout", "done": True})
+ else:
+ task.output_queue.put({"error": "request timeout", "code": "timeout"})
+ task.output_queue.put(_END)
+ continue
+ if task.stream:
+ try:
+ for item in svc.stream(task.payload):
+ if isinstance(item, dict):
+ self._bind_session(item.get("session_id"), idx)
+ if item.get("done") and item.get("stopped"):
+ with self._lock:
+ self._metrics["cancelled"] += 1.0
+ task.output_queue.put(item)
+ with self._lock:
+ self._metrics["completed"] += 1.0
+ except Exception as exc:
+ with self._lock:
+ self._metrics["failed"] += 1.0
+ task.output_queue.put({"error": str(exc), "done": True})
+ finally:
+ task.output_queue.put(_END)
+ else:
+ try:
+ result = svc.generate(task.payload)
+ if isinstance(result, dict):
+ self._bind_session(result.get("session_id"), idx)
+ task.output_queue.put(result)
+ with self._lock:
+ self._metrics["completed"] += 1.0
+ if isinstance(result, dict) and result.get("stopped"):
+ self._metrics["cancelled"] += 1.0
+ except Exception as exc:
+ with self._lock:
+ self._metrics["failed"] += 1.0
+ task.output_queue.put({"error": str(exc)})
+ finally:
+ task.output_queue.put(_END)
+ finally:
+ q.task_done()
+
+ def _worker_loop_continuous(self, idx: int) -> None:
+ svc = self._services[idx]
+ q = self._queues[idx]
+ # Raw tasks waiting for prefill (not yet started)
+ prefill_pending: "deque[InferenceTask]" = deque()
+ # Fallback: legacy iterator-based active tasks
+ fallback_prefill: "deque[_ActiveTask]" = deque()
+ fallback_decode: List[_ActiveTask] = []
+ # Batch-driven decode state (from prepare_batch)
+ batch_state: Optional[Any] = None
+ batch_tasks: List[InferenceTask] = [] # parallel to batch_state.sequences
+
+ def _append_from_queue(block: bool) -> None:
+ while True:
+ try:
+ task = q.get(block=block, timeout=0.1 if block else 0.0)
+ except queue.Empty:
+ return
+ if task is None:
+ q.task_done()
+ return
+ prefill_pending.append(task)
+ q.task_done()
+ block = False
+
+ def _emit_chunk(task: InferenceTask, chunk: dict) -> None:
+ """Send a stream chunk or accumulate for non-stream."""
+ task.output_queue.put(chunk)
+
+ def _emit_final_stream(task: InferenceTask, context_id: str,
+ finish_reason: str, prompt_len: int,
+ gen_len: int, stopped: bool) -> None:
+ usage = {
+ "prompt_tokens": prompt_len,
+ "completion_tokens": gen_len,
+ "total_tokens": prompt_len + gen_len,
+ }
+ from llaisys.server import _wrap_chunk
+ chunk = _wrap_chunk(context_id, None, finish_reason, usage=usage, stopped=stopped)
+ task.output_queue.put(chunk)
+ task.output_queue.put(_END)
+
+ def _emit_final_non_stream(task: InferenceTask, context_id: str,
+ content: str, finish_reason: str,
+ prompt_len: int, gen_len: int,
+ stopped: bool) -> None:
+ usage = {
+ "prompt_tokens": prompt_len,
+ "completion_tokens": gen_len,
+ "total_tokens": prompt_len + gen_len,
+ }
+ from llaisys.server import _wrap_completion
+ result = _wrap_completion(context_id, content, finish_reason, usage, stopped=stopped)
+ task.output_queue.put(result)
+ task.output_queue.put(_END)
+
+ # --- Fallback helpers (legacy iterator path) ---
+ def _step_once(state: _ActiveTask) -> str:
+ task = state.task
+ it = state.iterator
+ if task.deadline_at is not None and time.time() > task.deadline_at:
+ with self._lock:
+ self._metrics["timed_out"] += 1.0
+ if task.stream:
+ task.output_queue.put({"error": "request timeout", "code": "timeout", "done": True})
+ else:
+ task.output_queue.put({"error": "request timeout", "code": "timeout"})
+ task.output_queue.put(_END)
+ return "done"
+ try:
+ item = next(it)
+ if isinstance(item, dict):
+ self._bind_session(item.get("session_id"), idx)
+
+ def _is_final(d: dict) -> bool:
+ if d.get("done"):
+ return True
+ choices = d.get("choices")
+ if choices and isinstance(choices, list) and len(choices) > 0:
+ if choices[0].get("finish_reason") is not None:
+ return True
+ return False
+
+ def _is_stopped(d: dict) -> bool:
+ if d.get("stopped"):
+ return True
+ choices = d.get("choices")
+ if choices and isinstance(choices, list) and len(choices) > 0:
+ if choices[0].get("finish_reason") == "stop":
+ return True
+ return False
+
+ if task.stream:
+ if not isinstance(item, dict):
+ raise RuntimeError("stream item must be dict")
+ task.output_queue.put(item)
+ state.emitted_any = True
+ if _is_final(item):
+ with self._lock:
+ self._metrics["completed"] += 1.0
+ if _is_stopped(item):
+ self._metrics["cancelled"] += 1.0
+ task.output_queue.put(_END)
+ return "done"
+ return "keep"
+ if isinstance(item, dict) and _is_final(item):
+ if item.get("error"):
+ with self._lock:
+ self._metrics["failed"] += 1.0
+ task.output_queue.put({"error": str(item.get("error"))})
+ else:
+ result = dict(item)
+ choices = result.get("choices")
+ if choices and isinstance(choices, list) and len(choices) > 0:
+ c = dict(choices[0])
+ acc = getattr(state, "accumulated_content", "")
+ delta = c.pop("delta", {})
+ final_content = acc + delta.get("content", "")
+ c["message"] = {"role": "assistant", "content": final_content}
+ result["choices"] = [c]
+ if result.get("object") == "chat.completion.chunk":
+ result["object"] = "chat.completion"
+ task.output_queue.put(result)
+ with self._lock:
+ self._metrics["completed"] += 1.0
+ if _is_stopped(item):
+ self._metrics["cancelled"] += 1.0
+ task.output_queue.put(_END)
+ return "done"
+ choices = item.get("choices")
+ if choices and isinstance(choices, list) and len(choices) > 0:
+ delta = choices[0].get("delta", {})
+ content = delta.get("content", "")
+ if content:
+ state.accumulated_content = getattr(state, "accumulated_content", "") + content
+ return "keep"
+ except StopIteration:
+ with self._lock:
+ self._metrics["failed"] += 1.0
+ if task.stream:
+ task.output_queue.put({"error": "stream ended unexpectedly", "done": True})
+ else:
+ task.output_queue.put({"error": "task ended unexpectedly"})
+ task.output_queue.put(_END)
+ return "done"
+ except Exception as exc:
+ with self._lock:
+ self._metrics["failed"] += 1.0
+ if task.stream:
+ task.output_queue.put({"error": str(exc), "done": True})
+ else:
+ task.output_queue.put({"error": str(exc)})
+ task.output_queue.put(_END)
+ return "done"
+
+ while not self._stop.is_set():
+ has_work = (
+ prefill_pending or fallback_prefill or fallback_decode
+ or (batch_state is not None)
+ )
+ if not has_work:
+ _append_from_queue(block=True)
+ if not prefill_pending:
+ continue
+ else:
+ _append_from_queue(block=False)
+
+ with self._lock:
+ total_active = (
+ len(prefill_pending) + len(fallback_prefill) + len(fallback_decode)
+ + (len([s for s in batch_state.sequences if not s.finished]) if batch_state else 0)
+ )
+ self._metrics["batch_rounds"] += 1.0
+ self._metrics["batch_active_sum"] += float(total_active)
+ self._metrics["batch_last_active"] = float(total_active)
+ self._metrics["prefill_last_active"] = float(len(prefill_pending) + len(fallback_prefill))
+ decode_count = len(fallback_decode) + (
+ len([s for s in batch_state.sequences if not s.finished]) if batch_state else 0
+ )
+ self._metrics["decode_last_active"] = float(decode_count)
+
+ # ============================================================
+ # P stage: try batch prefill for pending tasks
+ # ============================================================
+ if prefill_pending and batch_state is None:
+ with self._lock:
+ self._metrics["prefill_rounds"] += 1.0
+
+ # Collect candidates up to max_batch_size
+ batch_candidates: List[InferenceTask] = []
+ remaining: "deque[InferenceTask]" = deque()
+ decode_active_count = len(fallback_decode)
+ slots = self._max_batch_size - decode_active_count
+
+ while prefill_pending and len(batch_candidates) < slots:
+ task = prefill_pending.popleft()
+ # Check deadline
+ if task.deadline_at is not None and time.time() > task.deadline_at:
+ with self._lock:
+ self._metrics["timed_out"] += 1.0
+ if task.stream:
+ task.output_queue.put({"error": "request timeout", "code": "timeout", "done": True})
+ else:
+ task.output_queue.put({"error": "request timeout", "code": "timeout"})
+ task.output_queue.put(_END)
+ continue
+ batch_candidates.append(task)
+
+ if not batch_candidates:
+ pass # all timed out
+ elif len(batch_candidates) >= 1 and hasattr(svc, "prepare_batch"):
+ # Try batch path
+ try:
+ payloads = [t.payload for t in batch_candidates]
+ result = svc.prepare_batch(payloads)
+ except Exception as exc:
+ logger.debug("prepare_batch failed: %s", exc, exc_info=True)
+ result = None
+
+ if result is not None:
+ batch_state = result
+ batch_tasks = list(batch_candidates)
+ with self._lock:
+ self._metrics["stream_batch_prefill_batches"] += 1.0
+ self._metrics["stream_batch_prefill_tasks"] += float(len(batch_candidates))
+
+ # Emit first token chunks for each sequence
+ from llaisys.server import _wrap_chunk
+ for i, seq in enumerate(batch_state.sequences):
+ task = batch_tasks[i]
+ self._bind_session(seq.context_id, idx)
+ if seq.filtered_text and task.stream:
+ chunk = _wrap_chunk(seq.context_id, seq.filtered_text, None)
+ task.output_queue.put(chunk)
+ # If already finished after prefill
+ if seq.finished:
+ if task.stream:
+ _emit_final_stream(
+ task, seq.context_id, seq.finish_reason or "stop",
+ len(seq.prompt_ids), len(seq.generated_ids),
+ stopped=bool(seq.cancel_event and seq.cancel_event.is_set()),
+ )
+ else:
+ _emit_final_non_stream(
+ task, seq.context_id, seq.filtered_text,
+ seq.finish_reason or "stop",
+ len(seq.prompt_ids), len(seq.generated_ids),
+ stopped=bool(seq.cancel_event and seq.cancel_event.is_set()),
+ )
+ with self._lock:
+ self._metrics["completed"] += 1.0
+ svc.finalize_sequence(batch_state, i)
+ # If all finished after prefill, clear batch
+ if all(s.finished for s in batch_state.sequences):
+ batch_state = None
+ batch_tasks = []
+ else:
+ # Fallback: push to legacy iterator path
+ with self._lock:
+ self._metrics["stream_batch_fallback_tasks"] += float(len(batch_candidates))
+ for task in batch_candidates:
+ try:
+ it = svc.stream(task.payload)
+ fallback_prefill.append(_ActiveTask(task=task, iterator=it))
+ except Exception as exc:
+ if task.stream:
+ task.output_queue.put({"error": str(exc), "done": True})
+ else:
+ task.output_queue.put({"error": str(exc)})
+ task.output_queue.put(_END)
+ with self._lock:
+ self._metrics["failed"] += 1.0
+ else:
+ # No prepare_batch available, use legacy path
+ with self._lock:
+ self._metrics["stream_batch_fallback_tasks"] += float(len(batch_candidates))
+ for task in batch_candidates:
+ try:
+ it = svc.stream(task.payload)
+ fallback_prefill.append(_ActiveTask(task=task, iterator=it))
+ except Exception as exc:
+ if task.stream:
+ task.output_queue.put({"error": str(exc), "done": True})
+ else:
+ task.output_queue.put({"error": str(exc)})
+ task.output_queue.put(_END)
+ with self._lock:
+ self._metrics["failed"] += 1.0
+
+ # Legacy P stage: step fallback prefill tasks one at a time
+ if fallback_prefill:
+ with self._lock:
+ if not prefill_pending:
+ self._metrics["prefill_rounds"] += 1.0
+
+ # Try packed prefill for non-stream fallback tasks
+ packed_candidates: List[_ActiveTask] = []
+ for state in fallback_prefill:
+ if state.task.stream:
+ continue
+ packed_candidates.append(state)
+ if len(packed_candidates) >= self._max_batch_size:
+ break
+ if len(packed_candidates) >= 2 and (
+ hasattr(svc, "generate_packed_non_stream") or hasattr(svc, "generate_packed_once")
+ ):
+ packed_exception = False
+ with self._lock:
+ self._metrics["packed_prefill_attempts"] += 1.0
+ self._metrics["packed_prefill_candidate_tasks"] += float(len(packed_candidates))
+ try:
+ packed_payloads = [st.task.payload for st in packed_candidates]
+ if hasattr(svc, "generate_packed_non_stream"):
+ packed_results = svc.generate_packed_non_stream(packed_payloads)
+ else:
+ packed_results = svc.generate_packed_once(packed_payloads)
+ except Exception as exc:
+ packed_exception = True
+ with self._lock:
+ self._metrics["packed_prefill_exceptions"] += 1.0
+ self._packed_prefill_last_error = str(exc)
+ packed_results = None
+ if isinstance(packed_results, list) and len(packed_results) == len(packed_candidates):
+ packed_ids = {id(st) for st in packed_candidates}
+ fallback_prefill = deque([st for st in fallback_prefill if id(st) not in packed_ids])
+ for st, result in zip(packed_candidates, packed_results):
+ st.task.output_queue.put(result)
+ st.task.output_queue.put(_END)
+ with self._lock:
+ self._metrics["completed"] += float(len(packed_candidates))
+ self._metrics["packed_prefill_batches"] += 1.0
+ self._metrics["packed_prefill_tasks"] += float(len(packed_candidates))
+ self._packed_prefill_last_error = ""
+ # Skip single step below if we consumed all
+ if not fallback_prefill:
+ pass
+ elif not packed_exception:
+ with self._lock:
+ self._metrics["packed_prefill_none_returns"] += 1.0
+
+ if fallback_prefill:
+ state = fallback_prefill.popleft()
+ status = _step_once(state)
+ if status == "keep":
+ fallback_decode.append(state)
+
+ # ============================================================
+ # D stage: batch decode
+ # ============================================================
+ if batch_state is not None:
+ active_before = len([s for s in batch_state.sequences if not s.finished])
+ with self._lock:
+ self._metrics["decode_rounds"] += 1.0
+ self._metrics["stream_batch_decode_rounds"] += 1.0
+ self._metrics["stream_batch_decode_active_sum"] += float(active_before)
+
+ try:
+ step_results = svc.step_batch(batch_state)
+ except Exception as exc:
+ logger.debug("step_batch failed: %s", exc, exc_info=True)
+ # Mark all active as failed
+ for i, seq in enumerate(batch_state.sequences):
+ if not seq.finished:
+ seq.finished = True
+ task = batch_tasks[i]
+ if task.stream:
+ task.output_queue.put({"error": str(exc), "done": True})
+ else:
+ task.output_queue.put({"error": str(exc)})
+ task.output_queue.put(_END)
+ with self._lock:
+ self._metrics["failed"] += 1.0
+ batch_state = None
+ batch_tasks = []
+ step_results = None
+
+ if step_results is not None:
+ from llaisys.server import _wrap_chunk
+ for sr in step_results:
+ task = batch_tasks[sr.seq_index]
+ seq = batch_state.sequences[sr.seq_index]
+
+ if sr.delta_text and task.stream:
+ chunk = _wrap_chunk(seq.context_id, sr.delta_text, None)
+ task.output_queue.put(chunk)
+
+ if sr.finished:
+ if task.stream:
+ _emit_final_stream(
+ task, seq.context_id, sr.finish_reason or "stop",
+ len(seq.prompt_ids), len(seq.generated_ids),
+ stopped=sr.stopped,
+ )
+ else:
+ _emit_final_non_stream(
+ task, seq.context_id, seq.filtered_text,
+ sr.finish_reason or "stop",
+ len(seq.prompt_ids), len(seq.generated_ids),
+ stopped=sr.stopped,
+ )
+ with self._lock:
+ self._metrics["completed"] += 1.0
+ if sr.stopped:
+ self._metrics["cancelled"] += 1.0
+ svc.finalize_sequence(batch_state, sr.seq_index)
+
+ # Check for shrink events
+ active_after = len([s for s in batch_state.sequences if not s.finished])
+ if active_after < active_before and active_after > 0:
+ with self._lock:
+ self._metrics["stream_batch_shrink_events"] += 1.0
+
+ # Clear batch if all done
+ if all(s.finished for s in batch_state.sequences):
+ batch_state = None
+ batch_tasks = []
+
+ # Legacy D stage: iterate fallback decode tasks
+ if fallback_decode:
+ with self._lock:
+ self._metrics["decode_rounds"] += 1.0
+ next_decode: List[_ActiveTask] = []
+ for state in fallback_decode:
+ status = _step_once(state)
+ if status == "keep":
+ next_decode.append(state)
+ fallback_decode = next_decode
diff --git a/python/llaisys/server.py b/python/llaisys/server.py
new file mode 100644
index 000000000..6c73af966
--- /dev/null
+++ b/python/llaisys/server.py
@@ -0,0 +1,1102 @@
+from __future__ import annotations
+
+import argparse
+import json
+import re
+import threading
+from dataclasses import dataclass, field
+from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+from urllib.parse import parse_qs, urlparse
+
+import llaisys
+from llaisys.interfaces import IInferenceService
+from llaisys.kv_cache_pool import KVCachePool
+from llaisys.kv_runtime_bridge import KVRuntimeBridge
+from llaisys.libllaisys import LlaisysSamplingParams
+from llaisys.models import Qwen2
+from llaisys.scheduler import InferenceScheduler, SchedulerQueueFullError, TaskTimeoutError
+
+
+# ---------------------------------------------------------------------------
+# Streaming batch data structures
+# ---------------------------------------------------------------------------
+
+@dataclass
+class BatchSequenceState:
+ index: int
+ context_id: str
+ messages: List[Dict[str, str]]
+ prompt_ids: List[int]
+ generated_ids: List[int] = field(default_factory=list)
+ filtered_text: str = ""
+ max_new_tokens: int = 128
+ sampling: Dict[str, Any] = field(default_factory=dict)
+ sampling_params: Optional[LlaisysSamplingParams] = None
+ use_sampling: bool = False
+ cancel_event: Optional[threading.Event] = None
+ finished: bool = False
+ finish_reason: Optional[str] = None
+
+
+@dataclass
+class BatchState:
+ sequences: List[BatchSequenceState]
+ any_sampling: bool
+ eos_token: int
+
+
+@dataclass
+class StepResult:
+ seq_index: int
+ delta_text: str
+ finished: bool
+ finish_reason: Optional[str]
+ stopped: bool = False
+from llaisys.session_manager import SessionManager
+
+
+def _wrap_completion(
+ session_id: str,
+ content: str,
+ finish_reason: str,
+ usage: Dict[str, int],
+ stopped: bool = False,
+) -> Dict[str, Any]:
+ result: Dict[str, Any] = {
+ "id": f"chatcmpl-{session_id}",
+ "object": "chat.completion",
+ "model": "qwen2",
+ "choices": [
+ {
+ "index": 0,
+ "message": {"role": "assistant", "content": content},
+ "finish_reason": finish_reason,
+ }
+ ],
+ "usage": usage,
+ "session_id": session_id,
+ }
+ if stopped:
+ result["stopped"] = True
+ return result
+
+
+def _wrap_chunk(
+ session_id: str,
+ delta_content: Optional[str],
+ finish_reason: Optional[str],
+ usage: Optional[Dict[str, int]] = None,
+ stopped: bool = False,
+) -> Dict[str, Any]:
+ delta: Dict[str, str] = {}
+ if delta_content is not None:
+ delta["content"] = delta_content
+ chunk: Dict[str, Any] = {
+ "id": f"chatcmpl-{session_id}",
+ "object": "chat.completion.chunk",
+ "model": "qwen2",
+ "choices": [
+ {
+ "index": 0,
+ "delta": delta,
+ "finish_reason": finish_reason,
+ }
+ ],
+ "session_id": session_id,
+ }
+ if usage is not None:
+ chunk["usage"] = usage
+ if stopped:
+ chunk["stopped"] = True
+ return chunk
+
+
+def _wrap_error(message: str, error_type: str = "server_error", code: str = "") -> Dict[str, Any]:
+ err: Dict[str, Any] = {"error": {"message": message, "type": error_type}}
+ if code:
+ err["error"]["code"] = code
+ return err
+
+
+class ChatService(IInferenceService):
+ def __init__(
+ self,
+ model: Qwen2,
+ tokenizer: llaisys.Tokenizer,
+ model_path: Optional[str] = None,
+ enable_kv_runtime_reuse: bool = False,
+ block_size: int = 64,
+ max_blocks: int = 4096,
+ max_bytes: int = 256 * 1024 * 1024,
+ model_lock: Optional[threading.RLock] = None,
+ kv_pool: Optional[KVCachePool] = None,
+ kv_bridge: Optional[KVRuntimeBridge] = None,
+ ) -> None:
+ self.model = model
+ self.tokenizer = tokenizer
+ self._enable_kv_runtime_reuse = bool(enable_kv_runtime_reuse)
+ # RLock allows cooperative iterator-level scheduling in continuous-batching mode.
+ self._model_lock = model_lock if model_lock is not None else threading.RLock()
+
+ # Delegated components
+ self._session_mgr = SessionManager()
+ self._kv_bridge = kv_bridge if kv_bridge is not None else KVRuntimeBridge(model, enabled=enable_kv_runtime_reuse)
+ self._kv_pool = kv_pool if kv_pool is not None else KVCachePool(
+ block_size=block_size,
+ max_blocks=max_blocks,
+ max_bytes=max_bytes,
+ )
+ self._active_tokens: List[int] = []
+
+ # Text processing
+ self._chat_template_tokenizer = self._init_chat_template_tokenizer(model_path)
+ self._filter_tokens = ("<|end_of_sentence|>",)
+ self._filter_patterns = [
+ re.compile(r"<\s*\|\s*end_of_sentence\s*\|\s*>", re.IGNORECASE),
+ re.compile(r"<\s*\|[^>]*\|\s*>"),
+ re.compile(r"<\s*[\|\uFF5C][^>]*[\|\uFF5C]\s*>"),
+ re.compile(
+ r"<\s*[\|\uFF5C]\s*end[\s_\u2581]*of[\s_\u2581]*sentence\s*[\|\uFF5C]\s*>",
+ re.IGNORECASE,
+ ),
+ ]
+
+ @property
+ def kv_pool(self) -> KVCachePool:
+ """暴露 KVCache 池给调度器查询"""
+ return self._kv_pool
+
+ def tokenize_for_routing(self, payload: Dict[str, Any]) -> Optional[List[int]]:
+ """为 KV 感知路由进行轻量级 tokenize
+
+ 尝试从 payload 构建 prompt 并 tokenize,用于调度器查询 KV 命中。
+ 失败时返回 None,不影响正常请求处理。
+
+ Args:
+ payload: 请求参数
+
+ Returns:
+ token ids 列表,或 None(如果无法 tokenize)
+ """
+ try:
+ # 尝试提取 messages
+ messages = payload.get("messages")
+ prompt_text = payload.get("prompt")
+ system_prompt = payload.get("system_prompt")
+
+ if messages is not None:
+ if not isinstance(messages, list):
+ return None
+ prompt = self._render_prompt(list(messages), str(system_prompt) if system_prompt else None)
+ elif prompt_text is not None:
+ # 简单 prompt,尝试获取历史
+ session_id = str(payload.get("session_id") or "").strip()
+ history = self._session_mgr.get_messages(session_id)
+ history.append({"role": "user", "content": str(prompt_text)})
+ prompt = self._render_prompt(history, str(system_prompt) if system_prompt else None)
+ else:
+ return None
+
+ return self.tokenizer.encode(prompt)
+ except Exception:
+ return None
+
+ @staticmethod
+ def _init_chat_template_tokenizer(model_path: Optional[str]):
+ if not model_path:
+ return None
+ try:
+ from transformers import AutoTokenizer
+ except Exception:
+ return None
+ try:
+ return AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ except Exception:
+ return None
+
+ def _postprocess_text(self, text: str) -> str:
+ for token in self._filter_tokens:
+ text = text.replace(token, "")
+ for pattern in self._filter_patterns:
+ text = pattern.sub("", text)
+ return text
+
+ def request_stop(self, context_id: str) -> bool:
+ return self._session_mgr.request_stop(context_id)
+
+ def kv_debug_snapshot(self, session_id: Optional[str] = None) -> Dict[str, Any]:
+ snapshot = self._kv_bridge.debug_snapshot(session_id)
+ snapshot["kv_pool"] = self._kv_pool.snapshot_stats()
+ return snapshot
+
+ def _render_prompt(self, messages: List[Dict[str, str]], system_prompt: Optional[str]) -> str:
+ templated_messages: List[Dict[str, str]] = []
+ if system_prompt:
+ templated_messages.append({"role": "system", "content": str(system_prompt)})
+ templated_messages.extend(messages)
+
+ if self._chat_template_tokenizer is not None:
+ try:
+ return self._chat_template_tokenizer.apply_chat_template(
+ templated_messages,
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+ except Exception:
+ pass
+ return Qwen2.build_prompt(
+ messages,
+ system_prompt=str(system_prompt) if system_prompt else None,
+ add_generation_prompt=True,
+ )
+
+ def _eos_token(self) -> int:
+ eos = getattr(self.model, "_meta", None)
+ if eos is None:
+ return -1
+ end_token = getattr(eos, "end_token", -1)
+ return int(getattr(end_token, "value", end_token))
+
+ def _decode_next(
+ self,
+ token_ids: List[int],
+ use_sampling: bool,
+ sampling: Dict[str, Any],
+ ) -> int:
+ top_k = int(sampling.get("top_k", 1))
+ top_p = float(sampling.get("top_p", 0.0))
+ temperature = float(sampling.get("temperature", 0.0))
+ seed = int(sampling.get("seed", 0))
+ if use_sampling:
+ return int(
+ self.model.step_sampling(
+ token_ids,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ seed=seed,
+ )
+ )
+ return int(self.model.step(token_ids))
+
+ def _prefill_next(
+ self,
+ prompt_ids: List[int],
+ use_sampling: bool,
+ sampling: Dict[str, Any],
+ ) -> int:
+ top_k = int(sampling.get("top_k", 1))
+ top_p = float(sampling.get("top_p", 0.0))
+ temperature = float(sampling.get("temperature", 0.0))
+ seed = int(sampling.get("seed", 0))
+ if use_sampling:
+ return int(
+ self.model.prefill_sampling(
+ prompt_ids,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ seed=seed,
+ )
+ )
+ return int(self.model.prefill(prompt_ids))
+
+ def _iter_generate_ids(
+ self,
+ prompt_ids: List[int],
+ max_new_tokens: int,
+ sampling: Dict[str, Any],
+ prefix_len: int,
+ cancel_event: threading.Event,
+ ) -> Iterable[int]:
+ mode = str(sampling.get("mode", "")).strip().lower()
+ top_k = int(sampling.get("top_k", 1))
+ top_p = float(sampling.get("top_p", 0.0))
+ temperature = float(sampling.get("temperature", 0.0))
+ if mode == "argmax":
+ use_sampling = False
+ elif mode == "sample":
+ use_sampling = True
+ else:
+ use_sampling = temperature > 0.0 or top_k > 1 or top_p > 0.0
+
+ if cancel_event.is_set():
+ return
+
+ can_reuse_active_prefix = (
+ self._enable_kv_runtime_reuse
+ and prefix_len > 0
+ and len(self._active_tokens) == prefix_len
+ and self._active_tokens[:prefix_len] == prompt_ids[:prefix_len]
+ and len(prompt_ids) > prefix_len
+ )
+ if can_reuse_active_prefix:
+ next_token = self._decode_next(prompt_ids[prefix_len:], use_sampling, sampling)
+ self._active_tokens = list(prompt_ids)
+ else:
+ self.model.reset_kv_cache()
+ next_token = self._prefill_next(prompt_ids, use_sampling, sampling)
+ self._active_tokens = list(prompt_ids)
+
+ if next_token < 0:
+ return
+
+ eos = self._eos_token()
+ yield next_token
+ self._active_tokens.append(next_token)
+ for _ in range(max_new_tokens - 1):
+ if cancel_event.is_set():
+ break
+ if eos >= 0 and next_token == eos:
+ break
+ next_token = self._decode_next([next_token], use_sampling, sampling)
+ if next_token < 0:
+ break
+ yield next_token
+ self._active_tokens.append(next_token)
+
+ def _prepare_request(self, payload: Dict[str, Any]) -> Tuple[str, List[Dict[str, str]], List[int], Dict[str, Any], int]:
+ system_prompt = payload.get("system_prompt")
+ # Accept OpenAI's max_tokens as alias; prefer it over max_new_tokens
+ if "max_tokens" in payload:
+ max_new_tokens = int(payload["max_tokens"])
+ else:
+ max_new_tokens = int(payload.get("max_new_tokens", 128))
+ # model field accepted and ignored
+ sampling = {
+ "mode": payload.get("sampling"),
+ "top_k": payload.get("top_k", 1),
+ "top_p": payload.get("top_p", 0.0),
+ "temperature": payload.get("temperature", 0.0),
+ "seed": payload.get("seed", 0),
+ }
+
+ context_id, messages = self._session_mgr.extract_messages(payload)
+ prompt = self._render_prompt(messages, str(system_prompt) if system_prompt else None)
+ prompt_ids = self.tokenizer.encode(prompt)
+ return context_id, messages, prompt_ids, sampling, max_new_tokens
+
+ def generate_packed_non_stream(self, payloads: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]:
+ """Best-effort packed non-stream path (greedy + sampling).
+
+ Current safe scope:
+ - non-stream requests only
+ - greedy and sampling requests (mixed batches supported)
+ - no history-edit branching fields
+
+ When any request uses sampling, the batch is routed through the
+ packed-sampling C API. If that API is unavailable (old DLL), sampling
+ requests fall back to ``None`` so the scheduler handles them one by one.
+ Pure-greedy batches still use the original fast ``prefill_packed`` path.
+ """
+ if not payloads:
+ return []
+ if not hasattr(self.model, "prefill_packed") or not hasattr(self.model, "step_packed"):
+ return None
+
+ prepared: List[Tuple[str, List[Dict[str, str]], List[int], Dict[str, Any], int]] = []
+ any_sampling = False
+ sampling_params_list: List[LlaisysSamplingParams] = []
+ for payload in payloads:
+ if payload.get("stream", False):
+ return None
+ # History editing introduces branch semantics; keep packed path conservative for now.
+ if payload.get("edit_from_session_id"):
+ return None
+ try:
+ context_id, messages, prompt_ids, sampling, max_new_tokens = self._prepare_request(payload)
+ except Exception:
+ return None
+ if max_new_tokens <= 0:
+ return None
+ mode = str(sampling.get("mode", "")).strip().lower()
+ top_k = int(sampling.get("top_k", 1))
+ top_p = float(sampling.get("top_p", 0.0))
+ temperature = float(sampling.get("temperature", 0.0))
+ seed = int(sampling.get("seed", 0))
+ if mode == "argmax":
+ use_sampling = False
+ elif mode == "sample":
+ use_sampling = True
+ else:
+ use_sampling = temperature > 0.0 or top_k > 1 or top_p > 0.0
+ if use_sampling:
+ any_sampling = True
+ sampling_params_list.append(LlaisysSamplingParams(
+ top_k=top_k, top_p=top_p,
+ temperature=temperature, seed=seed,
+ ))
+ prepared.append((context_id, messages, prompt_ids, sampling, max_new_tokens))
+
+ # If any request needs sampling, check for the packed-sampling API.
+ # Fall back to None (single-request path) when the new DLL is absent.
+ if any_sampling:
+ if not hasattr(self.model, "prefill_packed_sampling") or not hasattr(self.model, "step_packed_sampling"):
+ return None
+
+ prompts = [it[2] for it in prepared]
+ generated_all: List[List[int]] = [[] for _ in prepared]
+ last_step_inputs: List[int] = [int(p[-1]) if p else 0 for p in prompts]
+ max_new_tokens_list = [int(it[4]) for it in prepared]
+ eos = self._eos_token()
+ with self._model_lock:
+ self.model.reset_kv_cache()
+ if any_sampling:
+ next_tokens = self.model.prefill_packed_sampling(prompts, sampling_params_list)
+ else:
+ next_tokens = self.model.prefill_packed(prompts)
+ if len(next_tokens) != len(prepared):
+ return None
+ for i, tok in enumerate(next_tokens):
+ t = int(tok)
+ if t >= 0:
+ generated_all[i].append(t)
+ last_step_inputs[i] = t
+ # Continue decode rounds for unfinished requests (dynamic shrinking).
+ while True:
+ active_indices: List[int] = []
+ decode_inputs: List[List[int]] = []
+ active_sp: List[LlaisysSamplingParams] = []
+ for i in range(len(generated_all)):
+ gen = generated_all[i]
+ if not gen:
+ continue
+ if len(gen) >= max_new_tokens_list[i]:
+ continue
+ if eos >= 0 and gen[-1] == eos:
+ continue
+ active_indices.append(i)
+ decode_inputs.append([int(last_step_inputs[i])])
+ active_sp.append(sampling_params_list[i])
+ if not active_indices:
+ break
+ if any_sampling:
+ step_tokens = self.model.step_packed_sampling(decode_inputs, active_sp)
+ else:
+ step_tokens = self.model.step_packed(decode_inputs)
+ if len(step_tokens) != len(active_indices):
+ return None
+ for j, tok in enumerate(step_tokens):
+ ai = active_indices[j]
+ t = int(tok)
+ if t >= 0:
+ generated_all[ai].append(t)
+ last_step_inputs[ai] = t
+
+ out: List[Dict[str, Any]] = []
+ for i, (context_id, messages, prompt_ids, _sampling, _max_new_tokens) in enumerate(prepared):
+ generated_ids = list(generated_all[i])
+ response_text = self._postprocess_text(self.tokenizer.decode(generated_ids))
+ messages2 = list(messages)
+ messages2.append({"role": "assistant", "content": response_text})
+ self._session_mgr.save_messages(context_id, messages2)
+ self._session_mgr.clear_stop(context_id)
+ usage = {
+ "prompt_tokens": len(prompt_ids),
+ "completion_tokens": len(generated_ids),
+ "total_tokens": len(prompt_ids) + len(generated_ids),
+ }
+ hit_limit = len(generated_ids) >= _max_new_tokens
+ finish_reason = "length" if hit_limit else "stop"
+ out.append(_wrap_completion(context_id, response_text, finish_reason, usage))
+ return out
+
+ # Backward-compatible alias used by scheduler tests/mocks.
+ def generate_packed_once(self, payloads: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]:
+ return self.generate_packed_non_stream(payloads)
+
+ def generate(self, payload: Dict[str, Any]) -> Dict[str, Any]:
+ context_id, messages, prompt_ids, sampling, max_new_tokens = self._prepare_request(payload)
+ cancel_event = self._session_mgr.get_cancel_event(context_id)
+ self._session_mgr.clear_stop(context_id)
+
+ with self._model_lock:
+ acquire = self._kv_pool.acquire_context(context_id, prompt_ids)
+ self._kv_bridge.bind_for_request(context_id, prompt_ids, acquire.prefix_len)
+ generated_ids: List[int] = []
+ try:
+ for token_id in self._iter_generate_ids(
+ prompt_ids=prompt_ids,
+ max_new_tokens=max_new_tokens,
+ sampling=sampling,
+ prefix_len=acquire.prefix_len,
+ cancel_event=cancel_event,
+ ):
+ generated_ids.append(int(token_id))
+ cancelled = cancel_event.is_set()
+ if cancelled:
+ self._active_tokens = list(prompt_ids)
+ self._kv_pool.update_context(context_id, prompt_ids)
+ else:
+ self._kv_pool.update_context(context_id, self._active_tokens)
+ self._kv_bridge.export_after_request(
+ context_id, self._active_tokens, self._kv_pool.block_size
+ )
+ except Exception:
+ self._kv_pool.release_context(context_id)
+ self._kv_bridge.release(context_id)
+ raise
+
+ response_text = self._postprocess_text(self.tokenizer.decode(generated_ids))
+ usage = {
+ "prompt_tokens": len(prompt_ids),
+ "completion_tokens": len(generated_ids),
+ "total_tokens": len(prompt_ids) + len(generated_ids),
+ }
+ if cancel_event.is_set():
+ self._session_mgr.clear_stop(context_id)
+ return _wrap_completion(context_id, response_text, "stop", usage, stopped=True)
+ hit_limit = len(generated_ids) >= max_new_tokens
+ finish_reason = "length" if hit_limit else "stop"
+ messages = list(messages)
+ messages.append({"role": "assistant", "content": response_text})
+ self._session_mgr.save_messages(context_id, messages)
+ self._session_mgr.clear_stop(context_id)
+ return _wrap_completion(context_id, response_text, finish_reason, usage)
+
+ def stream(self, payload: Dict[str, Any]) -> Iterable[Dict[str, Any]]:
+ context_id, messages, prompt_ids, sampling, max_new_tokens = self._prepare_request(payload)
+ cancel_event = self._session_mgr.get_cancel_event(context_id)
+ self._session_mgr.clear_stop(context_id)
+
+ generated_ids: List[int] = []
+ filtered = ""
+ with self._model_lock:
+ acquire = self._kv_pool.acquire_context(context_id, prompt_ids)
+ self._kv_bridge.bind_for_request(context_id, prompt_ids, acquire.prefix_len)
+ try:
+ for token_id in self._iter_generate_ids(
+ prompt_ids=prompt_ids,
+ max_new_tokens=max_new_tokens,
+ sampling=sampling,
+ prefix_len=acquire.prefix_len,
+ cancel_event=cancel_event,
+ ):
+ generated_ids.append(int(token_id))
+ new_text = self.tokenizer.decode(generated_ids)
+ new_filtered = self._postprocess_text(new_text)
+ delta = new_filtered[len(filtered) :]
+ filtered = new_filtered
+ if delta:
+ yield _wrap_chunk(context_id, delta, None)
+ cancelled = cancel_event.is_set()
+ if cancelled:
+ self._active_tokens = list(prompt_ids)
+ self._kv_pool.update_context(context_id, prompt_ids)
+ else:
+ self._kv_pool.update_context(context_id, self._active_tokens)
+ self._kv_bridge.export_after_request(
+ context_id, self._active_tokens, self._kv_pool.block_size
+ )
+ except Exception:
+ self._kv_pool.release_context(context_id)
+ self._kv_bridge.release(context_id)
+ raise
+
+ if cancel_event.is_set():
+ self._session_mgr.clear_stop(context_id)
+ usage = {
+ "prompt_tokens": len(prompt_ids),
+ "completion_tokens": len(generated_ids),
+ "total_tokens": len(prompt_ids) + len(generated_ids),
+ }
+ yield _wrap_chunk(context_id, None, "stop", usage=usage, stopped=True)
+ return
+
+ messages = list(messages)
+ messages.append({"role": "assistant", "content": filtered})
+ self._session_mgr.save_messages(context_id, messages)
+ self._session_mgr.clear_stop(context_id)
+ hit_limit = len(generated_ids) >= max_new_tokens
+ finish_reason = "length" if hit_limit else "stop"
+ usage = {
+ "prompt_tokens": len(prompt_ids),
+ "completion_tokens": len(generated_ids),
+ "total_tokens": len(prompt_ids) + len(generated_ids),
+ }
+ yield _wrap_chunk(context_id, None, finish_reason, usage=usage)
+
+ # ------------------------------------------------------------------
+ # Streaming batch API (Phase 1)
+ # ------------------------------------------------------------------
+
+ def prepare_batch(self, payloads: List[Dict[str, Any]]) -> Optional[BatchState]:
+ """Prefill all sequences in a batch, return BatchState or None to fall back."""
+ if not payloads:
+ return None
+ if not hasattr(self.model, "prefill_packed") or not hasattr(self.model, "step_packed"):
+ return None
+
+ sequences: List[BatchSequenceState] = []
+ any_sampling = False
+ sampling_params_list: List[LlaisysSamplingParams] = []
+
+ for i, payload in enumerate(payloads):
+ # Edit-fork requests are not supported in batch path
+ if payload.get("edit_from_session_id"):
+ return None
+ try:
+ context_id, messages, prompt_ids, sampling, max_new_tokens = self._prepare_request(payload)
+ except Exception:
+ return None
+ if max_new_tokens <= 0:
+ return None
+
+ mode = str(sampling.get("mode", "")).strip().lower()
+ top_k = int(sampling.get("top_k", 1))
+ top_p = float(sampling.get("top_p", 0.0))
+ temperature = float(sampling.get("temperature", 0.0))
+ seed = int(sampling.get("seed", 0))
+ if mode == "argmax":
+ use_sampling = False
+ elif mode == "sample":
+ use_sampling = True
+ else:
+ use_sampling = temperature > 0.0 or top_k > 1 or top_p > 0.0
+ if use_sampling:
+ any_sampling = True
+
+ sp = LlaisysSamplingParams(top_k=top_k, top_p=top_p, temperature=temperature, seed=seed)
+ cancel_event = self._session_mgr.get_cancel_event(context_id)
+ self._session_mgr.clear_stop(context_id)
+
+ sequences.append(BatchSequenceState(
+ index=i,
+ context_id=context_id,
+ messages=messages,
+ prompt_ids=prompt_ids,
+ generated_ids=[],
+ filtered_text="",
+ max_new_tokens=max_new_tokens,
+ sampling=sampling,
+ sampling_params=sp,
+ use_sampling=use_sampling,
+ cancel_event=cancel_event,
+ finished=False,
+ finish_reason=None,
+ ))
+ sampling_params_list.append(sp)
+
+ # Check for packed-sampling API if needed
+ if any_sampling:
+ if not hasattr(self.model, "prefill_packed_sampling") or not hasattr(self.model, "step_packed_sampling"):
+ return None
+
+ prompts = [seq.prompt_ids for seq in sequences]
+ eos = self._eos_token()
+
+ with self._model_lock:
+ self.model.reset_kv_cache()
+ if any_sampling:
+ next_tokens = self.model.prefill_packed_sampling(prompts, sampling_params_list)
+ else:
+ next_tokens = self.model.prefill_packed(prompts)
+
+ if len(next_tokens) != len(sequences):
+ return None
+
+ for i, tok in enumerate(next_tokens):
+ t = int(tok)
+ if t >= 0:
+ sequences[i].generated_ids.append(t)
+ # Decode and compute initial filtered text
+ new_text = self.tokenizer.decode(sequences[i].generated_ids)
+ sequences[i].filtered_text = self._postprocess_text(new_text)
+ else:
+ sequences[i].finished = True
+ sequences[i].finish_reason = "stop"
+
+ # Check immediate termination
+ if not sequences[i].finished:
+ if eos >= 0 and t == eos:
+ sequences[i].finished = True
+ sequences[i].finish_reason = "stop"
+ elif len(sequences[i].generated_ids) >= sequences[i].max_new_tokens:
+ sequences[i].finished = True
+ sequences[i].finish_reason = "length"
+ elif sequences[i].cancel_event and sequences[i].cancel_event.is_set():
+ sequences[i].finished = True
+ sequences[i].finish_reason = "stop"
+
+ return BatchState(sequences=sequences, any_sampling=any_sampling, eos_token=eos)
+
+ def step_batch(self, state: BatchState) -> List[StepResult]:
+ """Execute one decode step for all active sequences. Dynamic shrinking: skip finished."""
+ results: List[StepResult] = []
+ active_indices: List[int] = []
+ decode_inputs: List[List[int]] = []
+ sampling_params_active: List[LlaisysSamplingParams] = []
+
+ for i, seq in enumerate(state.sequences):
+ if seq.finished:
+ continue
+ if seq.cancel_event and seq.cancel_event.is_set():
+ seq.finished = True
+ seq.finish_reason = "stop"
+ results.append(StepResult(
+ seq_index=i, delta_text="", finished=True,
+ finish_reason="stop", stopped=True,
+ ))
+ continue
+ active_indices.append(i)
+ last_tok = seq.generated_ids[-1] if seq.generated_ids else 0
+ decode_inputs.append([last_tok])
+ if seq.sampling_params is not None:
+ sampling_params_active.append(seq.sampling_params)
+
+ if not active_indices:
+ return results
+
+ with self._model_lock:
+ if state.any_sampling:
+ step_tokens = self.model.step_packed_sampling(decode_inputs, sampling_params_active)
+ else:
+ step_tokens = self.model.step_packed(decode_inputs)
+
+ if len(step_tokens) != len(active_indices):
+ # Model returned unexpected count; mark all active as finished
+ for ai in active_indices:
+ seq = state.sequences[ai]
+ seq.finished = True
+ seq.finish_reason = "stop"
+ results.append(StepResult(
+ seq_index=ai, delta_text="", finished=True,
+ finish_reason="stop", stopped=False,
+ ))
+ return results
+
+ for j, ai in enumerate(active_indices):
+ seq = state.sequences[ai]
+ t = int(step_tokens[j])
+
+ if t < 0:
+ seq.finished = True
+ seq.finish_reason = "stop"
+ results.append(StepResult(
+ seq_index=ai, delta_text="", finished=True,
+ finish_reason="stop", stopped=False,
+ ))
+ continue
+
+ seq.generated_ids.append(t)
+ new_text = self.tokenizer.decode(seq.generated_ids)
+ new_filtered = self._postprocess_text(new_text)
+ delta = new_filtered[len(seq.filtered_text):]
+ seq.filtered_text = new_filtered
+
+ # Check termination
+ finished = False
+ finish_reason = None
+ stopped = False
+
+ if state.eos_token >= 0 and t == state.eos_token:
+ finished = True
+ finish_reason = "stop"
+ elif len(seq.generated_ids) >= seq.max_new_tokens:
+ finished = True
+ finish_reason = "length"
+ elif seq.cancel_event and seq.cancel_event.is_set():
+ finished = True
+ finish_reason = "stop"
+ stopped = True
+
+ if finished:
+ seq.finished = True
+ seq.finish_reason = finish_reason
+
+ results.append(StepResult(
+ seq_index=ai, delta_text=delta, finished=finished,
+ finish_reason=finish_reason, stopped=stopped,
+ ))
+
+ return results
+
+ def finalize_sequence(self, state: BatchState, seq_index: int) -> None:
+ """Save session history and clean up for a completed sequence."""
+ seq = state.sequences[seq_index]
+ if seq.cancel_event and seq.cancel_event.is_set():
+ self._session_mgr.clear_stop(seq.context_id)
+ return
+ messages = list(seq.messages)
+ messages.append({"role": "assistant", "content": seq.filtered_text})
+ self._session_mgr.save_messages(seq.context_id, messages)
+ self._session_mgr.clear_stop(seq.context_id)
+
+
+class ChatHandler(BaseHTTPRequestHandler):
+ protocol_version = "HTTP/1.1"
+ scheduler: InferenceScheduler
+
+ def _set_cors_headers(self) -> None:
+ self.send_header("Access-Control-Allow-Origin", "*")
+ self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
+ self.send_header("Access-Control-Allow-Headers", "Content-Type")
+
+ def _send_json(self, code: int, payload: Dict[str, Any]) -> None:
+ data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
+ self.send_response(code)
+ self.send_header("Content-Type", "application/json; charset=utf-8")
+ self._set_cors_headers()
+ self.send_header("Content-Length", str(len(data)))
+ self.end_headers()
+ self.wfile.write(data)
+
+ def _write_chunk(self, data: bytes) -> bool:
+ try:
+ self.wfile.write(f"{len(data):X}\r\n".encode("ascii"))
+ self.wfile.write(data)
+ self.wfile.write(b"\r\n")
+ self.wfile.flush()
+ return True
+ except (BrokenPipeError, ConnectionAbortedError, ConnectionResetError):
+ return False
+
+ def do_GET(self) -> None:
+ parsed = urlparse(self.path)
+ if parsed.path == "/health":
+ self._send_json(200, {"status": "ok"})
+ return
+ if parsed.path == "/debug/kv":
+ query = parse_qs(parsed.query)
+ session_id = str((query.get("session_id") or [""])[0]).strip() or None
+ payload = self.scheduler.kv_debug_snapshot(session_id)
+ self._send_json(200, payload)
+ return
+ if parsed.path == "/debug/scheduler":
+ self._send_json(200, self.scheduler.debug_snapshot())
+ return
+ self._send_json(404, _wrap_error("not found", "invalid_request_error", "not_found"))
+
+ def do_OPTIONS(self) -> None:
+ self.send_response(204)
+ self._set_cors_headers()
+ self.send_header("Content-Length", "0")
+ self.end_headers()
+
+ def do_POST(self) -> None:
+ if self.path not in ("/chat", "/v1/chat/completions", "/chat/stop"):
+ self._send_json(404, _wrap_error("not found", "invalid_request_error", "not_found"))
+ return
+
+ length = int(self.headers.get("Content-Length", "0"))
+ body = self.rfile.read(length) if length > 0 else b"{}"
+ try:
+ payload = json.loads(body.decode("utf-8"))
+ except Exception:
+ self._send_json(400, _wrap_error("invalid JSON", "invalid_request_error", "invalid_json"))
+ return
+
+ if self.path == "/chat/stop":
+ session_id = str(payload.get("session_id") or "").strip()
+ if not session_id:
+ self._send_json(400, _wrap_error("session_id is required", "invalid_request_error", "missing_field"))
+ return
+ self.scheduler.request_stop(session_id)
+ self._send_json(200, {"ok": True, "session_id": session_id})
+ return
+
+ stream = bool(payload.get("stream", False))
+ if not stream:
+ try:
+ handle = self.scheduler.submit(payload, stream=False)
+ result = handle.get_result(timeout=self.scheduler.request_timeout_seconds())
+ if isinstance(result, dict) and result.get("error"):
+ code = 504 if result.get("code") == "timeout" else 400
+ err = result.get("error")
+ err_code = str(result.get("code", "")) or "server_error"
+ self._send_json(code, _wrap_error(str(err), "server_error", err_code))
+ return
+ except SchedulerQueueFullError as exc:
+ self._send_json(429, _wrap_error(str(exc), "server_error", "queue_full"))
+ return
+ except TaskTimeoutError as exc:
+ self._send_json(504, _wrap_error(str(exc), "server_error", "timeout"))
+ return
+ except RuntimeError as exc:
+ self._send_json(400, _wrap_error(str(exc), "server_error"))
+ return
+ self._send_json(200, result)
+ return
+
+ self.send_response(200)
+ self.send_header("Content-Type", "text/event-stream; charset=utf-8")
+ self.send_header("Cache-Control", "no-cache")
+ self.send_header("Connection", "keep-alive")
+ self.send_header("Transfer-Encoding", "chunked")
+ self._set_cors_headers()
+ self.end_headers()
+
+ current_session_id = ""
+ try:
+ handle = self.scheduler.submit(payload, stream=True)
+ for item in handle.iter_stream(timeout=self.scheduler.request_timeout_seconds()):
+ current_session_id = str(item.get("session_id") or current_session_id)
+ data = json.dumps(item, ensure_ascii=False).encode("utf-8")
+ if not self._write_chunk(b"data: " + data + b"\n\n"):
+ if current_session_id:
+ self.scheduler.request_stop(current_session_id)
+ return
+ self._write_chunk(b"data: [DONE]\n\n")
+ except SchedulerQueueFullError as exc:
+ err = _wrap_error(str(exc), "server_error", "queue_full")
+ data = json.dumps(err, ensure_ascii=False).encode("utf-8")
+ self._write_chunk(b"data: " + data + b"\n\n")
+ except TaskTimeoutError as exc:
+ if current_session_id:
+ self.scheduler.request_stop(current_session_id)
+ err = _wrap_error(str(exc), "server_error", "timeout")
+ data = json.dumps(err, ensure_ascii=False).encode("utf-8")
+ self._write_chunk(b"data: " + data + b"\n\n")
+ except Exception as exc:
+ if current_session_id:
+ self.scheduler.request_stop(current_session_id)
+ err = _wrap_error(str(exc), "server_error")
+ data = json.dumps(err, ensure_ascii=False).encode("utf-8")
+ self._write_chunk(b"data: " + data + b"\n\n")
+ finally:
+ self._write_chunk(b"")
+
+
+def _resolve_tokenizer_path(model_path: str, tokenizer_path: Optional[str]) -> str:
+ if tokenizer_path:
+ return tokenizer_path
+ path = Path(model_path)
+ sp = path / "tokenizer.model"
+ if sp.exists():
+ return str(sp)
+ hf = path / "tokenizer.json"
+ if hf.exists():
+ return str(hf)
+ raise FileNotFoundError(f"No tokenizer.model or tokenizer.json found under: {path}")
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", required=True, type=str, help="model directory")
+ parser.add_argument("--tokenizer", required=False, type=str, help="tokenizer file path")
+ parser.add_argument("--host", default="127.0.0.1", type=str)
+ parser.add_argument("--port", default=8000, type=int)
+ parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"])
+ parser.add_argument("--pool-size", default=1, type=int, help="deprecated")
+ parser.add_argument(
+ "--kv-runtime-reuse",
+ action="store_true",
+ help="enable experimental runtime KV reuse fast-path",
+ )
+ parser.add_argument("--kv-block-size", default=64, type=int, help="kv block token size")
+ parser.add_argument("--kv-max-blocks", default=4096, type=int, help="kv max block count")
+ parser.add_argument("--kv-max-bytes", default=268435456, type=int, help="kv max bytes")
+ parser.add_argument("--workers", default=1, type=int, help="inference worker count")
+ parser.add_argument("--queue-size", default=128, type=int, help="max queued tasks per worker")
+ parser.add_argument("--request-timeout-ms", default=120000, type=int, help="scheduler request timeout in milliseconds")
+ parser.add_argument(
+ "--continuous-batching",
+ action="store_true",
+ help="enable minimal iteration-level continuous scheduling",
+ )
+ parser.add_argument(
+ "--kv-aware-routing",
+ action="store_true",
+ help="enable KV-aware worker routing (query KV pool before dispatching)",
+ )
+ parser.add_argument(
+ "--max-batch-size",
+ default=8,
+ type=int,
+ help="max sequences per streaming batch (default 8)",
+ )
+ parser.add_argument(
+ "--shared-model",
+ action="store_true",
+ help="share a single model instance and KV pool across all workers",
+ )
+ parser.add_argument(
+ "--kv-memory-threshold",
+ default=0.0,
+ type=float,
+ help="KV memory pressure threshold (0.0=disabled, 0.85=recommended)",
+ )
+ args = parser.parse_args()
+
+ tokenizer_path = _resolve_tokenizer_path(args.model, args.tokenizer)
+ worker_count = max(1, int(args.workers))
+ services: List[ChatService] = []
+ if args.shared_model:
+ # Shared mode: one model, one tokenizer, one KV pool, one KV bridge, one lock
+ model = Qwen2(
+ args.model,
+ llaisys.DeviceType.CPU if args.device == "cpu" else llaisys.DeviceType.NVIDIA,
+ )
+ tokenizer = llaisys.Tokenizer(tokenizer_path)
+ shared_lock = threading.RLock()
+ shared_kv_pool = KVCachePool(
+ block_size=args.kv_block_size,
+ max_blocks=args.kv_max_blocks,
+ max_bytes=args.kv_max_bytes,
+ )
+ shared_kv_bridge = KVRuntimeBridge(model, enabled=args.kv_runtime_reuse)
+ for _ in range(worker_count):
+ services.append(
+ ChatService(
+ model,
+ tokenizer,
+ model_path=args.model,
+ enable_kv_runtime_reuse=args.kv_runtime_reuse,
+ block_size=args.kv_block_size,
+ max_blocks=args.kv_max_blocks,
+ max_bytes=args.kv_max_bytes,
+ model_lock=shared_lock,
+ kv_pool=shared_kv_pool,
+ kv_bridge=shared_kv_bridge,
+ )
+ )
+ else:
+ for _ in range(worker_count):
+ tokenizer = llaisys.Tokenizer(tokenizer_path)
+ model = Qwen2(
+ args.model,
+ llaisys.DeviceType.CPU if args.device == "cpu" else llaisys.DeviceType.NVIDIA,
+ )
+ services.append(
+ ChatService(
+ model,
+ tokenizer,
+ model_path=args.model,
+ enable_kv_runtime_reuse=args.kv_runtime_reuse,
+ block_size=args.kv_block_size,
+ max_blocks=args.kv_max_blocks,
+ max_bytes=args.kv_max_bytes,
+ )
+ )
+ scheduler = InferenceScheduler(
+ services,
+ queue_size=max(1, int(args.queue_size)),
+ request_timeout_ms=max(0, int(args.request_timeout_ms)),
+ continuous_batching=bool(args.continuous_batching),
+ kv_aware_routing=bool(args.kv_aware_routing),
+ max_batch_size=max(1, int(args.max_batch_size)),
+ kv_memory_threshold=float(args.kv_memory_threshold),
+ )
+ scheduler.start()
+
+ handler = ChatHandler
+ handler.scheduler = scheduler
+ server = ThreadingHTTPServer((args.host, args.port), handler)
+ server.daemon_threads = True
+ kv_routing_str = ", kv_aware_routing=on" if args.kv_aware_routing else ""
+ shared_str = ", shared_model=on" if args.shared_model else ""
+ kv_mem_str = f", kv_memory_threshold={args.kv_memory_threshold}" if args.kv_memory_threshold > 0 else ""
+ print(
+ f"LLAISYS chat server listening on http://{args.host}:{args.port} "
+ f"(workers={worker_count}, queue_size={max(1, int(args.queue_size))}{kv_routing_str}{shared_str}{kv_mem_str})"
+ )
+ try:
+ server.serve_forever()
+ finally:
+ scheduler.stop()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/python/llaisys/session_manager.py b/python/llaisys/session_manager.py
new file mode 100644
index 000000000..c37aa8fbf
--- /dev/null
+++ b/python/llaisys/session_manager.py
@@ -0,0 +1,97 @@
+from __future__ import annotations
+
+import threading
+import uuid
+from typing import Any, Dict, List, Tuple
+
+
+class SessionManager:
+ """Session message history and cancellation event management."""
+
+ def __init__(self) -> None:
+ self._lock = threading.Lock()
+ self._context_messages: Dict[str, List[Dict[str, str]]] = {}
+ self._cancel_events: Dict[str, threading.Event] = {}
+
+ def extract_messages(self, payload: Dict[str, Any]) -> Tuple[str, List[Dict[str, str]]]:
+ """Extract context_id and message list from payload.
+
+ Handles three input modes:
+ - edit_from_session_id: branch from history edit
+ - messages: direct message list
+ - prompt: append to existing session history
+ """
+ context_id = str(payload.get("session_id") or "").strip() or str(uuid.uuid4())
+ messages = payload.get("messages")
+ prompt = payload.get("prompt")
+ edit_from = str(payload.get("edit_from_session_id") or "").strip()
+ edit_index_raw = payload.get("edit_message_index")
+
+ if edit_from:
+ with self._lock:
+ source = list(self._context_messages.get(edit_from, []))
+ if not source:
+ raise ValueError("edit_from_session_id not found")
+ if prompt is None:
+ raise ValueError("prompt is required when editing history")
+ if edit_index_raw is None:
+ raise ValueError("edit_message_index is required when editing history")
+ edit_index = int(edit_index_raw)
+ if edit_index < 0 or edit_index >= len(source):
+ raise ValueError("edit_message_index out of range")
+ if source[edit_index].get("role") != "user":
+ raise ValueError("edit_message_index must point to a user message")
+ branched = source[: edit_index + 1]
+ branched[edit_index] = {"role": "user", "content": str(prompt)}
+ if not str(payload.get("session_id") or "").strip():
+ context_id = str(uuid.uuid4())
+ return context_id, branched
+
+ if messages is not None:
+ if not isinstance(messages, list):
+ raise ValueError("messages must be a list")
+ return context_id, list(messages)
+
+ if prompt is None:
+ raise ValueError("payload must include messages or prompt")
+
+ with self._lock:
+ history = list(self._context_messages.get(context_id, []))
+ history.append({"role": "user", "content": str(prompt)})
+ return context_id, history
+
+ def save_messages(self, context_id: str, messages: List[Dict[str, str]]) -> None:
+ """Save session message history."""
+ with self._lock:
+ self._context_messages[context_id] = list(messages)
+
+ def get_messages(self, context_id: str) -> List[Dict[str, str]]:
+ """Get session message history (returns a copy)."""
+ with self._lock:
+ return list(self._context_messages.get(context_id, []))
+
+ def get_cancel_event(self, context_id: str) -> threading.Event:
+ """Get or create a cancellation event for the given context."""
+ with self._lock:
+ event = self._cancel_events.get(context_id)
+ if event is None:
+ event = threading.Event()
+ self._cancel_events[context_id] = event
+ return event
+
+ def request_stop(self, context_id: str) -> bool:
+ """Set the cancellation event for the given context."""
+ with self._lock:
+ event = self._cancel_events.get(context_id)
+ if event is None:
+ event = threading.Event()
+ self._cancel_events[context_id] = event
+ event.set()
+ return True
+
+ def clear_stop(self, context_id: str) -> None:
+ """Clear the cancellation event for the given context."""
+ with self._lock:
+ event = self._cancel_events.get(context_id)
+ if event:
+ event.clear()
diff --git a/python/llaisys/tensor_parallel.py b/python/llaisys/tensor_parallel.py
new file mode 100644
index 000000000..f21029dcd
--- /dev/null
+++ b/python/llaisys/tensor_parallel.py
@@ -0,0 +1,64 @@
+"""Tensor parallel weight splitting for Qwen2 models (Megatron-style)."""
+
+import numpy as np
+
+
+def split_column(tensor: np.ndarray, rank: int, world_size: int) -> np.ndarray:
+ """Split tensor along dim 0 (output features). For Q/K/V/gate/up weights and biases."""
+ chunk = tensor.shape[0] // world_size
+ return tensor[rank * chunk : (rank + 1) * chunk].copy()
+
+
+def split_row(tensor: np.ndarray, rank: int, world_size: int) -> np.ndarray:
+ """Split tensor along dim 1 (input features). For attn_o/down weights."""
+ chunk = tensor.shape[1] // world_size
+ return tensor[:, rank * chunk : (rank + 1) * chunk].copy()
+
+
+# Weight name patterns that get column-split (dim 0)
+_COLUMN_SPLIT = {
+ "self_attn.q_proj.weight",
+ "self_attn.k_proj.weight",
+ "self_attn.v_proj.weight",
+ "self_attn.q_proj.bias",
+ "self_attn.k_proj.bias",
+ "self_attn.v_proj.bias",
+ "mlp.gate_proj.weight",
+ "mlp.up_proj.weight",
+}
+
+# Weight name patterns that get row-split (dim 1)
+_ROW_SPLIT = {
+ "self_attn.o_proj.weight",
+ "mlp.down_proj.weight",
+}
+
+
+def shard_qwen2_weights(
+ weights_dict: dict[str, np.ndarray], rank: int, world_size: int
+) -> dict[str, np.ndarray]:
+ """Shard Qwen2 model weights for tensor parallelism.
+
+ Megatron-style: column split Q/K/V/gate/up, row split attn_o/down.
+ Replicate: embeddings, norms, everything else.
+ """
+ if world_size <= 1:
+ return weights_dict
+
+ out = {}
+ for name, tensor in weights_dict.items():
+ # Extract the sub-key for layer weights (e.g. "self_attn.q_proj.weight")
+ sub = None
+ if name.startswith("model.layers."):
+ parts = name.split(".")
+ if len(parts) >= 4:
+ sub = ".".join(parts[3:])
+
+ if sub in _COLUMN_SPLIT:
+ out[name] = split_column(tensor, rank, world_size)
+ elif sub in _ROW_SPLIT:
+ out[name] = split_row(tensor, rank, world_size)
+ else:
+ # Replicate: embeddings, norms, lm_head
+ out[name] = tensor
+ return out
diff --git a/python/llaisys/tokenizer.py b/python/llaisys/tokenizer.py
new file mode 100644
index 000000000..12053bd12
--- /dev/null
+++ b/python/llaisys/tokenizer.py
@@ -0,0 +1,112 @@
+from __future__ import annotations
+
+from ctypes import POINTER, c_char_p, c_int64, c_size_t, create_string_buffer
+from pathlib import Path
+from typing import Iterable, List, Optional
+
+from .libllaisys import LIB_LLAISYS, LlaisysTokenizer
+
+
+class Tokenizer:
+ def __init__(self, model_path: str):
+ self._backend: str = "sentencepiece"
+ self._tokenizer: Optional[LlaisysTokenizer] = None
+ self._hf_tokenizer = None
+
+ tokenizer_path = self._resolve_tokenizer_path(model_path)
+ if tokenizer_path.suffix.lower() == ".json":
+ self._backend = "hf"
+ self._hf_tokenizer = self._load_hf_tokenizer(tokenizer_path)
+ else:
+ self._tokenizer = LIB_LLAISYS.llaisysTokenizerCreateSentencePiece(
+ c_char_p(str(tokenizer_path).encode("utf-8"))
+ )
+ if not self._tokenizer:
+ raise RuntimeError("llaisysTokenizerCreateSentencePiece failed")
+
+ def encode(self, text: str) -> List[int]:
+ if self._backend == "hf":
+ return list(self._hf_tokenizer.encode(text).ids)
+ data = text.encode("utf-8")
+ n = int(
+ LIB_LLAISYS.llaisysTokenizerEncode(
+ self._tokenizer, c_char_p(data), None, c_size_t(0)
+ )
+ )
+ if n < 0:
+ raise RuntimeError("llaisysTokenizerEncode failed")
+ if n == 0:
+ return []
+ out_ids = (c_int64 * n)()
+ written = int(
+ LIB_LLAISYS.llaisysTokenizerEncode(
+ self._tokenizer, c_char_p(data), out_ids, c_size_t(n)
+ )
+ )
+ if written < 0:
+ raise RuntimeError("llaisysTokenizerEncode failed")
+ return [int(out_ids[i]) for i in range(written)]
+
+ def decode(self, ids: Iterable[int]) -> str:
+ ids_list = list(ids)
+ n = len(ids_list)
+ if n == 0:
+ return ""
+ if self._backend == "hf":
+ return self._hf_tokenizer.decode(ids_list, skip_special_tokens=False)
+ buf = (c_int64 * n)(*ids_list)
+ max_len = int(
+ LIB_LLAISYS.llaisysTokenizerDecode(
+ self._tokenizer, buf, c_size_t(n), None, c_size_t(0)
+ )
+ )
+ if max_len < 0:
+ raise RuntimeError("llaisysTokenizerDecode failed")
+ out = create_string_buffer(max_len)
+ written = int(
+ LIB_LLAISYS.llaisysTokenizerDecode(
+ self._tokenizer, buf, c_size_t(n), out, c_size_t(max_len)
+ )
+ )
+ if written < 0:
+ raise RuntimeError("llaisysTokenizerDecode failed")
+ return out.value.decode("utf-8")
+
+ def close(self) -> None:
+ if self._tokenizer:
+ LIB_LLAISYS.llaisysTokenizerDestroy(self._tokenizer)
+ self._tokenizer = None
+
+ def __del__(self) -> None:
+ self.close()
+
+ @staticmethod
+ def _resolve_tokenizer_path(model_path: str) -> Path:
+ path = Path(model_path)
+ if path.is_dir():
+ sp = path / "tokenizer.model"
+ if sp.exists():
+ return sp
+ hf = path / "tokenizer.json"
+ if hf.exists():
+ return hf
+ raise FileNotFoundError(
+ f"No tokenizer.model or tokenizer.json found under: {path}"
+ )
+ if not path.exists():
+ raise FileNotFoundError(f"Tokenizer file not found: {path}")
+ return path
+
+ @staticmethod
+ def _load_hf_tokenizer(path: Path):
+ try:
+ from tokenizers import Tokenizer as HFTokenizer
+ except Exception as exc:
+ raise RuntimeError(
+ "tokenizer.json requires the 'tokenizers' package. "
+ "Install with: pip install tokenizers"
+ ) from exc
+ return HFTokenizer.from_file(str(path))
+
+
+__all__ = ["Tokenizer"]
diff --git a/scripts/_tp_worker.py b/scripts/_tp_worker.py
new file mode 100644
index 000000000..40365f74b
--- /dev/null
+++ b/scripts/_tp_worker.py
@@ -0,0 +1,219 @@
+#!/usr/bin/env python3
+"""TP worker process -- spawned by launch_tp.py.
+
+Reads env vars: RANK, WORLD_SIZE, CUDA_VISIBLE_DEVICES, TP_UID_FILE,
+TP_MODEL_PATH, TP_DEVICE, TP_PROMPT, TP_MAX_TOKENS.
+"""
+
+import os
+import sys
+import ctypes
+import time
+from pathlib import Path
+
+_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, _project_root)
+
+import numpy as np
+import safetensors
+
+from llaisys.libllaisys import (
+ LIB_LLAISYS,
+ LlaisysCommAPI,
+ llaisysComm_t,
+ LLAISYS_COMM_UNIQUE_ID_MAX_SIZE,
+ DeviceType,
+ DataType,
+ LlaisysQwen2Meta,
+ llaisysDeviceType_t,
+ llaisysDataType_t,
+ LlaisysSamplingParams,
+)
+from llaisys.tensor_parallel import shard_qwen2_weights
+
+
+def main():
+ rank = int(os.environ["RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ uid_file = os.environ["TP_UID_FILE"]
+ model_path = Path(os.environ["TP_MODEL_PATH"])
+ device_name = os.environ.get("TP_DEVICE", "nvidia")
+ prompt = os.environ.get("TP_PROMPT", "Hello")
+ max_tokens = int(os.environ.get("TP_MAX_TOKENS", "64"))
+
+ device = DeviceType.NVIDIA if device_name == "nvidia" else DeviceType.ILUVATAR
+ backend = 0 # NCCL
+
+ # Read unique ID
+ for _ in range(100):
+ if os.path.exists(uid_file) and os.path.getsize(uid_file) > 0:
+ break
+ time.sleep(0.1)
+ with open(uid_file, "rb") as f:
+ uid_bytes = f.read()
+
+ # Init comm
+ api_ptr = LIB_LLAISYS.llaisysGetCommAPI(backend)
+ api = api_ptr.contents
+ comm = llaisysComm_t()
+ uid_buf = ctypes.create_string_buffer(uid_bytes, LLAISYS_COMM_UNIQUE_ID_MAX_SIZE)
+ ret = api.init(ctypes.byref(comm), rank, world_size, uid_buf)
+ if ret != 0:
+ raise RuntimeError(f"commInit failed: {ret}")
+
+ # Load tokenizer (use transformers for HF tokenizer.json)
+ from transformers import AutoTokenizer
+ tok = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)
+
+ # Tokenize prompt
+ import json
+ config_path = model_path / "config.json"
+ with open(config_path, "r", encoding="utf-8") as f:
+ cfg = json.load(f)
+
+ # Load and shard weights
+ weights = {}
+ for file in sorted(model_path.glob("*.safetensors")):
+ import torch
+ data_ = safetensors.safe_open(file, framework="pt", device="cpu")
+ for name in data_.keys():
+ arr = data_.get_tensor(name)
+ if arr.dtype == torch.bfloat16:
+ arr = arr.to(torch.float16)
+ weights[name] = arr.cpu().numpy()
+
+ weights = shard_qwen2_weights(weights, rank, world_size)
+
+ # Build model meta
+ torch_dtype = str(cfg.get("torch_dtype", "bfloat16")).lower()
+ dtype = DataType.F16 # we convert bf16->f16 above
+ nlayer = int(cfg.get("num_hidden_layers", 0))
+ hs = int(cfg.get("hidden_size", 0))
+ nh = int(cfg.get("num_attention_heads", 0))
+ nkvh = int(cfg.get("num_key_value_heads", nh))
+ di = int(cfg.get("intermediate_size", 0))
+ maxseq = int(cfg.get("max_position_embeddings", 0))
+ voc = int(cfg.get("vocab_size", 0))
+ epsilon = float(cfg.get("rms_norm_eps", 1e-6))
+ theta = float(cfg.get("rope_theta", 10000.0))
+ eos = cfg.get("eos_token_id", -1)
+ end_token = int(eos[0]) if isinstance(eos, list) else int(eos)
+ dh = int(cfg.get("head_dim", hs // nh if nh else 0))
+
+ # Adjust nh/nkvh for TP
+ tp_nh = nh // world_size
+ tp_nkvh = nkvh // world_size
+
+ model_meta = LlaisysQwen2Meta(
+ llaisysDataType_t(dtype),
+ ctypes.c_size_t(nlayer),
+ ctypes.c_size_t(hs),
+ ctypes.c_size_t(tp_nh),
+ ctypes.c_size_t(tp_nkvh),
+ ctypes.c_size_t(dh),
+ ctypes.c_size_t(di // world_size),
+ ctypes.c_size_t(maxseq),
+ ctypes.c_size_t(voc),
+ ctypes.c_float(epsilon),
+ ctypes.c_float(theta),
+ ctypes.c_int64(end_token),
+ )
+
+ device_ids = (ctypes.c_int * 1)(0)
+ model = LIB_LLAISYS.llaisysQwen2ModelCreate(
+ ctypes.byref(model_meta), llaisysDeviceType_t(device), device_ids, 1
+ )
+ if not model:
+ raise RuntimeError("llaisysQwen2ModelCreate failed")
+
+ LIB_LLAISYS.llaisysQwen2ModelSetKVCacheEnabled(model, ctypes.c_int(1))
+ LIB_LLAISYS.llaisysQwen2ModelSetTensorParallel(model, comm, ctypes.c_void_p(None), world_size)
+ model_weights = LIB_LLAISYS.llaisysQwen2ModelWeights(model)
+
+ # Upload sharded weights
+ def upload_tensor(arr):
+ arr = np.ascontiguousarray(arr)
+ shape = (ctypes.c_size_t * arr.ndim)(*arr.shape)
+ dt = DataType.F16 if "float16" in arr.dtype.name else DataType.F32
+ tensor = LIB_LLAISYS.tensorCreate(
+ shape, ctypes.c_size_t(arr.ndim),
+ llaisysDataType_t(dt), llaisysDeviceType_t(device), ctypes.c_int(0),
+ )
+ LIB_LLAISYS.tensorLoad(tensor, ctypes.c_void_p(arr.ctypes.data))
+ return tensor
+
+ w = model_weights.contents
+ for name, arr in weights.items():
+ tensor = upload_tensor(arr)
+ if name in {"model.embed_tokens.weight", "transformer.wte.weight"}:
+ w.in_embed = tensor
+ elif name in {"lm_head.weight", "model.lm_head.weight"}:
+ w.out_embed = tensor
+ elif name in {"model.norm.weight", "transformer.ln_f.weight"}:
+ w.out_norm_w = tensor
+ elif name.startswith("model.layers."):
+ parts = name.split(".")
+ if len(parts) < 4:
+ continue
+ layer = int(parts[2])
+ sub = ".".join(parts[3:])
+ if sub == "input_layernorm.weight":
+ w.attn_norm_w[layer] = tensor
+ elif sub == "self_attn.q_proj.weight":
+ w.attn_q_w[layer] = tensor
+ elif sub == "self_attn.q_proj.bias":
+ w.attn_q_b[layer] = tensor
+ elif sub == "self_attn.k_proj.weight":
+ w.attn_k_w[layer] = tensor
+ elif sub == "self_attn.k_proj.bias":
+ w.attn_k_b[layer] = tensor
+ elif sub == "self_attn.v_proj.weight":
+ w.attn_v_w[layer] = tensor
+ elif sub == "self_attn.v_proj.bias":
+ w.attn_v_b[layer] = tensor
+ elif sub == "self_attn.o_proj.weight":
+ w.attn_o_w[layer] = tensor
+ elif sub == "post_attention_layernorm.weight":
+ w.mlp_norm_w[layer] = tensor
+ elif sub == "mlp.gate_proj.weight":
+ w.mlp_gate_w[layer] = tensor
+ elif sub == "mlp.up_proj.weight":
+ w.mlp_up_w[layer] = tensor
+ elif sub == "mlp.down_proj.weight":
+ w.mlp_down_w[layer] = tensor
+
+ if not w.out_embed and w.in_embed:
+ w.out_embed = w.in_embed
+
+ # Tokenize and run inference
+ input_ids = tok.encode(prompt, add_special_tokens=True)
+
+ # Prefill + decode
+ token_buf = (ctypes.c_int64 * len(input_ids))(*input_ids)
+ params = LlaisysSamplingParams(ctypes.c_int(1), ctypes.c_float(0.0), ctypes.c_float(0.0), ctypes.c_uint32(0))
+ next_token = int(LIB_LLAISYS.llaisysQwen2ModelPrefillSampling(
+ model, token_buf, ctypes.c_size_t(len(input_ids)), ctypes.byref(params),
+ ))
+
+ generated = list(input_ids)
+ for step in range(max_tokens):
+ if next_token < 0 or next_token == end_token:
+ break
+ generated.append(next_token)
+ tb = (ctypes.c_int64 * 1)(next_token)
+ next_token = int(LIB_LLAISYS.llaisysQwen2ModelStepSampling(
+ model, tb, ctypes.c_size_t(1), ctypes.byref(params),
+ ))
+
+ # Decode and print from rank 0
+ if rank == 0:
+ output_text = tok.decode(generated, skip_special_tokens=True)
+ print(output_text)
+
+ # Cleanup
+ LIB_LLAISYS.llaisysQwen2ModelDestroy(model)
+ api.destroy(comm)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/benchmark_chat_scheduler.py b/scripts/benchmark_chat_scheduler.py
new file mode 100644
index 000000000..d9d8af43e
--- /dev/null
+++ b/scripts/benchmark_chat_scheduler.py
@@ -0,0 +1,202 @@
+from __future__ import annotations
+
+import argparse
+from concurrent.futures import ThreadPoolExecutor, as_completed
+import json
+import math
+import statistics
+import time
+import urllib.error
+import urllib.request
+import uuid
+from typing import Any, Dict, List, Optional, Tuple
+
+
+def _post_json(url: str, payload: Dict[str, Any], timeout: float) -> Tuple[int, Dict[str, Any], str]:
+ body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
+ req = urllib.request.Request(
+ url,
+ data=body,
+ headers={"Content-Type": "application/json"},
+ method="POST",
+ )
+ try:
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
+ text = resp.read().decode("utf-8", errors="replace")
+ code = int(resp.status)
+ data = json.loads(text) if text else {}
+ return code, data, ""
+ except urllib.error.HTTPError as exc:
+ text = exc.read().decode("utf-8", errors="replace")
+ data = {}
+ try:
+ data = json.loads(text) if text else {}
+ except Exception:
+ pass
+ return int(exc.code), data, text or str(exc)
+ except Exception as exc:
+ return -1, {}, str(exc)
+
+
+def _get_json(url: str, timeout: float) -> Dict[str, Any]:
+ req = urllib.request.Request(url, method="GET")
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
+ text = resp.read().decode("utf-8", errors="replace")
+ return json.loads(text) if text else {}
+
+
+def _percentile(sorted_values: List[float], p: float) -> float:
+ if not sorted_values:
+ return 0.0
+ if len(sorted_values) == 1:
+ return float(sorted_values[0])
+ rank = p * (len(sorted_values) - 1)
+ low = int(math.floor(rank))
+ high = int(math.ceil(rank))
+ if low == high:
+ return float(sorted_values[low])
+ w = rank - low
+ return float(sorted_values[low] * (1 - w) + sorted_values[high] * w)
+
+
+def run_benchmark(args: argparse.Namespace) -> int:
+ endpoint = args.endpoint.rstrip("/")
+ chat_url = f"{endpoint}/chat"
+ scheduler_url = f"{endpoint}/debug/scheduler"
+ health_url = f"{endpoint}/health"
+
+ try:
+ health = _get_json(health_url, timeout=args.timeout)
+ except Exception as exc:
+ print(f"[ERROR] health check failed: {exc}")
+ return 2
+
+ print(f"[INFO] health: {health}")
+ before_debug: Dict[str, Any] = {}
+ try:
+ before_debug = _get_json(scheduler_url, timeout=args.timeout)
+ except Exception:
+ before_debug = {}
+
+ if args.warmup > 0:
+ print(f"[INFO] warmup requests: {args.warmup}")
+ for i in range(args.warmup):
+ payload: Dict[str, Any] = {
+ "prompt": f"{args.prompt} [warmup-{i}]",
+ "stream": False,
+ "max_new_tokens": args.max_new_tokens,
+ }
+ _post_json(chat_url, payload, timeout=args.timeout)
+
+ total = int(args.total_requests)
+ concurrency = int(args.concurrency)
+ print(f"[INFO] start benchmark: total={total}, concurrency={concurrency}, endpoint={chat_url}")
+
+ t0 = time.perf_counter()
+ latencies_ms: List[float] = []
+ errors: List[str] = []
+ status_count: Dict[int, int] = {}
+
+ def _one_request(i: int) -> Tuple[float, int, str]:
+ payload: Dict[str, Any] = {
+ "prompt": f"{args.prompt} [req-{i}]",
+ "stream": False,
+ "max_new_tokens": args.max_new_tokens,
+ }
+ if args.session_mode == "shared":
+ payload["session_id"] = args.shared_session_id
+ elif args.session_mode == "unique":
+ payload["session_id"] = f"{args.session_prefix}-{uuid.uuid4()}"
+ if args.sampling:
+ payload["sampling"] = args.sampling
+ if args.temperature is not None:
+ payload["temperature"] = args.temperature
+ if args.top_k is not None:
+ payload["top_k"] = args.top_k
+ if args.top_p is not None:
+ payload["top_p"] = args.top_p
+
+ s = time.perf_counter()
+ code, data, err = _post_json(chat_url, payload, timeout=args.timeout)
+ elapsed_ms = (time.perf_counter() - s) * 1000.0
+ if code == 200 and not data.get("error"):
+ return elapsed_ms, code, ""
+ detail = err or str(data.get("error") or f"HTTP {code}")
+ return elapsed_ms, code, detail
+
+ with ThreadPoolExecutor(max_workers=concurrency) as ex:
+ futures = [ex.submit(_one_request, i) for i in range(total)]
+ for fut in as_completed(futures):
+ elapsed_ms, code, detail = fut.result()
+ status_count[code] = status_count.get(code, 0) + 1
+ if code == 200 and not detail:
+ latencies_ms.append(elapsed_ms)
+ else:
+ errors.append(f"[{code}] {detail}")
+
+ total_elapsed_s = max(1e-9, time.perf_counter() - t0)
+ success = len(latencies_ms)
+ failed = len(errors)
+ throughput = total / total_elapsed_s
+
+ latencies_sorted = sorted(latencies_ms)
+ p50 = _percentile(latencies_sorted, 0.50)
+ p95 = _percentile(latencies_sorted, 0.95)
+ p99 = _percentile(latencies_sorted, 0.99)
+ avg = statistics.mean(latencies_ms) if latencies_ms else 0.0
+
+ after_debug: Dict[str, Any] = {}
+ try:
+ after_debug = _get_json(scheduler_url, timeout=args.timeout)
+ except Exception:
+ after_debug = {}
+
+ print("\n=== Benchmark Summary ===")
+ print(f"success: {success}/{total} ({(success / total) * 100:.1f}%)")
+ print(f"failed: {failed}")
+ print(f"elapsed_s: {total_elapsed_s:.3f}")
+ print(f"throughput_rps: {throughput:.2f}")
+ print(f"latency_ms: avg={avg:.1f}, p50={p50:.1f}, p95={p95:.1f}, p99={p99:.1f}")
+ print(f"status_count: {status_count}")
+
+ if before_debug:
+ print("\n=== /debug/scheduler (before) ===")
+ print(json.dumps(before_debug, ensure_ascii=False, indent=2))
+ if after_debug:
+ print("\n=== /debug/scheduler (after) ===")
+ print(json.dumps(after_debug, ensure_ascii=False, indent=2))
+
+ if errors:
+ print("\n=== Sample Errors (up to 10) ===")
+ for line in errors[:10]:
+ print(line)
+ return 0 if success > 0 else 1
+
+
+def build_parser() -> argparse.ArgumentParser:
+ p = argparse.ArgumentParser(description="LLAISYS scheduler benchmark (non-stream chat requests)")
+ p.add_argument("--endpoint", default="http://127.0.0.1:8000", type=str)
+ p.add_argument("--total-requests", default=20, type=int)
+ p.add_argument("--concurrency", default=5, type=int)
+ p.add_argument("--prompt", default="请用一句话介绍北京", type=str)
+ p.add_argument("--max-new-tokens", default=32, type=int)
+ p.add_argument("--timeout", default=60.0, type=float, help="per-request timeout in seconds")
+ p.add_argument("--warmup", default=1, type=int)
+ p.add_argument("--session-mode", choices=["none", "shared", "unique"], default="none")
+ p.add_argument("--shared-session-id", default="bench-shared-session", type=str)
+ p.add_argument("--session-prefix", default="bench-session", type=str)
+ p.add_argument("--sampling", default="", type=str)
+ p.add_argument("--temperature", default=None, type=float)
+ p.add_argument("--top-k", default=None, type=int)
+ p.add_argument("--top-p", default=None, type=float)
+ return p
+
+
+def main() -> None:
+ parser = build_parser()
+ args = parser.parse_args()
+ raise SystemExit(run_benchmark(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/launch_tp.py b/scripts/launch_tp.py
new file mode 100644
index 000000000..a7ccf2788
--- /dev/null
+++ b/scripts/launch_tp.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python3
+"""Tensor-parallel multi-process launcher for llaisys inference.
+
+Rank 0 generates a NCCL unique ID, writes it to a temp file, then spawns
+N subprocesses (one per GPU). Each subprocess loads sharded weights, inits
+the communicator with the shared unique ID, and runs inference. Output is
+printed from rank 0.
+
+Usage:
+ python scripts/launch_tp.py --model /path/to/qwen2 --nranks 2 --prompt "Hello"
+"""
+
+import argparse
+import os
+import sys
+import subprocess
+import tempfile
+import ctypes
+
+# Ensure project root is on sys.path
+_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, _project_root)
+
+
+def generate_unique_id(backend=0):
+ """Generate NCCL unique ID via llaisysCommGenerateUniqueId."""
+ from llaisys.libllaisys import LIB_LLAISYS, LLAISYS_COMM_UNIQUE_ID_MAX_SIZE
+ id_buf = ctypes.create_string_buffer(LLAISYS_COMM_UNIQUE_ID_MAX_SIZE)
+ id_size = ctypes.c_size_t(0)
+ ret = LIB_LLAISYS.llaisysCommGenerateUniqueId(
+ backend, id_buf, ctypes.byref(id_size)
+ )
+ if ret != 0:
+ raise RuntimeError(f"llaisysCommGenerateUniqueId failed: {ret}")
+ return id_buf.raw[: id_size.value]
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Tensor-parallel launcher")
+ parser.add_argument("--model", required=True, help="Path to model directory")
+ parser.add_argument("--nranks", type=int, default=2, help="Number of TP ranks")
+ parser.add_argument("--device", default="nvidia", choices=["nvidia", "iluvatar"])
+ parser.add_argument("--prompt", default="Hello", help="Input prompt")
+ parser.add_argument("--max-tokens", type=int, default=64)
+ args = parser.parse_args()
+
+ # Generate unique ID on rank 0 process
+ uid_bytes = generate_unique_id()
+
+ # Write unique ID to temp file for subprocesses
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".uid")
+ tmp.write(uid_bytes)
+ tmp.close()
+ uid_path = tmp.name
+
+ worker = os.path.join(os.path.dirname(__file__), "_tp_worker.py")
+
+ procs = []
+ for rank in range(args.nranks):
+ env = os.environ.copy()
+ env["RANK"] = str(rank)
+ env["WORLD_SIZE"] = str(args.nranks)
+ env["CUDA_VISIBLE_DEVICES"] = str(rank)
+ env["TP_UID_FILE"] = uid_path
+ env["TP_MODEL_PATH"] = args.model
+ env["TP_DEVICE"] = args.device
+ env["TP_PROMPT"] = args.prompt
+ env["TP_MAX_TOKENS"] = str(args.max_tokens)
+
+ proc = subprocess.Popen(
+ [sys.executable, worker],
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ procs.append((rank, proc))
+
+ # Wait for all and collect output
+ for rank, proc in procs:
+ stdout, stderr = proc.communicate()
+ if proc.returncode != 0:
+ print(f"[rank {rank}] FAILED (exit {proc.returncode})", file=sys.stderr)
+ if stderr:
+ print(stderr.decode(errors="replace"), file=sys.stderr)
+ elif rank == 0:
+ print(stdout.decode(errors="replace"), end="")
+
+ # Cleanup
+ try:
+ os.unlink(uid_path)
+ except OSError:
+ pass
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_gpu.ps1 b/scripts/run_gpu.ps1
new file mode 100644
index 000000000..150f84486
--- /dev/null
+++ b/scripts/run_gpu.ps1
@@ -0,0 +1,130 @@
+param(
+ [ValidateSet("build", "test", "server", "all")]
+ [string]$Mode = "all",
+ [string]$Model = "",
+ [string]$Device = "nvidia",
+ [string]$CondaEnv = "llaisys-gpu",
+ [string]$ConfigPath = "",
+ [switch]$SkipTests,
+ [switch]$ActivateConda
+)
+
+$ErrorActionPreference = "Stop"
+
+function Write-Step([string]$Message) {
+ Write-Host "==> $Message"
+}
+
+$PythonExe = "python"
+function Resolve-PythonExe {
+ if ($env:CONDA_PREFIX) {
+ $candidate = Join-Path $env:CONDA_PREFIX "python.exe"
+ if (Test-Path $candidate) {
+ return $candidate
+ }
+ }
+ return "python"
+}
+
+$RepoRoot = Resolve-Path (Join-Path $PSScriptRoot "..")
+Set-Location $RepoRoot
+
+if ([string]::IsNullOrWhiteSpace($ConfigPath)) {
+ $ConfigPath = Join-Path $RepoRoot "scripts\run_gpu.config.json"
+}
+
+$Config = $null
+if (Test-Path $ConfigPath) {
+ try {
+ $Config = Get-Content $ConfigPath -Raw | ConvertFrom-Json
+ } catch {
+ throw "Failed to read config file: $ConfigPath"
+ }
+}
+
+if ($Config -ne $null) {
+ if (-not $PSBoundParameters.ContainsKey("Model") -and $Config.model) {
+ $Model = $Config.model
+ }
+ if (-not $PSBoundParameters.ContainsKey("Device") -and $Config.device) {
+ $Device = $Config.device
+ }
+ if (-not $PSBoundParameters.ContainsKey("CondaEnv") -and $Config.conda_env) {
+ $CondaEnv = $Config.conda_env
+ }
+}
+
+if ($ActivateConda) {
+ if (Get-Command conda -ErrorAction SilentlyContinue) {
+ Write-Step "Activating conda env: $CondaEnv"
+ conda activate $CondaEnv
+ } else {
+ throw "conda is not available in this shell. Run 'conda init powershell' and reopen PowerShell."
+ }
+}
+$PythonExe = Resolve-PythonExe
+
+function Build-Gpu {
+ Write-Step "Configuring xmake"
+ xmake f -m release --nv-gpu=y --vs=2022
+
+ Write-Step "Building"
+ xmake
+
+ $dllSrc = Join-Path $RepoRoot "build\windows\x64\release\llaisys.dll"
+ $dllDst = Join-Path $RepoRoot "python\llaisys\libllaisys\llaisys.dll"
+ if (!(Test-Path $dllSrc)) {
+ throw "Build output not found: $dllSrc"
+ }
+
+ Write-Step "Copying DLL to python package"
+ Copy-Item $dllSrc $dllDst -Force
+}
+
+function Ensure-Dll {
+ $dllDst = Join-Path $RepoRoot "python\llaisys\libllaisys\llaisys.dll"
+ if (Test-Path $dllDst) {
+ return
+ }
+ $dllCandidates = @(
+ (Join-Path $RepoRoot "bin\llaisys.dll"),
+ (Join-Path $RepoRoot "build\windows\x64\release\llaisys.dll")
+ )
+ foreach ($dllSrc in $dllCandidates) {
+ if (Test-Path $dllSrc) {
+ Write-Step "Copying DLL to python package"
+ Copy-Item $dllSrc $dllDst -Force
+ return
+ }
+ }
+ throw "Missing llaisys.dll. Run '-Mode build' or copy it to: $dllDst"
+}
+
+function Test-Gpu {
+ Ensure-Dll
+ Write-Step "Running GPU op tests"
+ & $PythonExe test/ops_gpu/run_all.py
+}
+
+function Run-Server {
+ if ([string]::IsNullOrWhiteSpace($Model)) {
+ throw "Model path is required. Provide -Model or set 'model' in $ConfigPath"
+ }
+ Ensure-Dll
+ Write-Step "Starting server on $Device"
+ & $PythonExe -m llaisys.server --model $Model --device $Device
+}
+
+switch ($Mode) {
+ "build" { Build-Gpu }
+ "test" { Test-Gpu }
+ "server" { Run-Server }
+ "all" {
+ Build-Gpu
+ if (-not $SkipTests) {
+ Test-Gpu
+ }
+ Run-Server
+ }
+}
+
diff --git a/src/core/context/context.cpp b/src/core/context/context.cpp
index 44894b9e7..cbcf1dc6b 100644
--- a/src/core/context/context.cpp
+++ b/src/core/context/context.cpp
@@ -3,7 +3,8 @@
#include
namespace llaisys::core {
-
+
+//构造函数,初始化运行时
Context::Context() {
// All device types, put CPU at the end
std::vector device_typs;
@@ -31,6 +32,7 @@ Context::Context() {
}
}
+//销毁上下文及其包含的运行时
Context::~Context() {
// Destroy current runtime first.
delete _current_runtime;
@@ -49,6 +51,7 @@ Context::~Context() {
_runtime_map.clear();
}
+//设置当前设备
void Context::setDevice(llaisysDeviceType_t device_type, int device_id) {
// If doest not match the current runtime.
if (_current_runtime == nullptr || _current_runtime->deviceType() != device_type || _current_runtime->deviceId() != device_id) {
@@ -65,6 +68,7 @@ void Context::setDevice(llaisysDeviceType_t device_type, int device_id) {
}
}
+//获取当前运行时
Runtime &Context::runtime() {
ASSERT(_current_runtime != nullptr, "No runtime is activated, please call setDevice() first.");
return *_current_runtime;
diff --git a/src/core/context/context.hpp b/src/core/context/context.hpp
index a3ebcdecf..bd9707263 100644
--- a/src/core/context/context.hpp
+++ b/src/core/context/context.hpp
@@ -27,6 +27,7 @@ class Context {
Context(Context &&) = delete;
Context &operator=(Context &&) = delete;
+ //设置当前设备
void setDevice(llaisysDeviceType_t device_type, int device_id);
Runtime &runtime();
diff --git a/src/device/comm_api.cpp b/src/device/comm_api.cpp
new file mode 100644
index 000000000..d70b4b0bf
--- /dev/null
+++ b/src/device/comm_api.cpp
@@ -0,0 +1,89 @@
+#include "comm_api.hpp"
+
+namespace llaisys::device {
+
+int commInit(llaisysComm_t *, int, int, const void *) {
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ return -1;
+}
+
+void commDestroy(llaisysComm_t) {
+ EXCEPTION_UNSUPPORTED_DEVICE;
+}
+
+int commGetRank(llaisysComm_t) {
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ return -1;
+}
+
+int commGetSize(llaisysComm_t) {
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ return -1;
+}
+
+void commAllreduce(const void *, void *, size_t, llaisysDataType_t, llaisysReduceOp_t, llaisysComm_t, llaisysStream_t) {
+ EXCEPTION_UNSUPPORTED_DEVICE;
+}
+
+void commBroadcast(void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t) {
+ EXCEPTION_UNSUPPORTED_DEVICE;
+}
+
+void commSend(const void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t) {
+ EXCEPTION_UNSUPPORTED_DEVICE;
+}
+
+void commRecv(void *, size_t, llaisysDataType_t, int, llaisysComm_t, llaisysStream_t) {
+ EXCEPTION_UNSUPPORTED_DEVICE;
+}
+
+static const LlaisysCommAPI NOOP_COMM_API = {
+ &commInit,
+ &commDestroy,
+ &commGetRank,
+ &commGetSize,
+ &commAllreduce,
+ &commBroadcast,
+ &commSend,
+ &commRecv};
+
+const LlaisysCommAPI *getUnsupportedCommAPI() {
+ return &NOOP_COMM_API;
+}
+
+const LlaisysCommAPI *getCommAPI(llaisysCommBackend_t backend) {
+ switch (backend) {
+ case LLAISYS_COMM_NCCL:
+#ifdef ENABLE_NVIDIA_API
+ return llaisys::device::nccl::getCommAPI();
+#else
+ return getUnsupportedCommAPI();
+#endif
+ case LLAISYS_COMM_IXCCL:
+#ifdef ENABLE_ILUVATAR_API
+ return llaisys::device::ixccl::getCommAPI();
+#else
+ return getUnsupportedCommAPI();
+#endif
+ case LLAISYS_COMM_MPI:
+ return getUnsupportedCommAPI();
+ default:
+ return getUnsupportedCommAPI();
+ }
+}
+
+int commGenerateUniqueId(llaisysCommBackend_t backend, void *id_out, size_t *id_size) {
+ switch (backend) {
+ case LLAISYS_COMM_NCCL:
+#ifdef ENABLE_NVIDIA_API
+ return llaisys::device::nccl::commGenerateUniqueId(id_out, id_size);
+#else
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ return -1;
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ return -1;
+ }
+}
+} // namespace llaisys::device
diff --git a/src/device/comm_api.hpp b/src/device/comm_api.hpp
new file mode 100644
index 000000000..c0290e7d4
--- /dev/null
+++ b/src/device/comm_api.hpp
@@ -0,0 +1,25 @@
+#pragma once
+#include "llaisys/comm.h"
+
+#include "../utils.hpp"
+
+namespace llaisys::device {
+const LlaisysCommAPI *getCommAPI(llaisysCommBackend_t backend);
+int commGenerateUniqueId(llaisysCommBackend_t backend, void *id_out, size_t *id_size);
+
+const LlaisysCommAPI *getUnsupportedCommAPI();
+
+#ifdef ENABLE_NVIDIA_API
+namespace nccl {
+const LlaisysCommAPI *getCommAPI();
+int commGenerateUniqueId(void *id_out, size_t *id_size);
+}
+#endif
+
+#ifdef ENABLE_ILUVATAR_API
+namespace ixccl {
+const LlaisysCommAPI *getCommAPI();
+}
+#endif
+
+} // namespace llaisys::device
diff --git a/src/device/iluvatar/devlink_stub.cu b/src/device/iluvatar/devlink_stub.cu
new file mode 100644
index 000000000..b64d3641b
--- /dev/null
+++ b/src/device/iluvatar/devlink_stub.cu
@@ -0,0 +1,3 @@
+#include
+
+__global__ void llaisys_devlink_stub() {}
diff --git a/src/device/iluvatar/iluvatar_resource.cu b/src/device/iluvatar/iluvatar_resource.cu
new file mode 100644
index 000000000..67850c218
--- /dev/null
+++ b/src/device/iluvatar/iluvatar_resource.cu
@@ -0,0 +1,7 @@
+#include "iluvatar_resource.cuh"
+
+namespace llaisys::device::iluvatar {
+
+Resource::Resource(int device_id) : llaisys::device::DeviceResource(LLAISYS_DEVICE_ILUVATAR, device_id) {}
+
+} // namespace llaisys::device::iluvatar
diff --git a/src/device/iluvatar/iluvatar_resource.cuh b/src/device/iluvatar/iluvatar_resource.cuh
new file mode 100644
index 000000000..d3e637c39
--- /dev/null
+++ b/src/device/iluvatar/iluvatar_resource.cuh
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "../device_resource.hpp"
+
+namespace llaisys::device::iluvatar {
+class Resource : public llaisys::device::DeviceResource {
+public:
+ Resource(int device_id);
+ ~Resource();
+};
+} // namespace llaisys::device::iluvatar
diff --git a/src/device/iluvatar/iluvatar_runtime_api.cu b/src/device/iluvatar/iluvatar_runtime_api.cu
new file mode 100644
index 000000000..24445f19d
--- /dev/null
+++ b/src/device/iluvatar/iluvatar_runtime_api.cu
@@ -0,0 +1,119 @@
+#include "../runtime_api.hpp"
+#include "iluvatar_utils.hpp"
+
+#include
+
+namespace llaisys::device::iluvatar {
+
+namespace runtime_api {
+int getDeviceCount() {
+ int count = 0;
+ cuda_check(cudaGetDeviceCount(&count));
+ return count;
+}
+
+void setDevice(int device_id) {
+ cuda_check(cudaSetDevice(device_id));
+}
+
+void deviceSynchronize() {
+ cuda_check(cudaDeviceSynchronize());
+}
+
+llaisysStream_t createStream() {
+ cudaStream_t stream{};
+ cuda_check(cudaStreamCreate(&stream));
+ return reinterpret_cast(stream);
+}
+
+void destroyStream(llaisysStream_t stream) {
+ cuda_check(cudaStreamDestroy(reinterpret_cast(stream)));
+}
+void streamSynchronize(llaisysStream_t stream) {
+ cuda_check(cudaStreamSynchronize(reinterpret_cast(stream)));
+}
+
+void *mallocDevice(size_t size) {
+ void *ptr = nullptr;
+ cuda_check(cudaMalloc(&ptr, size));
+ return ptr;
+}
+
+void freeDevice(void *ptr) {
+ cuda_check(cudaFree(ptr));
+}
+
+void *mallocHost(size_t size) {
+ void *ptr = nullptr;
+ cuda_check(cudaMallocHost(&ptr, size));
+ return ptr;
+}
+
+void freeHost(void *ptr) {
+ cuda_check(cudaFreeHost(ptr));
+}
+
+void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) {
+ cudaMemcpyKind cuda_kind = cudaMemcpyDefault;
+ switch (kind) {
+ case LLAISYS_MEMCPY_H2H:
+ cuda_kind = cudaMemcpyHostToHost;
+ break;
+ case LLAISYS_MEMCPY_H2D:
+ cuda_kind = cudaMemcpyHostToDevice;
+ break;
+ case LLAISYS_MEMCPY_D2H:
+ cuda_kind = cudaMemcpyDeviceToHost;
+ break;
+ case LLAISYS_MEMCPY_D2D:
+ cuda_kind = cudaMemcpyDeviceToDevice;
+ break;
+ default:
+ cuda_kind = cudaMemcpyDefault;
+ break;
+ }
+ cuda_check(cudaMemcpy(dst, src, size, cuda_kind));
+}
+
+void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) {
+ cudaMemcpyKind cuda_kind = cudaMemcpyDefault;
+ switch (kind) {
+ case LLAISYS_MEMCPY_H2H:
+ cuda_kind = cudaMemcpyHostToHost;
+ break;
+ case LLAISYS_MEMCPY_H2D:
+ cuda_kind = cudaMemcpyHostToDevice;
+ break;
+ case LLAISYS_MEMCPY_D2H:
+ cuda_kind = cudaMemcpyDeviceToHost;
+ break;
+ case LLAISYS_MEMCPY_D2D:
+ cuda_kind = cudaMemcpyDeviceToDevice;
+ break;
+ default:
+ cuda_kind = cudaMemcpyDefault;
+ break;
+ }
+ cuda_check(cudaMemcpyAsync(dst, src, size, cuda_kind, reinterpret_cast(stream)));
+}
+
+static const LlaisysRuntimeAPI RUNTIME_API = {
+ &getDeviceCount,
+ &setDevice,
+ &deviceSynchronize,
+ &createStream,
+ &destroyStream,
+ &streamSynchronize,
+ &mallocDevice,
+ &freeDevice,
+ &mallocHost,
+ &freeHost,
+ &memcpySync,
+ &memcpyAsync};
+
+} // namespace runtime_api
+
+const LlaisysRuntimeAPI *getRuntimeAPI() {
+ return &runtime_api::RUNTIME_API;
+}
+} // namespace llaisys::device::iluvatar
diff --git a/src/device/iluvatar/iluvatar_utils.hpp b/src/device/iluvatar/iluvatar_utils.hpp
new file mode 100644
index 000000000..5254b3c11
--- /dev/null
+++ b/src/device/iluvatar/iluvatar_utils.hpp
@@ -0,0 +1,54 @@
+#pragma once
+
+#include "../../utils/types.hpp"
+
+#include
+#include
+#include
+
+#include
+
+namespace llaisys::device::iluvatar {
+inline void cuda_check(cudaError_t err) {
+ if (err == cudaSuccess) {
+ return;
+ }
+ if (err == cudaErrorCudartUnloading || err == cudaErrorContextIsDestroyed) {
+ return;
+ }
+ throw std::runtime_error(cudaGetErrorString(err));
+}
+
+template
+struct ScalarOps;
+
+template <>
+struct ScalarOps {
+ __device__ static inline float load(const float *ptr) {
+ return *ptr;
+ }
+ __device__ static inline void store(float *ptr, float v) {
+ *ptr = v;
+ }
+};
+
+template <>
+struct ScalarOps {
+ __device__ static inline float load(const llaisys::fp16_t *ptr) {
+ return __half2float(*reinterpret_cast(ptr));
+ }
+ __device__ static inline void store(llaisys::fp16_t *ptr, float v) {
+ *reinterpret_cast<__half *>(ptr) = __float2half(v);
+ }
+};
+
+template <>
+struct ScalarOps {
+ __device__ static inline float load(const llaisys::bf16_t *ptr) {
+ return __bfloat162float(*reinterpret_cast(ptr));
+ }
+ __device__ static inline void store(llaisys::bf16_t *ptr, float v) {
+ *reinterpret_cast<__nv_bfloat16 *>(ptr) = __float2bfloat16(v);
+ }
+};
+} // namespace llaisys::device::iluvatar
diff --git a/src/device/nvidia/cuda_utils.hpp b/src/device/nvidia/cuda_utils.hpp
new file mode 100644
index 000000000..20522193f
--- /dev/null
+++ b/src/device/nvidia/cuda_utils.hpp
@@ -0,0 +1,55 @@
+#pragma once
+
+#include "../../utils/types.hpp"
+
+#include
+#include
+#include
+
+#include
+
+namespace llaisys::device::nvidia {
+inline void cuda_check(cudaError_t err) {
+ if (err == cudaSuccess) {
+ return;
+ }
+ // During process shutdown, CUDA may report unloading/destroyed context.
+ if (err == cudaErrorCudartUnloading || err == cudaErrorContextIsDestroyed) {
+ return;
+ }
+ throw std::runtime_error(cudaGetErrorString(err));
+}
+
+template
+struct ScalarOps;
+
+template <>
+struct ScalarOps {
+ __device__ static inline float load(const float *ptr) {
+ return *ptr;
+ }
+ __device__ static inline void store(float *ptr, float v) {
+ *ptr = v;
+ }
+};
+
+template <>
+struct ScalarOps {
+ __device__ static inline float load(const llaisys::fp16_t *ptr) {
+ return __half2float(*reinterpret_cast(ptr));
+ }
+ __device__ static inline void store(llaisys::fp16_t *ptr, float v) {
+ *reinterpret_cast<__half *>(ptr) = __float2half(v);
+ }
+};
+
+template <>
+struct ScalarOps {
+ __device__ static inline float load(const llaisys::bf16_t *ptr) {
+ return __bfloat162float(*reinterpret_cast(ptr));
+ }
+ __device__ static inline void store(llaisys::bf16_t *ptr, float v) {
+ *reinterpret_cast<__nv_bfloat16 *>(ptr) = __float2bfloat16(v);
+ }
+};
+} // namespace llaisys::device::nvidia
diff --git a/src/device/nvidia/devlink_stub.cu b/src/device/nvidia/devlink_stub.cu
new file mode 100644
index 000000000..6dc8ecc01
--- /dev/null
+++ b/src/device/nvidia/devlink_stub.cu
@@ -0,0 +1,4 @@
+#include
+
+__global__ void llaisys_devlink_stub() {}
+
diff --git a/src/device/nvidia/nvidia_comm.cu b/src/device/nvidia/nvidia_comm.cu
new file mode 100644
index 000000000..f94691ed5
--- /dev/null
+++ b/src/device/nvidia/nvidia_comm.cu
@@ -0,0 +1,140 @@
+#include "../comm_api.hpp"
+#include "cuda_utils.hpp"
+
+#include
+#include
+#include
+
+namespace llaisys::device::nvidia {
+
+inline void nccl_check(ncclResult_t result) {
+ if (result == ncclSuccess) {
+ return;
+ }
+ throw std::runtime_error(ncclGetErrorString(result));
+}
+
+inline ncclDataType_t to_nccl_dtype(llaisysDataType_t dtype) {
+ switch (dtype) {
+ case LLAISYS_DTYPE_F32: return ncclFloat32;
+ case LLAISYS_DTYPE_F16: return ncclFloat16;
+ case LLAISYS_DTYPE_BF16: return ncclBfloat16;
+ case LLAISYS_DTYPE_I32: return ncclInt32;
+ case LLAISYS_DTYPE_I8: return ncclInt8;
+ default: throw std::runtime_error("Unsupported data type");
+ }
+}
+
+inline ncclRedOp_t to_nccl_op(llaisysReduceOp_t op) {
+ switch (op) {
+ case LLAISYS_REDUCE_SUM: return ncclSum;
+ case LLAISYS_REDUCE_PROD: return ncclProd;
+ case LLAISYS_REDUCE_MIN: return ncclMin;
+ case LLAISYS_REDUCE_MAX: return ncclMax;
+ default: throw std::runtime_error("Unsupported reduce op");
+ }
+}
+
+namespace nccl {
+
+int commInit(llaisysComm_t *comm, int rank, int size, const void *unique_id) {
+ ncclComm_t nccl_comm;
+ ncclUniqueId id;
+
+ if (unique_id) {
+ memcpy(&id, unique_id, sizeof(id));
+ } else if (rank == 0) {
+ nccl_check(ncclGetUniqueId(&id));
+ }
+
+ nccl_check(ncclCommInitRank(&nccl_comm, size, id, rank));
+ *comm = reinterpret_cast(nccl_comm);
+ return 0;
+}
+
+int commGenerateUniqueId(void *id_out, size_t *id_size) {
+ ncclUniqueId id;
+ nccl_check(ncclGetUniqueId(&id));
+ memcpy(id_out, &id, sizeof(id));
+ *id_size = sizeof(id);
+ return 0;
+}
+
+void commDestroy(llaisysComm_t comm) {
+ ncclComm_t nccl_comm = reinterpret_cast(comm);
+ nccl_check(ncclCommDestroy(nccl_comm));
+}
+
+int commGetRank(llaisysComm_t comm) {
+ ncclComm_t nccl_comm = reinterpret_cast(comm);
+ int rank;
+ nccl_check(ncclCommUserRank(nccl_comm, &rank));
+ return rank;
+}
+
+int commGetSize(llaisysComm_t comm) {
+ ncclComm_t nccl_comm = reinterpret_cast(comm);
+ int size;
+ nccl_check(ncclCommCount(nccl_comm, &size));
+ return size;
+}
+
+void commAllreduce(const void *sendbuf, void *recvbuf, size_t count,
+ llaisysDataType_t dtype, llaisysReduceOp_t op,
+ llaisysComm_t comm, llaisysStream_t stream) {
+ ncclComm_t nccl_comm = reinterpret_cast(comm);
+ cudaStream_t cuda_stream = reinterpret_cast(stream);
+ nccl_check(ncclAllReduce(sendbuf, recvbuf, count, to_nccl_dtype(dtype),
+ to_nccl_op(op), nccl_comm, cuda_stream));
+}
+
+void commBroadcast(void *buf, size_t count, llaisysDataType_t dtype, int root,
+ llaisysComm_t comm, llaisysStream_t stream) {
+ ncclComm_t nccl_comm = reinterpret_cast(comm);
+ cudaStream_t cuda_stream = reinterpret_cast(stream);
+ nccl_check(ncclBroadcast(buf, buf, count, to_nccl_dtype(dtype), root,
+ nccl_comm, cuda_stream));
+}
+
+void commSend(const void *buf, size_t count, llaisysDataType_t dtype, int peer,
+ llaisysComm_t comm, llaisysStream_t stream) {
+ ncclComm_t nccl_comm = reinterpret_cast(comm);
+ cudaStream_t cuda_stream = reinterpret_cast(stream);
+ nccl_check(ncclSend(buf, count, to_nccl_dtype(dtype), peer, nccl_comm,
+ cuda_stream));
+}
+
+void commRecv(void *buf, size_t count, llaisysDataType_t dtype, int peer,
+ llaisysComm_t comm, llaisysStream_t stream) {
+ ncclComm_t nccl_comm = reinterpret_cast(comm);
+ cudaStream_t cuda_stream = reinterpret_cast(stream);
+ nccl_check(ncclRecv(buf, count, to_nccl_dtype(dtype), peer, nccl_comm,
+ cuda_stream));
+}
+
+static const LlaisysCommAPI NCCL_COMM_API = {
+ &commInit,
+ &commDestroy,
+ &commGetRank,
+ &commGetSize,
+ &commAllreduce,
+ &commBroadcast,
+ &commSend,
+ &commRecv
+};
+
+const LlaisysCommAPI *getCommAPI() {
+ return &NCCL_COMM_API;
+}
+
+} // namespace nccl
+} // namespace llaisys::device::nvidia
+
+namespace llaisys::device::nccl {
+const LlaisysCommAPI *getCommAPI() {
+ return llaisys::device::nvidia::nccl::getCommAPI();
+}
+int commGenerateUniqueId(void *id_out, size_t *id_size) {
+ return llaisys::device::nvidia::nccl::commGenerateUniqueId(id_out, id_size);
+}
+}
diff --git a/src/device/nvidia/nvidia_runtime_api.cu b/src/device/nvidia/nvidia_runtime_api.cu
index cab928261..2c8bf713e 100644
--- a/src/device/nvidia/nvidia_runtime_api.cu
+++ b/src/device/nvidia/nvidia_runtime_api.cu
@@ -1,56 +1,100 @@
#include "../runtime_api.hpp"
+#include "cuda_utils.hpp"
-#include
-#include
+#include
namespace llaisys::device::nvidia {
namespace runtime_api {
int getDeviceCount() {
- TO_BE_IMPLEMENTED();
+ int count = 0;
+ cuda_check(cudaGetDeviceCount(&count));
+ return count;
}
-void setDevice(int) {
- TO_BE_IMPLEMENTED();
+void setDevice(int device_id) {
+ cuda_check(cudaSetDevice(device_id));
}
void deviceSynchronize() {
- TO_BE_IMPLEMENTED();
+ cuda_check(cudaDeviceSynchronize());
}
llaisysStream_t createStream() {
- TO_BE_IMPLEMENTED();
+ cudaStream_t stream{};
+ cuda_check(cudaStreamCreate(&stream));
+ return reinterpret_cast(stream);
}
void destroyStream(llaisysStream_t stream) {
- TO_BE_IMPLEMENTED();
+ cuda_check(cudaStreamDestroy(reinterpret_cast(stream)));
}
void streamSynchronize(llaisysStream_t stream) {
- TO_BE_IMPLEMENTED();
+ cuda_check(cudaStreamSynchronize(reinterpret_cast(stream)));
}
void *mallocDevice(size_t size) {
- TO_BE_IMPLEMENTED();
+ void *ptr = nullptr;
+ cuda_check(cudaMalloc(&ptr, size));
+ return ptr;
}
void freeDevice(void *ptr) {
- TO_BE_IMPLEMENTED();
+ cuda_check(cudaFree(ptr));
}
void *mallocHost(size_t size) {
- TO_BE_IMPLEMENTED();
+ void *ptr = nullptr;
+ cuda_check(cudaMallocHost(&ptr, size));
+ return ptr;
}
void freeHost(void *ptr) {
- TO_BE_IMPLEMENTED();
+ cuda_check(cudaFreeHost(ptr));
}
void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) {
- TO_BE_IMPLEMENTED();
+ cudaMemcpyKind cuda_kind = cudaMemcpyDefault;
+ switch (kind) {
+ case LLAISYS_MEMCPY_H2H:
+ cuda_kind = cudaMemcpyHostToHost;
+ break;
+ case LLAISYS_MEMCPY_H2D:
+ cuda_kind = cudaMemcpyHostToDevice;
+ break;
+ case LLAISYS_MEMCPY_D2H:
+ cuda_kind = cudaMemcpyDeviceToHost;
+ break;
+ case LLAISYS_MEMCPY_D2D:
+ cuda_kind = cudaMemcpyDeviceToDevice;
+ break;
+ default:
+ cuda_kind = cudaMemcpyDefault;
+ break;
+ }
+ cuda_check(cudaMemcpy(dst, src, size, cuda_kind));
}
-void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) {
- TO_BE_IMPLEMENTED();
+void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) {
+ cudaMemcpyKind cuda_kind = cudaMemcpyDefault;
+ switch (kind) {
+ case LLAISYS_MEMCPY_H2H:
+ cuda_kind = cudaMemcpyHostToHost;
+ break;
+ case LLAISYS_MEMCPY_H2D:
+ cuda_kind = cudaMemcpyHostToDevice;
+ break;
+ case LLAISYS_MEMCPY_D2H:
+ cuda_kind = cudaMemcpyDeviceToHost;
+ break;
+ case LLAISYS_MEMCPY_D2D:
+ cuda_kind = cudaMemcpyDeviceToDevice;
+ break;
+ default:
+ cuda_kind = cudaMemcpyDefault;
+ break;
+ }
+ cuda_check(cudaMemcpyAsync(dst, src, size, cuda_kind, reinterpret_cast(stream)));
}
static const LlaisysRuntimeAPI RUNTIME_API = {
diff --git a/src/device/runtime_api.cpp b/src/device/runtime_api.cpp
index 2de3eca02..1a8dfc6be 100644
--- a/src/device/runtime_api.cpp
+++ b/src/device/runtime_api.cpp
@@ -80,6 +80,12 @@ const LlaisysRuntimeAPI *getRuntimeAPI(llaisysDeviceType_t device_type) {
return llaisys::device::nvidia::getRuntimeAPI();
#else
return getUnsupportedRuntimeAPI();
+#endif
+ case LLAISYS_DEVICE_ILUVATAR:
+#ifdef ENABLE_ILUVATAR_API
+ return llaisys::device::iluvatar::getRuntimeAPI();
+#else
+ return getUnsupportedRuntimeAPI();
#endif
default:
EXCEPTION_UNSUPPORTED_DEVICE;
diff --git a/src/device/runtime_api.hpp b/src/device/runtime_api.hpp
index e6b9f80d6..12ebdc40f 100644
--- a/src/device/runtime_api.hpp
+++ b/src/device/runtime_api.hpp
@@ -17,4 +17,10 @@ namespace nvidia {
const LlaisysRuntimeAPI *getRuntimeAPI();
}
#endif
+
+#ifdef ENABLE_ILUVATAR_API
+namespace iluvatar {
+const LlaisysRuntimeAPI *getRuntimeAPI();
+}
+#endif
} // namespace llaisys::device
diff --git a/src/llaisys/comm.cc b/src/llaisys/comm.cc
new file mode 100644
index 000000000..d3b0c9c8c
--- /dev/null
+++ b/src/llaisys/comm.cc
@@ -0,0 +1,10 @@
+#include "llaisys/comm.h"
+#include "../device/comm_api.hpp"
+
+__C const LlaisysCommAPI *llaisysGetCommAPI(llaisysCommBackend_t backend) {
+ return llaisys::device::getCommAPI(backend);
+}
+
+__C int llaisysCommGenerateUniqueId(llaisysCommBackend_t backend, void *id_out, size_t *id_size) {
+ return llaisys::device::commGenerateUniqueId(backend, id_out, id_size);
+}
diff --git a/src/llaisys/models/qwen2.cpp b/src/llaisys/models/qwen2.cpp
new file mode 100644
index 000000000..5151c2537
--- /dev/null
+++ b/src/llaisys/models/qwen2.cpp
@@ -0,0 +1,479 @@
+// Qwen2 C API implementation (skeleton)
+#include "llaisys/models/qwen2.h"
+#include "../../models/qwen2/qwen2.hpp"
+#include "qwen2_kv_internal.hpp"
+
+#include
+#include
+#include
+#include
+#include
+
+struct LlaisysQwen2Model {
+ LlaisysQwen2Meta meta{};
+ LlaisysQwen2Weights weights{};
+ llaisysDeviceType_t device = LLAISYS_DEVICE_CPU;
+ std::vector device_ids;
+ std::unique_ptr impl;
+ LlaisysQwen2KVContext *kv_ctx = nullptr; // experimental, non-decoder path
+};
+
+static void init_layer_arrays(LlaisysQwen2Weights &w, size_t nlayer) {
+ w.attn_norm_w = new llaisysTensor_t[nlayer]();
+ w.attn_q_w = new llaisysTensor_t[nlayer]();
+ w.attn_q_b = new llaisysTensor_t[nlayer]();
+ w.attn_k_w = new llaisysTensor_t[nlayer]();
+ w.attn_k_b = new llaisysTensor_t[nlayer]();
+ w.attn_v_w = new llaisysTensor_t[nlayer]();
+ w.attn_v_b = new llaisysTensor_t[nlayer]();
+ w.attn_o_w = new llaisysTensor_t[nlayer]();
+ w.mlp_norm_w = new llaisysTensor_t[nlayer]();
+ w.mlp_gate_w = new llaisysTensor_t[nlayer]();
+ w.mlp_up_w = new llaisysTensor_t[nlayer]();
+ w.mlp_down_w = new llaisysTensor_t[nlayer]();
+}
+
+static void destroy_layer_arrays(LlaisysQwen2Weights &w, size_t nlayer) {
+ auto destroy_array = [nlayer](llaisysTensor_t *arr) {
+ if (!arr) return;
+ for (size_t i = 0; i < nlayer; ++i) {
+ if (arr[i]) {
+ tensorDestroy(arr[i]);
+ arr[i] = nullptr;
+ }
+ }
+ delete[] arr;
+ };
+
+ destroy_array(w.attn_norm_w);
+ destroy_array(w.attn_q_w);
+ destroy_array(w.attn_q_b);
+ destroy_array(w.attn_k_w);
+ destroy_array(w.attn_k_b);
+ destroy_array(w.attn_v_w);
+ destroy_array(w.attn_v_b);
+ destroy_array(w.attn_o_w);
+ destroy_array(w.mlp_norm_w);
+ destroy_array(w.mlp_gate_w);
+ destroy_array(w.mlp_up_w);
+ destroy_array(w.mlp_down_w);
+
+ w.attn_norm_w = nullptr;
+ w.attn_q_w = nullptr;
+ w.attn_q_b = nullptr;
+ w.attn_k_w = nullptr;
+ w.attn_k_b = nullptr;
+ w.attn_v_w = nullptr;
+ w.attn_v_b = nullptr;
+ w.attn_o_w = nullptr;
+ w.mlp_norm_w = nullptr;
+ w.mlp_gate_w = nullptr;
+ w.mlp_up_w = nullptr;
+ w.mlp_down_w = nullptr;
+}
+
+__C {
+ __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(
+ const LlaisysQwen2Meta *meta,
+ llaisysDeviceType_t device,
+ int *device_ids,
+ int ndevice) {
+ if (!meta || ndevice <= 0) return nullptr;
+
+ auto *model = new LlaisysQwen2Model();
+ model->meta = *meta;
+ model->device = device;
+ model->device_ids.assign(device_ids, device_ids + ndevice);
+
+ init_layer_arrays(model->weights, model->meta.nlayer);
+ model->impl = std::make_unique(
+ model->meta,
+ model->weights,
+ model->device,
+ model->device_ids);
+
+ return model;
+ }
+
+ //销毁千问2模型实例
+ __export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model *model) {
+ if (!model) return;
+
+ if (model->weights.in_embed) {
+ tensorDestroy(model->weights.in_embed);
+ model->weights.in_embed = nullptr;
+ }
+ if (model->weights.out_embed) {
+ tensorDestroy(model->weights.out_embed);
+ model->weights.out_embed = nullptr;
+ }
+ if (model->weights.out_norm_w) {
+ tensorDestroy(model->weights.out_norm_w);
+ model->weights.out_norm_w = nullptr;
+ }
+
+ destroy_layer_arrays(model->weights, model->meta.nlayer);
+ if (model->kv_ctx) {
+ llaisysQwen2KVContextRelease(model->kv_ctx);
+ model->kv_ctx = nullptr;
+ }
+
+ model->impl.reset();
+ delete model;
+ }
+
+
+ //获取千问2模型权重
+ __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model *model) {
+ if (!model) return nullptr;
+ return &model->weights;
+ }
+
+ //执行千问2模型推理
+ __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken) {
+ if (!model || !model->impl) return -1;
+ try {
+ return model->impl->infer(token_ids, ntoken);
+ } catch (const std::exception &e) {
+ std::cerr << "[ERROR] Qwen2 infer failed: " << e.what() << std::endl;
+ return -1;
+ } catch (...) {
+ std::cerr << "[ERROR] Qwen2 infer failed: unknown exception" << std::endl;
+ return -1;
+ }
+ }
+
+ __export int64_t llaisysQwen2ModelPrefill(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken) {
+ if (!model || !model->impl) return -1;
+ try {
+ return model->impl->prefill(token_ids, ntoken);
+ } catch (const std::exception &e) {
+ std::cerr << "[ERROR] Qwen2 prefill failed: " << e.what() << std::endl;
+ return -1;
+ } catch (...) {
+ std::cerr << "[ERROR] Qwen2 prefill failed: unknown exception" << std::endl;
+ return -1;
+ }
+ }
+
+ __export int64_t llaisysQwen2ModelStep(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken) {
+ if (!model || !model->impl) return -1;
+ try {
+ return model->impl->step(token_ids, ntoken);
+ } catch (const std::exception &e) {
+ std::cerr << "[ERROR] Qwen2 step failed: " << e.what() << std::endl;
+ return -1;
+ } catch (...) {
+ std::cerr << "[ERROR] Qwen2 step failed: unknown exception" << std::endl;
+ return -1;
+ }
+ }
+
+ __export int32_t llaisysQwen2ModelPrefillPacked(struct LlaisysQwen2Model *model,
+ int64_t *token_ids,
+ const int64_t *token_offsets,
+ size_t nseq,
+ int64_t *out_next_tokens) {
+ if (!model || !model->impl || !token_ids || !token_offsets || !out_next_tokens || nseq == 0) return -1;
+ try {
+ const size_t ntoken = static_cast(token_offsets[nseq]);
+ if (!model->impl->prefillPacked(token_ids, ntoken, token_offsets, nseq, out_next_tokens)) return -2;
+ return 0;
+ } catch (const std::exception &e) {
+ std::cerr << "[ERROR] Qwen2 prefill packed failed: " << e.what() << std::endl;
+ return -3;
+ } catch (...) {
+ std::cerr << "[ERROR] Qwen2 prefill packed failed: unknown exception" << std::endl;
+ return -4;
+ }
+ }
+
+ __export int32_t llaisysQwen2ModelStepPacked(struct LlaisysQwen2Model *model,
+ int64_t *token_ids,
+ const int64_t *token_offsets,
+ size_t nseq,
+ int64_t *out_next_tokens) {
+ if (!model || !model->impl || !token_ids || !token_offsets || !out_next_tokens || nseq == 0) return -1;
+ try {
+ const size_t ntoken = static_cast(token_offsets[nseq]);
+ if (!model->impl->stepPacked(token_ids, ntoken, token_offsets, nseq, out_next_tokens)) return -2;
+ return 0;
+ } catch (const std::exception &e) {
+ std::cerr << "[ERROR] Qwen2 step packed failed: " << e.what() << std::endl;
+ return -3;
+ } catch (...) {
+ std::cerr << "[ERROR] Qwen2 step packed failed: unknown exception" << std::endl;
+ return -4;
+ }
+ }
+
+ __export int64_t llaisysQwen2ModelPrefillSampling(struct LlaisysQwen2Model *model,
+ int64_t *token_ids,
+ size_t ntoken,
+ const LlaisysSamplingParams *params) {
+ if (!model || !model->impl) return -1;
+ try {
+ return model->impl->prefillSampling(token_ids, ntoken, params);
+ } catch (const std::exception &e) {
+ std::cerr << "[ERROR] Qwen2 prefill sampling failed: " << e.what() << std::endl;
+ return -1;
+ } catch (...) {
+ std::cerr << "[ERROR] Qwen2 prefill sampling failed: unknown exception" << std::endl;
+ return -1;
+ }
+ }
+
+ __export int64_t llaisysQwen2ModelStepSampling(struct LlaisysQwen2Model *model,
+ int64_t *token_ids,
+ size_t ntoken,
+ const LlaisysSamplingParams *params) {
+ if (!model || !model->impl) return -1;
+ try {
+ return model->impl->stepSampling(token_ids, ntoken, params);
+ } catch (const std::exception &e) {
+ std::cerr << "[ERROR] Qwen2 step sampling failed: " << e.what() << std::endl;
+ return -1;
+ } catch (...) {
+ std::cerr << "[ERROR] Qwen2 step sampling failed: unknown exception" << std::endl;
+ return -1;
+ }
+ }
+
+ __export int64_t llaisysQwen2ModelInferSampling(struct LlaisysQwen2Model *model,
+ int64_t *token_ids,
+ size_t ntoken,
+ const LlaisysSamplingParams *params) {
+ if (!model || !model->impl) return -1;
+ try {
+ return model->impl->prefillSampling(token_ids, ntoken, params);
+ } catch (const std::exception &e) {
+ std::cerr << "[ERROR] Qwen2 infer sampling failed: " << e.what() << std::endl;
+ return -1;
+ } catch (...) {
+ std::cerr << "[ERROR] Qwen2 infer sampling failed: unknown exception" << std::endl;
+ return -1;
+ }
+ }
+
+ __export int64_t llaisysQwen2ModelInferSamplingEx(struct LlaisysQwen2Model *model,
+ int64_t *token_ids,
+ size_t ntoken,
+ int32_t top_k,
+ float top_p,
+ float temperature,
+ uint32_t seed) {
+ if (!model || !model->impl) return -1;
+ LlaisysSamplingParams params{};
+ params.top_k = top_k;
+ params.top_p = top_p;
+ params.temperature = temperature;
+ params.seed = seed;
+ return llaisysQwen2ModelInferSampling(model, token_ids, ntoken, ¶ms);
+ }
+
+ __export void llaisysQwen2ModelResetKVCache(struct LlaisysQwen2Model *model) {
+ if (!model || !model->impl) return;
+ model->impl->resetKVCache();
+ }
+
+ __export void llaisysQwen2ModelSetKVCacheEnabled(struct LlaisysQwen2Model *model, uint8_t enabled) {
+ if (!model || !model->impl) return;
+ model->impl->setKVCacheEnabled(enabled != 0);
+ }
+
+ __export int32_t llaisysQwen2ModelSetTensorParallel(struct LlaisysQwen2Model *model,
+ llaisysComm_t comm,
+ llaisysStream_t stream,
+ int tp_size) {
+ if (!model || !model->impl) return -1;
+ model->impl->setTensorParallel(comm, stream, tp_size);
+ return 0;
+ }
+
+ __export struct LlaisysQwen2KVBlock *llaisysQwen2KVBlockCreate(
+ const struct LlaisysQwen2KVBlockMeta *meta,
+ llaisysDeviceType_t device,
+ int device_id) {
+ if (!meta || meta->nlayer == 0 || meta->max_tokens == 0) return nullptr;
+ auto *block = new LlaisysQwen2KVBlock();
+ block->meta = *meta;
+ block->device = device;
+ block->device_id = device_id;
+ block->k_layers.assign(meta->nlayer, nullptr);
+ block->v_layers.assign(meta->nlayer, nullptr);
+ size_t kv_shape[3] = {meta->max_tokens, meta->nkvh, meta->dh};
+ for (size_t layer = 0; layer < meta->nlayer; ++layer) {
+ block->k_layers[layer] = tensorCreate(kv_shape, 3, meta->dtype, device, device_id);
+ block->v_layers[layer] = tensorCreate(kv_shape, 3, meta->dtype, device, device_id);
+ if (!block->k_layers[layer] || !block->v_layers[layer]) {
+ for (auto *t : block->k_layers) {
+ if (t) tensorDestroy(t);
+ }
+ for (auto *t : block->v_layers) {
+ if (t) tensorDestroy(t);
+ }
+ delete block;
+ return nullptr;
+ }
+ }
+ return block;
+ }
+
+ __export void llaisysQwen2KVBlockRetain(struct LlaisysQwen2KVBlock *block) {
+ if (!block) return;
+ block->ref_count.fetch_add(1, std::memory_order_relaxed);
+ }
+
+ __export void llaisysQwen2KVBlockRelease(struct LlaisysQwen2KVBlock *block) {
+ if (!block) return;
+ if (block->ref_count.fetch_sub(1, std::memory_order_acq_rel) == 1) {
+ for (auto *t : block->k_layers) {
+ if (t) tensorDestroy(t);
+ }
+ for (auto *t : block->v_layers) {
+ if (t) tensorDestroy(t);
+ }
+ block->k_layers.clear();
+ block->v_layers.clear();
+ delete block;
+ }
+ }
+
+ __export int32_t llaisysQwen2KVBlockSetTokenCount(struct LlaisysQwen2KVBlock *block, size_t used_tokens) {
+ if (!block) return -1;
+ if (used_tokens > block->meta.max_tokens) return -2;
+ block->used_tokens = used_tokens;
+ return 0;
+ }
+
+ __export size_t llaisysQwen2KVBlockTokenCount(const struct LlaisysQwen2KVBlock *block) {
+ if (!block) return 0;
+ return block->used_tokens;
+ }
+
+ __export llaisysTensor_t llaisysQwen2KVBlockKeyTensor(struct LlaisysQwen2KVBlock *block, size_t layer) {
+ if (!block || layer >= block->k_layers.size()) return nullptr;
+ return block->k_layers[layer];
+ }
+
+ __export llaisysTensor_t llaisysQwen2KVBlockValueTensor(struct LlaisysQwen2KVBlock *block, size_t layer) {
+ if (!block || layer >= block->v_layers.size()) return nullptr;
+ return block->v_layers[layer];
+ }
+
+ __export struct LlaisysQwen2KVContext *llaisysQwen2KVContextCreate(
+ llaisysDataType_t dtype,
+ llaisysDeviceType_t device,
+ int device_id,
+ size_t nlayer,
+ size_t nh,
+ size_t nkvh,
+ size_t dh) {
+ if (nlayer == 0 || dh == 0) return nullptr;
+ auto *ctx = new LlaisysQwen2KVContext();
+ ctx->dtype = dtype;
+ ctx->device = device;
+ ctx->device_id = device_id;
+ ctx->nlayer = nlayer;
+ ctx->nh = nh;
+ ctx->nkvh = nkvh;
+ ctx->dh = dh;
+ return ctx;
+ }
+
+ __export void llaisysQwen2KVContextRetain(struct LlaisysQwen2KVContext *ctx) {
+ if (!ctx) return;
+ ctx->ref_count.fetch_add(1, std::memory_order_relaxed);
+ }
+
+ __export void llaisysQwen2KVContextRelease(struct LlaisysQwen2KVContext *ctx) {
+ if (!ctx) return;
+ if (ctx->ref_count.fetch_sub(1, std::memory_order_acq_rel) == 1) {
+ for (auto *blk : ctx->chain) {
+ llaisysQwen2KVBlockRelease(blk);
+ }
+ ctx->chain.clear();
+ delete ctx;
+ }
+ }
+
+ __export int32_t llaisysQwen2KVContextAttachBlock(
+ struct LlaisysQwen2KVContext *ctx,
+ struct LlaisysQwen2KVBlock *block) {
+ if (!ctx || !block) return -1;
+ if (ctx->device != block->device || ctx->device_id != block->device_id) return -2;
+ if (ctx->dtype != block->meta.dtype) return -3;
+ if (ctx->nlayer != block->meta.nlayer || ctx->dh != block->meta.dh) return -4;
+ if (ctx->nkvh != block->meta.nkvh || ctx->nh != block->meta.nh) return -5;
+ llaisysQwen2KVBlockRetain(block);
+ ctx->chain.push_back(block);
+ return 0;
+ }
+
+ __export void llaisysQwen2KVContextDetachAll(struct LlaisysQwen2KVContext *ctx) {
+ if (!ctx) return;
+ for (auto *blk : ctx->chain) {
+ llaisysQwen2KVBlockRelease(blk);
+ }
+ ctx->chain.clear();
+ }
+
+ __export size_t llaisysQwen2KVContextBlockCount(const struct LlaisysQwen2KVContext *ctx) {
+ if (!ctx) return 0;
+ return ctx->chain.size();
+ }
+
+ __export size_t llaisysQwen2KVContextTokenCount(const struct LlaisysQwen2KVContext *ctx) {
+ if (!ctx) return 0;
+ size_t total = 0;
+ for (auto *blk : ctx->chain) {
+ if (!blk) continue;
+ total += std::min(blk->used_tokens, blk->meta.max_tokens);
+ }
+ return total;
+ }
+
+ __export int32_t llaisysQwen2ModelSetKVContext(
+ struct LlaisysQwen2Model *model,
+ struct LlaisysQwen2KVContext *ctx) {
+ if (!model) return -1;
+ if (ctx) {
+ if (model->device != ctx->device) return -2;
+ const int model_device_id = model->device_ids.empty() ? 0 : model->device_ids[0];
+ if (model_device_id != ctx->device_id) return -3;
+ llaisysQwen2KVContextRetain(ctx);
+ }
+ if (model->kv_ctx) {
+ llaisysQwen2KVContextRelease(model->kv_ctx);
+ }
+ model->kv_ctx = ctx;
+ if (model->impl) {
+ const size_t past_len_tokens = llaisysQwen2KVContextTokenCount(ctx);
+ model->impl->setKVContext(ctx, past_len_tokens);
+ }
+ return 0;
+ }
+
+ __export struct LlaisysQwen2KVContext *llaisysQwen2ModelGetKVContext(
+ struct LlaisysQwen2Model *model) {
+ if (!model) return nullptr;
+ auto *ctx = model->kv_ctx;
+ if (model->impl) {
+ ctx = reinterpret_cast(model->impl->getKVContext());
+ }
+ if (!ctx) return nullptr;
+ llaisysQwen2KVContextRetain(ctx);
+ return ctx;
+ }
+
+ __export int32_t llaisysQwen2ModelExportKVContext(
+ struct LlaisysQwen2Model *model,
+ struct LlaisysQwen2KVContext *ctx,
+ size_t block_tokens) {
+ if (!model || !model->impl || !ctx) return -1;
+ if (model->device != ctx->device) return -2;
+ const int model_device_id = model->device_ids.empty() ? 0 : model->device_ids[0];
+ if (model_device_id != ctx->device_id) return -3;
+ return static_cast(model->impl->exportKVContext(ctx, block_tokens));
+ }
+}
diff --git a/src/llaisys/models/qwen2_kv_internal.hpp b/src/llaisys/models/qwen2_kv_internal.hpp
new file mode 100644
index 000000000..7c83bd071
--- /dev/null
+++ b/src/llaisys/models/qwen2_kv_internal.hpp
@@ -0,0 +1,28 @@
+#pragma once
+
+#include "llaisys/models/qwen2.h"
+
+#include
+#include
+
+struct LlaisysQwen2KVBlock {
+ LlaisysQwen2KVBlockMeta meta{};
+ llaisysDeviceType_t device = LLAISYS_DEVICE_CPU;
+ int device_id = 0;
+ size_t used_tokens = 0;
+ std::vector k_layers;
+ std::vector v_layers;
+ std::atomic ref_count{1};
+};
+
+struct LlaisysQwen2KVContext {
+ llaisysDataType_t dtype = LLAISYS_DTYPE_F32;
+ llaisysDeviceType_t device = LLAISYS_DEVICE_CPU;
+ int device_id = 0;
+ size_t nlayer = 0;
+ size_t nh = 0;
+ size_t nkvh = 0;
+ size_t dh = 0;
+ std::vector chain;
+ std::atomic ref_count{1};
+};
diff --git a/src/llaisys/ops.cc b/src/llaisys/ops.cc
index c99fbc32f..625887cac 100644
--- a/src/llaisys/ops.cc
+++ b/src/llaisys/ops.cc
@@ -23,7 +23,10 @@ __C {
llaisys::ops::embedding(out->tensor, index->tensor, weight->tensor);
}
void llaisysLinear(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, llaisysTensor_t bias) {
- llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, bias->tensor);
+ llaisys::ops::linear(out->tensor,
+ in->tensor,
+ weight->tensor,
+ bias ? bias->tensor : nullptr);
}
void llaisysRearrange(llaisysTensor_t out, llaisysTensor_t in) {
llaisys::ops::rearrange(out->tensor, in->tensor);
@@ -37,6 +40,24 @@ __C {
void llaisysSelfAttention(llaisysTensor_t attn_val, llaisysTensor_t q, llaisysTensor_t k, llaisysTensor_t v, float scale) {
llaisys::ops::self_attention(attn_val->tensor, q->tensor, k->tensor, v->tensor, scale);
}
+ void llaisysSelfAttentionSegmented(llaisysTensor_t attn_val,
+ llaisysTensor_t q,
+ llaisysTensor_t k,
+ llaisysTensor_t v,
+ float scale,
+ const int64_t *q_offsets,
+ const int64_t *kv_offsets,
+ size_t nseg) {
+ llaisys::ops::self_attention_segmented(
+ attn_val->tensor,
+ q->tensor,
+ k->tensor,
+ v->tensor,
+ scale,
+ q_offsets,
+ kv_offsets,
+ nseg);
+ }
void llaisysSwiGLU(llaisysTensor_t out, llaisysTensor_t gate, llaisysTensor_t up) {
llaisys::ops::swiglu(out->tensor, gate->tensor, up->tensor);
}
diff --git a/src/llaisys/tokenizer.cc b/src/llaisys/tokenizer.cc
new file mode 100644
index 000000000..c15abd11e
--- /dev/null
+++ b/src/llaisys/tokenizer.cc
@@ -0,0 +1,84 @@
+#include "llaisys/tokenizer.h"
+
+#include "../tokenizer/sentencepiece/sentencepiece.hpp"
+
+#include
+#include
+#include
+#include
+
+struct LlaisysTokenizer {
+ std::unique_ptr impl;
+};
+
+__C {
+__export struct LlaisysTokenizer *llaisysTokenizerCreateSentencePiece(const char *model_path) {
+ if (!model_path || model_path[0] == '\0') return nullptr;
+ auto tokenizer = std::make_unique();
+ tokenizer->impl = std::make_unique(model_path);
+ if (!tokenizer->impl || !tokenizer->impl->isLoaded()) {
+ return nullptr;
+ }
+ return tokenizer.release();
+}
+
+__export void llaisysTokenizerDestroy(struct LlaisysTokenizer *tokenizer) {
+ delete tokenizer;
+}
+
+__export int llaisysTokenizerEncode(struct LlaisysTokenizer *tokenizer,
+ const char *text,
+ int64_t *out_ids,
+ size_t max_ids) {
+ if (!tokenizer || !tokenizer->impl || !text) return -1;
+ std::vector ids;
+ if (!tokenizer->impl->encode(text, ids)) return -1;
+ if (!out_ids || max_ids == 0) {
+ return static_cast(ids.size());
+ }
+ const size_t n = ids.size() < max_ids ? ids.size() : max_ids;
+ for (size_t i = 0; i < n; ++i) out_ids[i] = ids[i];
+ return static_cast(n);
+}
+
+__export int llaisysTokenizerDecode(struct LlaisysTokenizer *tokenizer,
+ const int64_t *ids,
+ size_t len,
+ char *out_text,
+ size_t max_len) {
+ if (!tokenizer || !tokenizer->impl) return -1;
+ std::string text;
+ if (!tokenizer->impl->decode(ids, len, text)) return -1;
+ if (!out_text || max_len == 0) {
+ return static_cast(text.size() + 1);
+ }
+ const size_t n = text.size() < (max_len - 1) ? text.size() : (max_len - 1);
+ std::memcpy(out_text, text.data(), n);
+ out_text[n] = '\0';
+ return static_cast(n);
+}
+
+__export int64_t llaisysTokenizerTokenToId(struct LlaisysTokenizer *tokenizer, const char *token) {
+ if (!tokenizer || !tokenizer->impl || !token) return -1;
+ int64_t id = -1;
+ if (!tokenizer->impl->pieceToId(token, id)) return -1;
+ return id;
+}
+
+__export int llaisysTokenizerIdToToken(struct LlaisysTokenizer *tokenizer,
+ int64_t id,
+ char *out_token,
+ size_t max_len) {
+ if (!tokenizer || !tokenizer->impl) return -1;
+ std::string piece;
+ if (!tokenizer->impl->idToPiece(id, piece)) return -1;
+ if (!out_token || max_len == 0) {
+ return static_cast(piece.size() + 1);
+ }
+ const size_t n = piece.size() < (max_len - 1) ? piece.size() : (max_len - 1);
+ std::memcpy(out_token, piece.data(), n);
+ out_token[n] = '\0';
+ return static_cast(n);
+}
+
+}
diff --git a/src/models/qwen2/qwen2.cpp b/src/models/qwen2/qwen2.cpp
new file mode 100644
index 000000000..0082aba8f
--- /dev/null
+++ b/src/models/qwen2/qwen2.cpp
@@ -0,0 +1,500 @@
+#include "qwen2.hpp"
+
+#include "llaisys/ops.h"
+
+#include "../../utils.hpp"
+#include "../../core/context/context.hpp"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace llaisys::models {
+Qwen2::Qwen2(const LlaisysQwen2Meta &meta,
+ const LlaisysQwen2Weights &weights,
+ llaisysDeviceType_t device,
+ const std::vector &device_ids)
+ : _meta(meta),
+ _weights(&weights),
+ _device(device),
+ _device_ids(device_ids),
+ _decoder(transformer::DecoderConfig{
+ meta.dtype,
+ meta.nlayer,
+ meta.hs,
+ meta.nh,
+ meta.nkvh,
+ meta.dh,
+ meta.di,
+ meta.maxseq,
+ meta.voc,
+ meta.epsilon,
+ meta.theta},
+ &weights,
+ device,
+ device_ids) {}
+
+Qwen2::~Qwen2() {
+ clearPackedState();
+}
+
+void Qwen2::resetKVCache() {
+ clearPackedState();
+ _decoder.resetKVCache();
+}
+
+void Qwen2::setKVCacheEnabled(bool enabled) {
+ _decoder.setKVCacheEnabled(enabled);
+}
+
+void Qwen2::setTensorParallel(llaisysComm_t comm, llaisysStream_t stream, int tp_size) {
+ _decoder.setTensorParallel(comm, stream, tp_size);
+}
+
+void Qwen2::setKVContext(void *ctx, size_t past_len_tokens) {
+ clearPackedState();
+ _kv_ctx = ctx;
+ if (ctx) {
+ _decoder.bindExternalKVContext(ctx, past_len_tokens);
+ } else {
+ _decoder.clearExternalKVContext();
+ }
+}
+
+void *Qwen2::getKVContext() const {
+ return _kv_ctx;
+}
+
+int Qwen2::exportKVContext(void *ctx, size_t block_tokens) {
+ return _decoder.exportKVContext(ctx, block_tokens);
+}
+
+void Qwen2::clearPackedState() {
+ for (auto *ctx : _packed_kv_contexts) {
+ if (ctx) {
+ ::llaisysQwen2KVContextRelease(ctx);
+ }
+ }
+ _packed_kv_contexts.clear();
+ _packed_prompts.clear();
+}
+
+//执行千问2模型推理
+static int64_t argmax_from_logits(llaisysTensor_t logits,
+ llaisysDataType_t dtype,
+ llaisysDeviceType_t device,
+ int device_id) {
+ int64_t next_token = -1;
+ size_t one_shape[1] = {1};
+ llaisysTensor_t max_idx = tensorCreate(one_shape, 1, LLAISYS_DTYPE_I64, device, device_id);
+ llaisysTensor_t max_val = tensorCreate(one_shape, 1, dtype, device, device_id);
+ if (!max_idx || !max_val) {
+ if (max_idx) tensorDestroy(max_idx);
+ if (max_val) tensorDestroy(max_val);
+ return -1;
+ }
+ ::llaisysArgmax(max_idx, max_val, logits);
+ if (tensorGetDeviceType(max_idx) == LLAISYS_DEVICE_CPU) {
+ next_token = *reinterpret_cast(tensorGetData(max_idx));
+ } else {
+ int64_t host_val = -1;
+ llaisys::core::context().setDevice(device, device_id);
+ llaisys::core::context().runtime().api()->memcpy_sync(
+ &host_val,
+ tensorGetData(max_idx),
+ sizeof(int64_t),
+ LLAISYS_MEMCPY_D2H);
+ next_token = host_val;
+ }
+ tensorDestroy(max_idx);
+ tensorDestroy(max_val);
+ return next_token;
+}
+
+static std::vector logits_to_host(llaisysTensor_t logits,
+ llaisysDataType_t dtype,
+ llaisysDeviceType_t device,
+ int device_id,
+ size_t vocab) {
+ std::vector host(vocab, 0.0f);
+ const size_t bytes = vocab * utils::dsize(dtype);
+ if (device == LLAISYS_DEVICE_CPU) {
+ const std::byte *src = reinterpret_cast(tensorGetData(logits));
+ if (dtype == LLAISYS_DTYPE_F32) {
+ const float *vals = reinterpret_cast(src);
+ for (size_t i = 0; i < vocab; ++i) {
+ host[i] = vals[i];
+ }
+ } else if (dtype == LLAISYS_DTYPE_F16) {
+ const fp16_t *vals = reinterpret_cast(src);
+ for (size_t i = 0; i < vocab; ++i) {
+ host[i] = utils::cast(vals[i]);
+ }
+ } else if (dtype == LLAISYS_DTYPE_BF16) {
+ const bf16_t *vals = reinterpret_cast(src);
+ for (size_t i = 0; i < vocab; ++i) {
+ host[i] = utils::cast(vals[i]);
+ }
+ } else {
+ EXCEPTION_UNSUPPORTED_DATATYPE(dtype);
+ }
+ return host;
+ }
+
+ std::vector tmp(bytes);
+ llaisys::core::context().setDevice(device, device_id);
+ llaisys::core::context().runtime().api()->memcpy_sync(
+ tmp.data(), tensorGetData(logits), bytes, LLAISYS_MEMCPY_D2H);
+
+ if (dtype == LLAISYS_DTYPE_F32) {
+ const float *vals = reinterpret_cast(tmp.data());
+ for (size_t i = 0; i < vocab; ++i) {
+ host[i] = vals[i];
+ }
+ } else if (dtype == LLAISYS_DTYPE_F16) {
+ const fp16_t *vals = reinterpret_cast(tmp.data());
+ for (size_t i = 0; i < vocab; ++i) {
+ host[i] = utils::cast(vals[i]);
+ }
+ } else if (dtype == LLAISYS_DTYPE_BF16) {
+ const bf16_t *vals = reinterpret_cast(tmp.data());
+ for (size_t i = 0; i < vocab; ++i) {
+ host[i] = utils::cast(vals[i]);
+ }
+ } else {
+ EXCEPTION_UNSUPPORTED_DATATYPE(dtype);
+ }
+ return host;
+}
+
+static int64_t sample_from_logits(const std::vector &logits,
+ const LlaisysSamplingParams *params) {
+ const size_t vocab = logits.size();
+ if (vocab == 0) {
+ return -1;
+ }
+
+ int top_k = params ? params->top_k : 1;
+ float top_p = params ? params->top_p : 0.0f;
+ float temperature = params ? params->temperature : 0.0f;
+ uint32_t seed = params ? params->seed : 0u;
+
+ if (temperature <= 0.0f && top_k <= 1 && top_p <= 0.0f) {
+ return static_cast(std::distance(logits.begin(),
+ std::max_element(logits.begin(), logits.end())));
+ }
+
+ std::vector indices(vocab);
+ std::iota(indices.begin(), indices.end(), 0);
+
+ if (top_k > 0 && static_cast(top_k) < vocab) {
+ std::partial_sort(indices.begin(), indices.begin() + top_k, indices.end(),
+ [&](int a, int b) { return logits[a] > logits[b]; });
+ indices.resize(top_k);
+ }
+
+ const float temp = temperature > 0.0f ? temperature : 1.0f;
+ std::vector filtered_logits;
+ filtered_logits.reserve(indices.size());
+ for (int idx : indices) {
+ filtered_logits.push_back(logits[idx] / std::max(temp, 1e-6f));
+ }
+
+ float max_logit = *std::max_element(filtered_logits.begin(), filtered_logits.end());
+ std::vector probs(filtered_logits.size());
+ float sum = 0.0f;
+ for (size_t i = 0; i < filtered_logits.size(); ++i) {
+ probs[i] = std::exp(filtered_logits[i] - max_logit);
+ sum += probs[i];
+ }
+ if (sum <= 0.0f) {
+ return indices.front();
+ }
+ for (float &p : probs) {
+ p /= sum;
+ }
+
+ if (top_p > 0.0f && top_p < 1.0f) {
+ std::vector order(probs.size());
+ std::iota(order.begin(), order.end(), 0);
+ std::sort(order.begin(), order.end(), [&](size_t a, size_t b) { return probs[a] > probs[b]; });
+ float cumulative = 0.0f;
+ size_t keep = 0;
+ for (size_t idx : order) {
+ cumulative += probs[idx];
+ keep++;
+ if (cumulative >= top_p) {
+ break;
+ }
+ }
+ std::vector new_indices;
+ std::vector new_probs;
+ new_indices.reserve(keep);
+ new_probs.reserve(keep);
+ for (size_t i = 0; i < keep; ++i) {
+ size_t idx = order[i];
+ new_indices.push_back(indices[idx]);
+ new_probs.push_back(probs[idx]);
+ }
+ indices.swap(new_indices);
+ probs.swap(new_probs);
+ float new_sum = std::accumulate(probs.begin(), probs.end(), 0.0f);
+ if (new_sum > 0.0f) {
+ for (float &p : probs) {
+ p /= new_sum;
+ }
+ }
+ }
+
+ std::mt19937 rng(seed == 0 ? std::random_device{}() : seed);
+ std::uniform_real_distribution dist(0.0f, 1.0f);
+ float r = dist(rng);
+ float cumulative = 0.0f;
+ for (size_t i = 0; i < probs.size(); ++i) {
+ cumulative += probs[i];
+ if (r <= cumulative) {
+ return indices[i];
+ }
+ }
+ return indices.back();
+}
+
+static int64_t next_token_from_logits(llaisysTensor_t logits,
+ llaisysDataType_t dtype,
+ llaisysDeviceType_t device,
+ int device_id,
+ size_t vocab,
+ const LlaisysSamplingParams *params) {
+ if (!params) {
+ return argmax_from_logits(logits, dtype, device, device_id);
+ }
+ auto host_logits = logits_to_host(logits, dtype, device, device_id, vocab);
+ return sample_from_logits(host_logits, params);
+}
+
+int64_t Qwen2::infer(const int64_t *token_ids, size_t ntoken) {
+ return prefill(token_ids, ntoken);
+}
+
+int64_t Qwen2::prefill(const int64_t *token_ids, size_t ntoken) {
+ if (!token_ids || ntoken == 0) return -1;
+ clearPackedState();
+
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ size_t logits_shape[2] = {1, _meta.voc};
+ llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id);
+ if (!logits) return -1;
+ if (!_decoder.prefill(token_ids, ntoken, logits)) {
+ tensorDestroy(logits);
+ return -1;
+ }
+
+ int64_t next_token = argmax_from_logits(logits, _meta.dtype, _device, device_id);
+ tensorDestroy(logits);
+
+ return next_token;
+}
+
+int64_t Qwen2::step(const int64_t *token_ids, size_t ntoken) {
+ if (!token_ids || ntoken == 0) return -1;
+ clearPackedState();
+
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ size_t logits_shape[2] = {1, _meta.voc};
+ llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id);
+ if (!logits) return -1;
+ if (!_decoder.decodeStep(token_ids, ntoken, logits)) {
+ tensorDestroy(logits);
+ return -1;
+ }
+
+ int64_t next_token = argmax_from_logits(logits, _meta.dtype, _device, device_id);
+ tensorDestroy(logits);
+ return next_token;
+}
+
+bool Qwen2::prefillPacked(const int64_t *token_ids,
+ size_t ntoken,
+ const int64_t *token_offsets,
+ size_t nseq,
+ int64_t *out_next_tokens) {
+ if (!token_ids || !token_offsets || nseq == 0 || ntoken == 0 || !out_next_tokens) return false;
+ clearPackedState();
+ if (token_offsets[0] != 0 || static_cast(token_offsets[nseq]) != ntoken) return false;
+ for (size_t i = 0; i < nseq; ++i) {
+ if (token_offsets[i] >= token_offsets[i + 1]) return false;
+ }
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ size_t logits_shape[2] = {nseq, _meta.voc};
+ llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id);
+ if (!logits) return false;
+ if (!_decoder.prefillPacked(token_ids, ntoken, token_offsets, nseq, logits)) {
+ tensorDestroy(logits);
+ return false;
+ }
+ for (size_t i = 0; i < nseq; ++i) {
+ llaisysTensor_t row = tensorSlice(logits, 0, i, i + 1);
+ if (!row) {
+ tensorDestroy(logits);
+ return false;
+ }
+ out_next_tokens[i] = argmax_from_logits(row, _meta.dtype, _device, device_id);
+ tensorDestroy(row);
+ }
+ tensorDestroy(logits);
+
+ _packed_prompts.resize(nseq);
+ for (size_t i = 0; i < nseq; ++i) {
+ const size_t begin = static_cast(token_offsets[i]);
+ const size_t end = static_cast(token_offsets[i + 1]);
+ _packed_prompts[i].assign(token_ids + begin, token_ids + end);
+ }
+
+ // Build per-sequence KV snapshots once after packed prefill.
+ constexpr size_t kPackedBlockTokens = 64;
+ _packed_kv_contexts.assign(nseq, nullptr);
+ size_t single_logits_shape[2] = {1, _meta.voc};
+ llaisysTensor_t single_logits = tensorCreate(single_logits_shape, 2, _meta.dtype, _device, device_id);
+ if (!single_logits) {
+ clearPackedState();
+ return false;
+ }
+ for (size_t i = 0; i < nseq; ++i) {
+ _decoder.resetKVCache();
+ _decoder.clearExternalKVContext();
+ const auto &prompt = _packed_prompts[i];
+ if (prompt.empty()) {
+ tensorDestroy(single_logits);
+ clearPackedState();
+ return false;
+ }
+ if (!_decoder.prefill(prompt.data(), prompt.size(), single_logits)) {
+ tensorDestroy(single_logits);
+ clearPackedState();
+ return false;
+ }
+ auto *ctx = ::llaisysQwen2KVContextCreate(
+ _meta.dtype,
+ _device,
+ device_id,
+ _meta.nlayer,
+ _meta.nh,
+ _meta.nkvh,
+ _meta.dh);
+ if (!ctx) {
+ tensorDestroy(single_logits);
+ clearPackedState();
+ return false;
+ }
+ if (_decoder.exportKVContext(ctx, kPackedBlockTokens) != 0) {
+ ::llaisysQwen2KVContextRelease(ctx);
+ tensorDestroy(single_logits);
+ clearPackedState();
+ return false;
+ }
+ _packed_kv_contexts[i] = ctx;
+ }
+ _decoder.clearExternalKVContext();
+ _decoder.resetKVCache();
+ tensorDestroy(single_logits);
+ return true;
+}
+
+bool Qwen2::stepPacked(const int64_t *token_ids,
+ size_t ntoken,
+ const int64_t *token_offsets,
+ size_t nseq,
+ int64_t *out_next_tokens) {
+ if (!token_ids || !token_offsets || nseq == 0 || !out_next_tokens) return false;
+ if (token_offsets[0] != 0 || static_cast(token_offsets[nseq]) != ntoken) return false;
+ for (size_t i = 0; i < nseq; ++i) {
+ if (token_offsets[i] >= token_offsets[i + 1]) return false;
+ }
+ if (_packed_prompts.size() != nseq || _packed_kv_contexts.size() != nseq) return false;
+
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ constexpr size_t kPackedBlockTokens = 64;
+ std::vector step_tokens(nseq, 0);
+ for (size_t i = 0; i < nseq; ++i) {
+ const size_t begin = static_cast(token_offsets[i]);
+ const size_t end = static_cast(token_offsets[i + 1]);
+ const size_t step_len = end - begin;
+ if (step_len != 1) return false;
+ step_tokens[i] = token_ids[begin];
+ }
+ std::vector contexts(nseq, nullptr);
+ for (size_t i = 0; i < nseq; ++i) {
+ contexts[i] = _packed_kv_contexts[i];
+ if (!contexts[i]) return false;
+ }
+
+ size_t logits_shape[2] = {nseq, _meta.voc};
+ llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id);
+ if (!logits) return false;
+ if (!_decoder.decodePacked(step_tokens.data(), nseq, contexts, logits, kPackedBlockTokens)) {
+ tensorDestroy(logits);
+ clearPackedState();
+ return false;
+ }
+ for (size_t i = 0; i < nseq; ++i) {
+ llaisysTensor_t row = tensorSlice(logits, 0, i, i + 1);
+ if (!row) {
+ tensorDestroy(logits);
+ clearPackedState();
+ return false;
+ }
+ out_next_tokens[i] = argmax_from_logits(row, _meta.dtype, _device, device_id);
+ tensorDestroy(row);
+ if (out_next_tokens[i] < 0) {
+ tensorDestroy(logits);
+ clearPackedState();
+ return false;
+ }
+ _packed_prompts[i].push_back(step_tokens[i]);
+ _packed_prompts[i].push_back(out_next_tokens[i]);
+ }
+ tensorDestroy(logits);
+ return true;
+}
+
+int64_t Qwen2::prefillSampling(const int64_t *token_ids, size_t ntoken, const LlaisysSamplingParams *params) {
+ if (!token_ids || ntoken == 0) return -1;
+ clearPackedState();
+
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ size_t logits_shape[2] = {1, _meta.voc};
+ llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id);
+ if (!logits) return -1;
+ if (!_decoder.prefill(token_ids, ntoken, logits)) {
+ tensorDestroy(logits);
+ return -1;
+ }
+
+ int64_t next_token = next_token_from_logits(logits, _meta.dtype, _device, device_id, _meta.voc, params);
+ tensorDestroy(logits);
+ return next_token;
+}
+
+int64_t Qwen2::stepSampling(const int64_t *token_ids, size_t ntoken, const LlaisysSamplingParams *params) {
+ if (!token_ids || ntoken == 0) return -1;
+ clearPackedState();
+
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ size_t logits_shape[2] = {1, _meta.voc};
+ llaisysTensor_t logits = tensorCreate(logits_shape, 2, _meta.dtype, _device, device_id);
+ if (!logits) return -1;
+ if (!_decoder.decodeStep(token_ids, ntoken, logits)) {
+ tensorDestroy(logits);
+ return -1;
+ }
+
+ int64_t next_token = next_token_from_logits(logits, _meta.dtype, _device, device_id, _meta.voc, params);
+ tensorDestroy(logits);
+ return next_token;
+}
+} // namespace llaisys::models
diff --git a/src/models/qwen2/qwen2.hpp b/src/models/qwen2/qwen2.hpp
new file mode 100644
index 000000000..5f437356c
--- /dev/null
+++ b/src/models/qwen2/qwen2.hpp
@@ -0,0 +1,54 @@
+#pragma once
+
+#include "llaisys/models/qwen2.h"
+#include "llaisys/tensor.h"
+#include "../transformer/decoder/decoder.hpp"
+
+#include
+#include
+
+namespace llaisys::models {
+class Qwen2 {
+public:
+ Qwen2(const LlaisysQwen2Meta &meta,
+ const LlaisysQwen2Weights &weights,
+ llaisysDeviceType_t device,
+ const std::vector &device_ids);
+ ~Qwen2();
+
+ // Compatibility entrypoint; prefer prefill/step for streaming.
+ int64_t infer(const int64_t *token_ids, size_t ntoken);
+ int64_t prefill(const int64_t *token_ids, size_t ntoken);
+ int64_t step(const int64_t *token_ids, size_t ntoken);
+ bool prefillPacked(const int64_t *token_ids,
+ size_t ntoken,
+ const int64_t *token_offsets,
+ size_t nseq,
+ int64_t *out_next_tokens);
+ bool stepPacked(const int64_t *token_ids,
+ size_t ntoken,
+ const int64_t *token_offsets,
+ size_t nseq,
+ int64_t *out_next_tokens);
+ int64_t prefillSampling(const int64_t *token_ids, size_t ntoken, const LlaisysSamplingParams *params);
+ int64_t stepSampling(const int64_t *token_ids, size_t ntoken, const LlaisysSamplingParams *params);
+ void resetKVCache();
+ void setKVCacheEnabled(bool enabled);
+ void setTensorParallel(llaisysComm_t comm, llaisysStream_t stream, int tp_size);
+ void setKVContext(void *ctx, size_t past_len_tokens = 0);
+ void *getKVContext() const;
+ int exportKVContext(void *ctx, size_t block_tokens);
+
+private:
+ void clearPackedState();
+
+ LlaisysQwen2Meta _meta{};
+ const LlaisysQwen2Weights *_weights{nullptr};
+ llaisysDeviceType_t _device{LLAISYS_DEVICE_CPU};
+ std::vector _device_ids;
+ transformer::Decoder _decoder;
+ void *_kv_ctx{nullptr};
+ std::vector _packed_kv_contexts;
+ std::vector> _packed_prompts;
+};
+} // namespace llaisys::models
diff --git a/src/models/transformer/decoder/decoder.cpp b/src/models/transformer/decoder/decoder.cpp
new file mode 100644
index 000000000..ce97a9dd1
--- /dev/null
+++ b/src/models/transformer/decoder/decoder.cpp
@@ -0,0 +1,1235 @@
+#include "decoder.hpp"
+#include "../../../llaisys/models/qwen2_kv_internal.hpp"
+#include "../../../device/comm_api.hpp"
+
+#include "llaisys/ops.h"
+
+#include
+#include
+#include
+
+namespace llaisys::models::transformer {
+namespace {
+bool trace_enabled() {
+ static bool enabled = false;
+ static bool inited = false;
+ if (!inited) {
+#if defined(_WIN32)
+ char *value = nullptr;
+ size_t len = 0;
+ if (_dupenv_s(&value, &len, "LLAISYS_QWEN2_TRACE") == 0 && value) {
+ if (value[0] != '\0' && value[0] != '0') enabled = true;
+ free(value);
+ }
+#else
+ const char *value = std::getenv("LLAISYS_QWEN2_TRACE");
+ if (value && value[0] != '\0' && value[0] != '0') enabled = true;
+#endif
+ inited = true;
+ }
+ return enabled;
+}
+
+void trace(const char *stage) {
+ if (trace_enabled()) {
+ std::cerr << "[TRACE] Decoder forward: " << stage << std::endl;
+ }
+}
+
+bool require_tensor(llaisysTensor_t t, const char *stage) {
+ if (t) return true;
+ std::cerr << "[ERROR] Decoder: tensorCreate failed at " << stage << std::endl;
+ return false;
+}
+
+bool ensure_data(llaisysTensor_t t, const char *stage) {
+ if (!t) {
+ std::cerr << "[ERROR] Decoder: null tensor at " << stage << std::endl;
+ return false;
+ }
+ if (!tensorGetData(t)) {
+ std::cerr << "[ERROR] Decoder: null data at " << stage << std::endl;
+ return false;
+ }
+ return true;
+}
+
+void destroy_if_not_null(llaisysTensor_t t) {
+ if (t) tensorDestroy(t);
+}
+} // namespace
+
+Decoder::Decoder(const DecoderConfig &config,
+ const LlaisysQwen2Weights *weights,
+ llaisysDeviceType_t device,
+ const std::vector &device_ids)
+ : _config(config),
+ _weights(weights),
+ _device(device),
+ _device_ids(device_ids) {}
+
+void Decoder::setTensorParallel(llaisysComm_t comm, llaisysStream_t stream, int tp_size) {
+ _comm = comm;
+ _comm_stream = stream;
+ _tp_size = tp_size > 0 ? tp_size : 1;
+}
+
+Decoder::~Decoder() {
+ releaseCache();
+}
+
+void Decoder::ensureCache() {
+ if (!_kv_cache_enabled || _cache_inited || _config.maxseq == 0 || _config.nlayer == 0) return;
+ _k_cache.assign(_config.nlayer, nullptr);
+ _v_cache.assign(_config.nlayer, nullptr);
+
+ size_t kv_shape[3] = {_config.maxseq, _config.nkvh, _config.dh};
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ for (size_t i = 0; i < _config.nlayer; ++i) {
+ _k_cache[i] = tensorCreate(kv_shape, 3, _config.dtype, _device, device_id);
+ _v_cache[i] = tensorCreate(kv_shape, 3, _config.dtype, _device, device_id);
+ }
+ _past_len = 0;
+ _cache_inited = true;
+}
+
+void Decoder::releaseCache() {
+ for (auto &t : _k_cache) {
+ if (t) tensorDestroy(t);
+ t = nullptr;
+ }
+ for (auto &t : _v_cache) {
+ if (t) tensorDestroy(t);
+ t = nullptr;
+ }
+ _k_cache.clear();
+ _v_cache.clear();
+ _past_len = 0;
+ _cache_inited = false;
+}
+
+void Decoder::resetKVCache() {
+ if (!_cache_inited) return;
+ _past_len = 0;
+}
+
+void Decoder::setKVCacheEnabled(bool enabled) {
+ if (_kv_cache_enabled == enabled) return;
+ _kv_cache_enabled = enabled;
+ if (!enabled) {
+ releaseCache();
+ }
+}
+
+void Decoder::bindExternalKVContext(void *ctx, size_t past_len_tokens) {
+ _external_kv_ctx = ctx;
+ _external_past_len = past_len_tokens;
+ _external_cache_ready = false;
+ if (ctx) {
+ releaseCache();
+ }
+}
+
+void Decoder::clearExternalKVContext() {
+ _external_kv_ctx = nullptr;
+ _external_past_len = 0;
+ _external_cache_ready = false;
+}
+
+bool Decoder::hasExternalKVContext() const {
+ return _external_kv_ctx != nullptr;
+}
+
+int Decoder::exportKVContext(void *ctx_ptr, size_t block_tokens) {
+ if (!ctx_ptr) return -1;
+ if (!_kv_cache_enabled) return -2;
+ ensureCache();
+ if (!_cache_inited) return -3;
+
+ auto *ctx = reinterpret_cast(ctx_ptr);
+ if (!ctx) return -4;
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ if (ctx->dtype != _config.dtype || ctx->device != _device || ctx->device_id != device_id) return -5;
+ if (ctx->nlayer != _config.nlayer || ctx->nkvh != _config.nkvh || ctx->dh != _config.dh) return -6;
+
+ llaisysQwen2KVContextDetachAll(ctx);
+ if (_past_len == 0) return 0;
+
+ const size_t chunk_size = block_tokens > 0 ? block_tokens : _past_len;
+ size_t offset = 0;
+ while (offset < _past_len) {
+ const size_t used = std::min(chunk_size, _past_len - offset);
+ LlaisysQwen2KVBlockMeta meta{};
+ meta.dtype = _config.dtype;
+ meta.nlayer = _config.nlayer;
+ meta.nh = _config.nh;
+ meta.nkvh = _config.nkvh;
+ meta.dh = _config.dh;
+ meta.max_tokens = used;
+ auto *block = llaisysQwen2KVBlockCreate(&meta, _device, device_id);
+ if (!block) {
+ llaisysQwen2KVContextDetachAll(ctx);
+ return -7;
+ }
+ if (llaisysQwen2KVBlockSetTokenCount(block, used) != 0) {
+ llaisysQwen2KVBlockRelease(block);
+ llaisysQwen2KVContextDetachAll(ctx);
+ return -8;
+ }
+
+ bool copy_ok = true;
+ for (size_t layer = 0; layer < _config.nlayer && copy_ok; ++layer) {
+ llaisysTensor_t src_k = tensorSlice(_k_cache[layer], 0, offset, offset + used);
+ llaisysTensor_t src_v = tensorSlice(_v_cache[layer], 0, offset, offset + used);
+ llaisysTensor_t dst_k_full = llaisysQwen2KVBlockKeyTensor(block, layer);
+ llaisysTensor_t dst_v_full = llaisysQwen2KVBlockValueTensor(block, layer);
+ llaisysTensor_t dst_k = dst_k_full ? tensorSlice(dst_k_full, 0, 0, used) : nullptr;
+ llaisysTensor_t dst_v = dst_v_full ? tensorSlice(dst_v_full, 0, 0, used) : nullptr;
+ if (!src_k || !src_v || !dst_k || !dst_v) {
+ copy_ok = false;
+ } else {
+ ::llaisysRearrange(dst_k, src_k);
+ ::llaisysRearrange(dst_v, src_v);
+ }
+ destroy_if_not_null(src_k);
+ destroy_if_not_null(src_v);
+ destroy_if_not_null(dst_k);
+ destroy_if_not_null(dst_v);
+ }
+
+ if (!copy_ok || llaisysQwen2KVContextAttachBlock(ctx, block) != 0) {
+ llaisysQwen2KVBlockRelease(block);
+ llaisysQwen2KVContextDetachAll(ctx);
+ return -9;
+ }
+ llaisysQwen2KVBlockRelease(block);
+ offset += used;
+ }
+ return 0;
+}
+
+bool Decoder::recoverExternalCache() {
+ if (!_external_kv_ctx || _external_cache_ready) return true;
+ if (!_kv_cache_enabled) return false;
+ ensureCache();
+ if (!_cache_inited) return false;
+
+ auto *ctx = reinterpret_cast(_external_kv_ctx);
+ if (!ctx) return false;
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ if (ctx->dtype != _config.dtype || ctx->device != _device || ctx->device_id != device_id) return false;
+ if (ctx->nlayer != _config.nlayer || ctx->nkvh != _config.nkvh || ctx->dh != _config.dh) return false;
+
+ size_t total_tokens = 0;
+ for (auto *blk : ctx->chain) {
+ if (!blk) return false;
+ if (blk->meta.dtype != _config.dtype || blk->device != _device || blk->device_id != device_id) return false;
+ if (blk->meta.nlayer != _config.nlayer || blk->meta.nkvh != _config.nkvh || blk->meta.dh != _config.dh) return false;
+ if (blk->used_tokens > blk->meta.max_tokens) return false;
+ total_tokens += blk->used_tokens;
+ if (total_tokens > _config.maxseq) return false;
+ }
+
+ _past_len = 0;
+ for (size_t layer = 0; layer < _config.nlayer; ++layer) {
+ size_t offset = 0;
+ for (auto *blk : ctx->chain) {
+ const size_t used = blk->used_tokens;
+ if (used == 0) continue;
+ if (layer >= blk->k_layers.size() || layer >= blk->v_layers.size()) return false;
+ auto *k_block = blk->k_layers[layer];
+ auto *v_block = blk->v_layers[layer];
+ if (!k_block || !v_block) return false;
+
+ llaisysTensor_t src_k = tensorSlice(k_block, 0, 0, used);
+ llaisysTensor_t src_v = tensorSlice(v_block, 0, 0, used);
+ llaisysTensor_t dst_k = tensorSlice(_k_cache[layer], 0, offset, offset + used);
+ llaisysTensor_t dst_v = tensorSlice(_v_cache[layer], 0, offset, offset + used);
+ if (!src_k || !src_v || !dst_k || !dst_v) {
+ destroy_if_not_null(src_k);
+ destroy_if_not_null(src_v);
+ destroy_if_not_null(dst_k);
+ destroy_if_not_null(dst_v);
+ return false;
+ }
+ ::llaisysRearrange(dst_k, src_k);
+ ::llaisysRearrange(dst_v, src_v);
+ tensorDestroy(src_k);
+ tensorDestroy(src_v);
+ tensorDestroy(dst_k);
+ tensorDestroy(dst_v);
+ offset += used;
+ }
+ }
+
+ _past_len = total_tokens;
+ _external_past_len = total_tokens;
+ _external_cache_ready = true;
+ return true;
+}
+
+bool Decoder::runHidden(const int64_t *token_ids,
+ size_t ntoken,
+ bool append_only,
+ const int64_t *segment_offsets,
+ size_t nseg,
+ size_t &past_len,
+ size_t &cur_len,
+ llaisysTensor_t &idx,
+ llaisysTensor_t &pos_ids,
+ llaisysTensor_t &hidden) {
+ idx = nullptr;
+ pos_ids = nullptr;
+ hidden = nullptr;
+ if (!token_ids || ntoken == 0) return false;
+ if (!_weights || !_weights->in_embed) return false;
+ const bool segmented = (segment_offsets != nullptr && nseg > 0);
+ if (segmented && append_only) return false;
+
+ if (!segmented) {
+ ensureCache();
+ if (_external_kv_ctx && !_external_cache_ready) {
+ if (!recoverExternalCache()) {
+ clearExternalKVContext();
+ _past_len = 0;
+ }
+ }
+ }
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ // Segmented packed prefill treats each call as an independent packed forward.
+ // Reusing decoder KV cache here breaks offset domains (past_len vs packed offsets).
+ const bool can_cache = (!segmented) && _cache_inited && _config.maxseq > 0;
+ if (can_cache && ntoken > _config.maxseq) return false;
+ past_len = can_cache ? _past_len : 0;
+ if (append_only && !can_cache) {
+ return false;
+ }
+ if (!append_only) {
+ if (!can_cache || ntoken <= past_len) {
+ past_len = 0;
+ if (can_cache) _past_len = 0;
+ }
+ cur_len = ntoken - past_len;
+ } else {
+ cur_len = ntoken;
+ }
+ if (cur_len == 0) return false;
+ if (trace_enabled()) {
+ std::cerr << "[TRACE] Decoder cache: enabled=" << (_kv_cache_enabled ? 1 : 0)
+ << " inited=" << (_cache_inited ? 1 : 0)
+ << " can_cache=" << (can_cache ? 1 : 0)
+ << " past_len=" << past_len
+ << " cur_len=" << cur_len
+ << " ntoken=" << ntoken << std::endl;
+ }
+ const int64_t *new_tokens = append_only ? token_ids : (token_ids + past_len);
+ if (can_cache) {
+ if (_k_cache.size() != _config.nlayer || _v_cache.size() != _config.nlayer) return false;
+ if (past_len + cur_len > _config.maxseq) return false;
+ }
+
+ trace("begin");
+ // 1) token ids -> embedding
+ size_t idx_shape[1] = {cur_len};
+ idx = tensorCreate(idx_shape, 1, LLAISYS_DTYPE_I64, _device, device_id);
+ if (!require_tensor(idx, "idx")) return false;
+ tensorLoad(idx, new_tokens);
+
+ size_t hidden_shape[2] = {cur_len, _config.hs};
+ hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ if (!require_tensor(hidden, "hidden")) {
+ tensorDestroy(idx);
+ idx = nullptr;
+ return false;
+ }
+
+ trace("embedding");
+ ::llaisysEmbedding(hidden, idx, _weights->in_embed);
+
+ // 2) position ids for RoPE
+ std::vector pos_buf(cur_len);
+ for (size_t i = 0; i < cur_len; ++i) pos_buf[i] = static_cast(past_len + i);
+ trace("pos_ids");
+ pos_ids = tensorCreate(idx_shape, 1, LLAISYS_DTYPE_I64, _device, device_id);
+ if (!require_tensor(pos_ids, "pos_ids")) {
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ tensorLoad(pos_ids, pos_buf.data());
+
+ // 3) Attention + MLP blocks
+ const float scale = 1.0f / std::sqrt(static_cast(_config.dh));
+ for (size_t layer = 0; layer < _config.nlayer; ++layer) {
+ trace("attn.weights.check");
+ if (!_weights->attn_norm_w || !_weights->attn_q_w || !_weights->attn_k_w || !_weights->attn_v_w ||
+ !_weights->attn_o_w || !_weights->mlp_norm_w || !_weights->mlp_gate_w || !_weights->mlp_up_w ||
+ !_weights->mlp_down_w) {
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ if (!_weights->attn_norm_w[layer] || !_weights->attn_q_w[layer] || !_weights->attn_k_w[layer] ||
+ !_weights->attn_v_w[layer] || !_weights->attn_o_w[layer] || !_weights->mlp_norm_w[layer] ||
+ !_weights->mlp_gate_w[layer] || !_weights->mlp_up_w[layer] || !_weights->mlp_down_w[layer]) {
+ std::cerr << "[ERROR] Decoder: missing weights at layer " << layer << std::endl;
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+
+ trace("attn.norm");
+ llaisysTensor_t norm = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ if (!require_tensor(norm, "attn.norm")) {
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ ::llaisysRmsNorm(norm, hidden, _weights->attn_norm_w[layer], _config.epsilon);
+
+ trace("attn.qkv");
+ size_t q2d_shape[2] = {cur_len, _config.nh * _config.dh};
+ size_t kv2d_shape[2] = {cur_len, _config.nkvh * _config.dh};
+ llaisysTensor_t q2d = tensorCreate(q2d_shape, 2, _config.dtype, _device, device_id);
+ llaisysTensor_t k2d = tensorCreate(kv2d_shape, 2, _config.dtype, _device, device_id);
+ llaisysTensor_t v2d = tensorCreate(kv2d_shape, 2, _config.dtype, _device, device_id);
+ if (!require_tensor(q2d, "attn.q2d") || !require_tensor(k2d, "attn.k2d") ||
+ !require_tensor(v2d, "attn.v2d")) {
+ tensorDestroy(norm);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ if (q2d) tensorDestroy(q2d);
+ if (k2d) tensorDestroy(k2d);
+ if (v2d) tensorDestroy(v2d);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+
+ llaisysTensor_t q_bias = (_weights->attn_q_b && _weights->attn_q_b[layer]) ? _weights->attn_q_b[layer] : nullptr;
+ llaisysTensor_t k_bias = (_weights->attn_k_b && _weights->attn_k_b[layer]) ? _weights->attn_k_b[layer] : nullptr;
+ llaisysTensor_t v_bias = (_weights->attn_v_b && _weights->attn_v_b[layer]) ? _weights->attn_v_b[layer] : nullptr;
+
+ ::llaisysLinear(q2d, norm, _weights->attn_q_w[layer], q_bias);
+ ::llaisysLinear(k2d, norm, _weights->attn_k_w[layer], k_bias);
+ ::llaisysLinear(v2d, norm, _weights->attn_v_w[layer], v_bias);
+
+ trace("attn.view");
+ size_t q3d_shape[3] = {cur_len, _config.nh, _config.dh};
+ size_t k3d_shape[3] = {cur_len, _config.nkvh, _config.dh};
+ llaisysTensor_t q3d = tensorView(q2d, q3d_shape, 3);
+ llaisysTensor_t k3d = tensorView(k2d, k3d_shape, 3);
+ llaisysTensor_t v3d = tensorView(v2d, k3d_shape, 3);
+ if (!require_tensor(q3d, "attn.q3d") || !require_tensor(k3d, "attn.k3d") ||
+ !require_tensor(v3d, "attn.v3d")) {
+ tensorDestroy(norm);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ tensorDestroy(q2d);
+ tensorDestroy(k2d);
+ tensorDestroy(v2d);
+ if (q3d) tensorDestroy(q3d);
+ if (k3d) tensorDestroy(k3d);
+ if (v3d) tensorDestroy(v3d);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+
+ trace("attn.rope");
+ llaisysTensor_t q_rope = tensorCreate(q3d_shape, 3, _config.dtype, _device, device_id);
+ llaisysTensor_t k_rope = tensorCreate(k3d_shape, 3, _config.dtype, _device, device_id);
+ if (!require_tensor(q_rope, "attn.q_rope") || !require_tensor(k_rope, "attn.k_rope")) {
+ tensorDestroy(norm);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ tensorDestroy(q2d);
+ tensorDestroy(k2d);
+ tensorDestroy(v2d);
+ tensorDestroy(q3d);
+ tensorDestroy(k3d);
+ tensorDestroy(v3d);
+ if (q_rope) tensorDestroy(q_rope);
+ if (k_rope) tensorDestroy(k_rope);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ ::llaisysROPE(q_rope, q3d, pos_ids, _config.theta);
+ ::llaisysROPE(k_rope, k3d, pos_ids, _config.theta);
+
+ if (can_cache) {
+ trace("attn.cache.write");
+ llaisysTensor_t k_slot = tensorSlice(_k_cache[layer], 0, past_len, past_len + cur_len);
+ llaisysTensor_t v_slot = tensorSlice(_v_cache[layer], 0, past_len, past_len + cur_len);
+ ::llaisysRearrange(k_slot, k_rope);
+ ::llaisysRearrange(v_slot, v3d);
+ tensorDestroy(k_slot);
+ tensorDestroy(v_slot);
+ }
+
+ llaisysTensor_t k_attn = k_rope;
+ llaisysTensor_t v_attn = v3d;
+ llaisysTensor_t k_cache_view = nullptr;
+ llaisysTensor_t v_cache_view = nullptr;
+ if (can_cache) {
+ trace("attn.cache.read");
+ size_t total_len = past_len + cur_len;
+ k_cache_view = tensorSlice(_k_cache[layer], 0, 0, total_len);
+ v_cache_view = tensorSlice(_v_cache[layer], 0, 0, total_len);
+ k_attn = k_cache_view;
+ v_attn = v_cache_view;
+ }
+
+ trace("attn.softmax");
+ llaisysTensor_t attn_out3d = tensorCreate(q3d_shape, 3, _config.dtype, _device, device_id);
+ if (!require_tensor(attn_out3d, "attn.out3d")) {
+ tensorDestroy(norm);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ tensorDestroy(q2d);
+ tensorDestroy(k2d);
+ tensorDestroy(v2d);
+ tensorDestroy(q3d);
+ tensorDestroy(k3d);
+ tensorDestroy(v3d);
+ tensorDestroy(q_rope);
+ tensorDestroy(k_rope);
+ if (k_cache_view) tensorDestroy(k_cache_view);
+ if (v_cache_view) tensorDestroy(v_cache_view);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ if (segmented) {
+ ::llaisysSelfAttentionSegmented(
+ attn_out3d,
+ q_rope,
+ k_attn,
+ v_attn,
+ scale,
+ segment_offsets,
+ segment_offsets,
+ nseg);
+ } else {
+ ::llaisysSelfAttention(attn_out3d, q_rope, k_attn, v_attn, scale);
+ }
+ if (k_cache_view) tensorDestroy(k_cache_view);
+ if (v_cache_view) tensorDestroy(v_cache_view);
+
+ trace("attn.proj");
+ llaisysTensor_t attn_out2d = tensorView(attn_out3d, q2d_shape, 2);
+ llaisysTensor_t proj_out = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ if (!require_tensor(attn_out2d, "attn.out2d") || !require_tensor(proj_out, "attn.proj_out")) {
+ tensorDestroy(norm);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ tensorDestroy(q2d);
+ tensorDestroy(k2d);
+ tensorDestroy(v2d);
+ tensorDestroy(q3d);
+ tensorDestroy(k3d);
+ tensorDestroy(v3d);
+ tensorDestroy(q_rope);
+ tensorDestroy(k_rope);
+ tensorDestroy(attn_out3d);
+ if (attn_out2d) tensorDestroy(attn_out2d);
+ if (proj_out) tensorDestroy(proj_out);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ if (!ensure_data(attn_out2d, "attn.proj.in") || !ensure_data(proj_out, "attn.proj.out") ||
+ !ensure_data(_weights->attn_o_w[layer], "attn.proj.w")) {
+ tensorDestroy(norm);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ tensorDestroy(q2d);
+ tensorDestroy(k2d);
+ tensorDestroy(v2d);
+ tensorDestroy(q3d);
+ tensorDestroy(k3d);
+ tensorDestroy(v3d);
+ tensorDestroy(q_rope);
+ tensorDestroy(k_rope);
+ tensorDestroy(attn_out3d);
+ tensorDestroy(attn_out2d);
+ tensorDestroy(proj_out);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ ::llaisysLinear(proj_out, attn_out2d, _weights->attn_o_w[layer], nullptr);
+
+ // Tensor parallel: allreduce after attn_o projection
+ if (_tp_size > 1 && _comm) {
+ size_t ndim = tensorGetNdim(proj_out);
+ size_t shape[4];
+ tensorGetShape(proj_out, shape);
+ size_t count = 1;
+ for (size_t d = 0; d < ndim; ++d) count *= shape[d];
+ auto backend = (_device == LLAISYS_DEVICE_ILUVATAR) ? LLAISYS_COMM_IXCCL : LLAISYS_COMM_NCCL;
+ auto *api = llaisys::device::getCommAPI(backend);
+ api->allreduce(tensorGetData(proj_out), tensorGetData(proj_out),
+ count, _config.dtype, LLAISYS_REDUCE_SUM, _comm, _comm_stream);
+ }
+
+ trace("attn.residual");
+ llaisysTensor_t new_hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ if (!require_tensor(new_hidden, "attn.residual")) {
+ tensorDestroy(norm);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ tensorDestroy(q2d);
+ tensorDestroy(k2d);
+ tensorDestroy(v2d);
+ tensorDestroy(q3d);
+ tensorDestroy(k3d);
+ tensorDestroy(v3d);
+ tensorDestroy(q_rope);
+ tensorDestroy(k_rope);
+ tensorDestroy(attn_out3d);
+ tensorDestroy(attn_out2d);
+ tensorDestroy(proj_out);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ ::llaisysAdd(new_hidden, hidden, proj_out);
+
+ tensorDestroy(hidden);
+ hidden = new_hidden;
+
+ tensorDestroy(norm);
+ tensorDestroy(q2d);
+ tensorDestroy(k2d);
+ tensorDestroy(v2d);
+ tensorDestroy(q3d);
+ tensorDestroy(k3d);
+ tensorDestroy(v3d);
+ tensorDestroy(q_rope);
+ tensorDestroy(k_rope);
+ tensorDestroy(attn_out3d);
+ tensorDestroy(attn_out2d);
+ tensorDestroy(proj_out);
+
+ // 4) MLP
+ trace("mlp.norm");
+ llaisysTensor_t mlp_norm = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ if (!require_tensor(mlp_norm, "mlp.norm")) {
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ ::llaisysRmsNorm(mlp_norm, hidden, _weights->mlp_norm_w[layer], _config.epsilon);
+
+ trace("mlp.gate_up");
+ size_t mlp_shape[2] = {cur_len, _config.di};
+ llaisysTensor_t gate = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id);
+ llaisysTensor_t up = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id);
+ if (!require_tensor(gate, "mlp.gate") || !require_tensor(up, "mlp.up")) {
+ tensorDestroy(mlp_norm);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ if (gate) tensorDestroy(gate);
+ if (up) tensorDestroy(up);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ ::llaisysLinear(gate, mlp_norm, _weights->mlp_gate_w[layer], nullptr);
+ ::llaisysLinear(up, mlp_norm, _weights->mlp_up_w[layer], nullptr);
+
+ trace("mlp.swiglu");
+ llaisysTensor_t swiglu = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id);
+ if (!require_tensor(swiglu, "mlp.swiglu")) {
+ tensorDestroy(mlp_norm);
+ tensorDestroy(gate);
+ tensorDestroy(up);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ ::llaisysSwiGLU(swiglu, gate, up);
+
+ trace("mlp.down");
+ llaisysTensor_t mlp_out = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ if (!require_tensor(mlp_out, "mlp.down")) {
+ tensorDestroy(mlp_norm);
+ tensorDestroy(gate);
+ tensorDestroy(up);
+ tensorDestroy(swiglu);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ ::llaisysLinear(mlp_out, swiglu, _weights->mlp_down_w[layer], nullptr);
+
+ // Tensor parallel: allreduce after mlp_down projection
+ if (_tp_size > 1 && _comm) {
+ size_t ndim = tensorGetNdim(mlp_out);
+ size_t shape[4];
+ tensorGetShape(mlp_out, shape);
+ size_t count = 1;
+ for (size_t d = 0; d < ndim; ++d) count *= shape[d];
+ auto backend = (_device == LLAISYS_DEVICE_ILUVATAR) ? LLAISYS_COMM_IXCCL : LLAISYS_COMM_NCCL;
+ auto *api = llaisys::device::getCommAPI(backend);
+ api->allreduce(tensorGetData(mlp_out), tensorGetData(mlp_out),
+ count, _config.dtype, LLAISYS_REDUCE_SUM, _comm, _comm_stream);
+ }
+
+ trace("mlp.residual");
+ llaisysTensor_t mlp_hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ if (!require_tensor(mlp_hidden, "mlp.residual")) {
+ tensorDestroy(mlp_norm);
+ tensorDestroy(gate);
+ tensorDestroy(up);
+ tensorDestroy(swiglu);
+ tensorDestroy(mlp_out);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ tensorDestroy(idx);
+ pos_ids = nullptr;
+ hidden = nullptr;
+ idx = nullptr;
+ return false;
+ }
+ ::llaisysAdd(mlp_hidden, hidden, mlp_out);
+
+ tensorDestroy(hidden);
+ hidden = mlp_hidden;
+
+ tensorDestroy(mlp_norm);
+ tensorDestroy(gate);
+ tensorDestroy(up);
+ tensorDestroy(swiglu);
+ tensorDestroy(mlp_out);
+ }
+
+ if (can_cache) {
+ _past_len = past_len + cur_len;
+ }
+
+ return true;
+}
+
+bool Decoder::prefill(const int64_t *token_ids, size_t ntoken, llaisysTensor_t out_last_logits) {
+ if (!out_last_logits) return false;
+ if (!ensure_data(out_last_logits, "head.logits.out")) return false;
+
+ size_t past_len = 0;
+ size_t cur_len = 0;
+ llaisysTensor_t idx = nullptr;
+ llaisysTensor_t pos_ids = nullptr;
+ llaisysTensor_t hidden = nullptr;
+ if (!runHidden(token_ids, ntoken, false, nullptr, 0, past_len, cur_len, idx, pos_ids, hidden)) return false;
+
+ if (!_weights || !_weights->out_norm_w || !_weights->out_embed) {
+ tensorDestroy(idx);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ return false;
+ }
+
+ trace("head.slice");
+ llaisysTensor_t last_hidden = tensorSlice(hidden, 0, cur_len - 1, cur_len);
+ if (!require_tensor(last_hidden, "head.last_hidden")) {
+ tensorDestroy(idx);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ return false;
+ }
+
+ size_t last_shape[2] = {1, _config.hs};
+ trace("head.norm");
+ llaisysTensor_t final_norm = tensorCreate(last_shape, 2, _config.dtype, _device, _device_ids.empty() ? 0 : _device_ids[0]);
+ if (!require_tensor(final_norm, "head.norm")) {
+ tensorDestroy(last_hidden);
+ tensorDestroy(idx);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ return false;
+ }
+ ::llaisysRmsNorm(final_norm, last_hidden, _weights->out_norm_w, _config.epsilon);
+
+ trace("head.logits");
+ ::llaisysLinear(out_last_logits, final_norm, _weights->out_embed, nullptr);
+
+ tensorDestroy(last_hidden);
+ tensorDestroy(final_norm);
+ tensorDestroy(idx);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ return true;
+}
+
+bool Decoder::prefillPacked(const int64_t *token_ids,
+ size_t ntoken,
+ const int64_t *token_offsets,
+ size_t nseq,
+ llaisysTensor_t out_last_logits) {
+ if (!out_last_logits || !token_ids || !token_offsets || nseq == 0 || ntoken == 0) return false;
+ if (!ensure_data(out_last_logits, "head.packed.logits.out")) return false;
+ if (tensorGetNdim(out_last_logits) != 2) return false;
+ size_t out_shape[2] = {0, 0};
+ tensorGetShape(out_last_logits, out_shape);
+ if (out_shape[0] != nseq || out_shape[1] != _config.voc) return false;
+ if (token_offsets[0] != 0 || static_cast(token_offsets[nseq]) != ntoken) return false;
+ for (size_t i = 0; i < nseq; ++i) {
+ if (token_offsets[i] > token_offsets[i + 1]) return false;
+ if (token_offsets[i] == token_offsets[i + 1]) return false;
+ }
+
+ size_t past_len = 0;
+ size_t cur_len = 0;
+ llaisysTensor_t idx = nullptr;
+ llaisysTensor_t pos_ids = nullptr;
+ llaisysTensor_t hidden = nullptr;
+ if (!runHidden(token_ids, ntoken, false, token_offsets, nseq, past_len, cur_len, idx, pos_ids, hidden)) return false;
+
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ if (!_weights || !_weights->out_norm_w || !_weights->out_embed) {
+ tensorDestroy(idx);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ return false;
+ }
+
+ bool ok = true;
+ for (size_t i = 0; i < nseq && ok; ++i) {
+ const size_t seg_end = static_cast(token_offsets[i + 1]);
+ const size_t last_pos = seg_end - 1;
+ llaisysTensor_t last_hidden = tensorSlice(hidden, 0, last_pos, last_pos + 1);
+ llaisysTensor_t row_logits = tensorSlice(out_last_logits, 0, i, i + 1);
+ size_t last_shape[2] = {1, _config.hs};
+ llaisysTensor_t final_norm = tensorCreate(last_shape, 2, _config.dtype, _device, device_id);
+ if (!last_hidden || !row_logits || !final_norm) {
+ ok = false;
+ } else {
+ ::llaisysRmsNorm(final_norm, last_hidden, _weights->out_norm_w, _config.epsilon);
+ ::llaisysLinear(row_logits, final_norm, _weights->out_embed, nullptr);
+ }
+ destroy_if_not_null(last_hidden);
+ destroy_if_not_null(row_logits);
+ destroy_if_not_null(final_norm);
+ }
+
+ tensorDestroy(idx);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ return ok;
+}
+
+bool Decoder::decodePacked(const int64_t *token_ids,
+ size_t nseq,
+ const std::vector &contexts,
+ llaisysTensor_t out_last_logits,
+ size_t block_tokens_hint) {
+ if (!token_ids || nseq == 0 || contexts.size() != nseq || !out_last_logits) return false;
+ if (!ensure_data(out_last_logits, "head.decode_packed.logits.out")) return false;
+ if (tensorGetNdim(out_last_logits) != 2) return false;
+ size_t out_shape[2] = {0, 0};
+ tensorGetShape(out_last_logits, out_shape);
+ if (out_shape[0] != nseq || out_shape[1] != _config.voc) return false;
+
+ const int device_id = _device_ids.empty() ? 0 : _device_ids[0];
+ std::vector past_lens(nseq, 0);
+ std::vector q_offsets(nseq + 1, 0);
+ std::vector kv_offsets(nseq + 1, 0);
+ std::vector append_blocks(nseq, nullptr);
+ std::vector append_pos(nseq, 0);
+ size_t kv_total = 0;
+
+ for (size_t i = 0; i < nseq; ++i) {
+ auto *ctx = contexts[i];
+ if (!ctx) return false;
+ if (ctx->dtype != _config.dtype || ctx->device != _device || ctx->device_id != device_id) return false;
+ if (ctx->nlayer != _config.nlayer || ctx->nkvh != _config.nkvh || ctx->dh != _config.dh) return false;
+ const size_t past = llaisysQwen2KVContextTokenCount(ctx);
+ if (past + 1 > _config.maxseq) return false;
+ past_lens[i] = past;
+ q_offsets[i + 1] = static_cast(i + 1);
+ kv_total += past + 1;
+ kv_offsets[i + 1] = static_cast(kv_total);
+
+ LlaisysQwen2KVBlock *target = nullptr;
+ size_t pos = 0;
+ if (!ctx->chain.empty()) {
+ auto *last = ctx->chain.back();
+ if (last && last->used_tokens < last->meta.max_tokens) {
+ target = last;
+ pos = last->used_tokens;
+ }
+ }
+ if (!target) {
+ const size_t max_tokens = block_tokens_hint > 0 ? block_tokens_hint : 64;
+ LlaisysQwen2KVBlockMeta meta{};
+ meta.dtype = _config.dtype;
+ meta.nlayer = _config.nlayer;
+ meta.nh = _config.nh;
+ meta.nkvh = _config.nkvh;
+ meta.dh = _config.dh;
+ meta.max_tokens = max_tokens;
+ auto *blk = llaisysQwen2KVBlockCreate(&meta, _device, device_id);
+ if (!blk) return false;
+ if (llaisysQwen2KVContextAttachBlock(ctx, blk) != 0) {
+ llaisysQwen2KVBlockRelease(blk);
+ return false;
+ }
+ llaisysQwen2KVBlockRelease(blk);
+ if (ctx->chain.empty() || !ctx->chain.back()) return false;
+ target = ctx->chain.back();
+ pos = target->used_tokens;
+ }
+ append_blocks[i] = target;
+ append_pos[i] = pos;
+ }
+
+ size_t idx_shape[1] = {nseq};
+ size_t hidden_shape[2] = {nseq, _config.hs};
+ llaisysTensor_t idx = tensorCreate(idx_shape, 1, LLAISYS_DTYPE_I64, _device, device_id);
+ llaisysTensor_t pos_ids = tensorCreate(idx_shape, 1, LLAISYS_DTYPE_I64, _device, device_id);
+ llaisysTensor_t hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ if (!idx || !pos_ids || !hidden) {
+ destroy_if_not_null(idx);
+ destroy_if_not_null(pos_ids);
+ destroy_if_not_null(hidden);
+ return false;
+ }
+ tensorLoad(idx, token_ids);
+ std::vector pos_buf(nseq, 0);
+ for (size_t i = 0; i < nseq; ++i) pos_buf[i] = static_cast(past_lens[i]);
+ tensorLoad(pos_ids, pos_buf.data());
+ ::llaisysEmbedding(hidden, idx, _weights->in_embed);
+
+ const float scale = 1.0f / std::sqrt(static_cast(_config.dh));
+ bool ok = true;
+ for (size_t layer = 0; layer < _config.nlayer && ok; ++layer) {
+ llaisysTensor_t norm = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ size_t q2d_shape[2] = {nseq, _config.nh * _config.dh};
+ size_t kv2d_shape[2] = {nseq, _config.nkvh * _config.dh};
+ llaisysTensor_t q2d = tensorCreate(q2d_shape, 2, _config.dtype, _device, device_id);
+ llaisysTensor_t k2d = tensorCreate(kv2d_shape, 2, _config.dtype, _device, device_id);
+ llaisysTensor_t v2d = tensorCreate(kv2d_shape, 2, _config.dtype, _device, device_id);
+ if (!norm || !q2d || !k2d || !v2d) {
+ destroy_if_not_null(norm);
+ destroy_if_not_null(q2d);
+ destroy_if_not_null(k2d);
+ destroy_if_not_null(v2d);
+ ok = false;
+ break;
+ }
+ ::llaisysRmsNorm(norm, hidden, _weights->attn_norm_w[layer], _config.epsilon);
+ llaisysTensor_t q_bias = (_weights->attn_q_b && _weights->attn_q_b[layer]) ? _weights->attn_q_b[layer] : nullptr;
+ llaisysTensor_t k_bias = (_weights->attn_k_b && _weights->attn_k_b[layer]) ? _weights->attn_k_b[layer] : nullptr;
+ llaisysTensor_t v_bias = (_weights->attn_v_b && _weights->attn_v_b[layer]) ? _weights->attn_v_b[layer] : nullptr;
+ ::llaisysLinear(q2d, norm, _weights->attn_q_w[layer], q_bias);
+ ::llaisysLinear(k2d, norm, _weights->attn_k_w[layer], k_bias);
+ ::llaisysLinear(v2d, norm, _weights->attn_v_w[layer], v_bias);
+
+ size_t q3d_shape[3] = {nseq, _config.nh, _config.dh};
+ size_t k3d_shape[3] = {nseq, _config.nkvh, _config.dh};
+ llaisysTensor_t q3d = tensorView(q2d, q3d_shape, 3);
+ llaisysTensor_t k3d = tensorView(k2d, k3d_shape, 3);
+ llaisysTensor_t v3d = tensorView(v2d, k3d_shape, 3);
+ llaisysTensor_t q_rope = tensorCreate(q3d_shape, 3, _config.dtype, _device, device_id);
+ llaisysTensor_t k_rope = tensorCreate(k3d_shape, 3, _config.dtype, _device, device_id);
+ if (!q3d || !k3d || !v3d || !q_rope || !k_rope) {
+ destroy_if_not_null(norm);
+ destroy_if_not_null(q2d);
+ destroy_if_not_null(k2d);
+ destroy_if_not_null(v2d);
+ destroy_if_not_null(q3d);
+ destroy_if_not_null(k3d);
+ destroy_if_not_null(v3d);
+ destroy_if_not_null(q_rope);
+ destroy_if_not_null(k_rope);
+ ok = false;
+ break;
+ }
+ ::llaisysROPE(q_rope, q3d, pos_ids, _config.theta);
+ ::llaisysROPE(k_rope, k3d, pos_ids, _config.theta);
+
+ size_t kv_all_shape[3] = {kv_total, _config.nkvh, _config.dh};
+ llaisysTensor_t k_all = tensorCreate(kv_all_shape, 3, _config.dtype, _device, device_id);
+ llaisysTensor_t v_all = tensorCreate(kv_all_shape, 3, _config.dtype, _device, device_id);
+ if (!k_all || !v_all) {
+ destroy_if_not_null(norm);
+ destroy_if_not_null(q2d);
+ destroy_if_not_null(k2d);
+ destroy_if_not_null(v2d);
+ destroy_if_not_null(q3d);
+ destroy_if_not_null(k3d);
+ destroy_if_not_null(v3d);
+ destroy_if_not_null(q_rope);
+ destroy_if_not_null(k_rope);
+ destroy_if_not_null(k_all);
+ destroy_if_not_null(v_all);
+ ok = false;
+ break;
+ }
+ for (size_t i = 0; i < nseq && ok; ++i) {
+ auto *ctx = contexts[i];
+ const size_t kv_begin = static_cast(kv_offsets[i]);
+ const size_t past = past_lens[i];
+ size_t copied = 0;
+ for (auto *blk : ctx->chain) {
+ if (!blk) {
+ ok = false;
+ break;
+ }
+ const size_t used = blk->used_tokens;
+ if (used == 0) continue;
+ llaisysTensor_t src_k = tensorSlice(blk->k_layers[layer], 0, 0, used);
+ llaisysTensor_t src_v = tensorSlice(blk->v_layers[layer], 0, 0, used);
+ llaisysTensor_t dst_k = tensorSlice(k_all, 0, kv_begin + copied, kv_begin + copied + used);
+ llaisysTensor_t dst_v = tensorSlice(v_all, 0, kv_begin + copied, kv_begin + copied + used);
+ if (!src_k || !src_v || !dst_k || !dst_v) {
+ destroy_if_not_null(src_k);
+ destroy_if_not_null(src_v);
+ destroy_if_not_null(dst_k);
+ destroy_if_not_null(dst_v);
+ ok = false;
+ break;
+ }
+ ::llaisysRearrange(dst_k, src_k);
+ ::llaisysRearrange(dst_v, src_v);
+ tensorDestroy(src_k);
+ tensorDestroy(src_v);
+ tensorDestroy(dst_k);
+ tensorDestroy(dst_v);
+ copied += used;
+ }
+ if (!ok || copied != past) {
+ ok = false;
+ break;
+ }
+
+ const size_t kv_new_pos = kv_begin + past;
+ llaisysTensor_t src_new_k = tensorSlice(k_rope, 0, i, i + 1);
+ llaisysTensor_t src_new_v = tensorSlice(v3d, 0, i, i + 1);
+ llaisysTensor_t dst_new_k = tensorSlice(k_all, 0, kv_new_pos, kv_new_pos + 1);
+ llaisysTensor_t dst_new_v = tensorSlice(v_all, 0, kv_new_pos, kv_new_pos + 1);
+ llaisysTensor_t dst_ctx_k = tensorSlice(append_blocks[i]->k_layers[layer], 0, append_pos[i], append_pos[i] + 1);
+ llaisysTensor_t dst_ctx_v = tensorSlice(append_blocks[i]->v_layers[layer], 0, append_pos[i], append_pos[i] + 1);
+ if (!src_new_k || !src_new_v || !dst_new_k || !dst_new_v || !dst_ctx_k || !dst_ctx_v) {
+ destroy_if_not_null(src_new_k);
+ destroy_if_not_null(src_new_v);
+ destroy_if_not_null(dst_new_k);
+ destroy_if_not_null(dst_new_v);
+ destroy_if_not_null(dst_ctx_k);
+ destroy_if_not_null(dst_ctx_v);
+ ok = false;
+ break;
+ }
+ ::llaisysRearrange(dst_new_k, src_new_k);
+ ::llaisysRearrange(dst_new_v, src_new_v);
+ ::llaisysRearrange(dst_ctx_k, src_new_k);
+ ::llaisysRearrange(dst_ctx_v, src_new_v);
+ tensorDestroy(src_new_k);
+ tensorDestroy(src_new_v);
+ tensorDestroy(dst_new_k);
+ tensorDestroy(dst_new_v);
+ tensorDestroy(dst_ctx_k);
+ tensorDestroy(dst_ctx_v);
+ }
+
+ llaisysTensor_t attn_out3d = nullptr;
+ llaisysTensor_t attn_out2d = nullptr;
+ llaisysTensor_t proj_out = nullptr;
+ llaisysTensor_t attn_hidden = nullptr;
+ llaisysTensor_t mlp_norm = nullptr;
+ llaisysTensor_t gate = nullptr;
+ llaisysTensor_t up = nullptr;
+ llaisysTensor_t swiglu = nullptr;
+ llaisysTensor_t mlp_out = nullptr;
+ llaisysTensor_t mlp_hidden = nullptr;
+
+ if (ok) {
+ attn_out3d = tensorCreate(q3d_shape, 3, _config.dtype, _device, device_id);
+ if (!attn_out3d) ok = false;
+ }
+ if (ok) {
+ ::llaisysSelfAttentionSegmented(
+ attn_out3d, q_rope, k_all, v_all, scale, q_offsets.data(), kv_offsets.data(), nseq);
+ attn_out2d = tensorView(attn_out3d, q2d_shape, 2);
+ proj_out = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id); attn_hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ if (!attn_out2d || !proj_out || !attn_hidden) ok = false;
+ }
+ if (ok) {
+ ::llaisysLinear(proj_out, attn_out2d, _weights->attn_o_w[layer], nullptr);
+ ::llaisysAdd(attn_hidden, hidden, proj_out);
+ }
+
+ if (ok) {
+ mlp_norm = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ size_t mlp_shape[2] = {nseq, _config.di};
+ gate = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id);
+ up = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id);
+ swiglu = tensorCreate(mlp_shape, 2, _config.dtype, _device, device_id);
+ mlp_out = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ mlp_hidden = tensorCreate(hidden_shape, 2, _config.dtype, _device, device_id);
+ if (!mlp_norm || !gate || !up || !swiglu || !mlp_out || !mlp_hidden) {
+ ok = false;
+ }
+ }
+ if (ok) {
+ ::llaisysRmsNorm(mlp_norm, attn_hidden, _weights->mlp_norm_w[layer], _config.epsilon);
+ ::llaisysLinear(gate, mlp_norm, _weights->mlp_gate_w[layer], nullptr);
+ ::llaisysLinear(up, mlp_norm, _weights->mlp_up_w[layer], nullptr);
+ ::llaisysSwiGLU(swiglu, gate, up);
+ ::llaisysLinear(mlp_out, swiglu, _weights->mlp_down_w[layer], nullptr);
+ ::llaisysAdd(mlp_hidden, attn_hidden, mlp_out);
+ }
+
+ if (ok) {
+ tensorDestroy(hidden);
+ hidden = mlp_hidden;
+ mlp_hidden = nullptr;
+ }
+
+ destroy_if_not_null(norm);
+ destroy_if_not_null(q2d);
+ destroy_if_not_null(k2d);
+ destroy_if_not_null(v2d);
+ destroy_if_not_null(q3d);
+ destroy_if_not_null(k3d);
+ destroy_if_not_null(v3d);
+ destroy_if_not_null(q_rope);
+ destroy_if_not_null(k_rope);
+ destroy_if_not_null(k_all);
+ destroy_if_not_null(v_all);
+ destroy_if_not_null(attn_out3d);
+ destroy_if_not_null(attn_out2d);
+ destroy_if_not_null(proj_out);
+ destroy_if_not_null(attn_hidden);
+ destroy_if_not_null(mlp_norm);
+ destroy_if_not_null(gate);
+ destroy_if_not_null(up);
+ destroy_if_not_null(swiglu);
+ destroy_if_not_null(mlp_out);
+ destroy_if_not_null(mlp_hidden);
+ }
+
+ if (ok) {
+ for (size_t i = 0; i < nseq; ++i) {
+ if (append_blocks[i] && append_blocks[i]->used_tokens < append_pos[i] + 1) {
+ append_blocks[i]->used_tokens = append_pos[i] + 1;
+ }
+ }
+ for (size_t i = 0; i < nseq && ok; ++i) {
+ llaisysTensor_t last_hidden = tensorSlice(hidden, 0, i, i + 1);
+ llaisysTensor_t row_logits = tensorSlice(out_last_logits, 0, i, i + 1);
+ size_t last_shape[2] = {1, _config.hs};
+ llaisysTensor_t final_norm = tensorCreate(last_shape, 2, _config.dtype, _device, device_id);
+ if (!last_hidden || !row_logits || !final_norm) {
+ ok = false;
+ } else {
+ ::llaisysRmsNorm(final_norm, last_hidden, _weights->out_norm_w, _config.epsilon);
+ ::llaisysLinear(row_logits, final_norm, _weights->out_embed, nullptr);
+ }
+ destroy_if_not_null(last_hidden);
+ destroy_if_not_null(row_logits);
+ destroy_if_not_null(final_norm);
+ }
+ }
+
+ tensorDestroy(idx);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ return ok;
+}
+
+bool Decoder::decodeStep(const int64_t *token_ids, size_t ntoken, llaisysTensor_t out_last_logits) {
+ if (!out_last_logits) return false;
+ if (!ensure_data(out_last_logits, "head.logits.out")) return false;
+
+ size_t past_len = 0;
+ size_t cur_len = 0;
+ llaisysTensor_t idx = nullptr;
+ llaisysTensor_t pos_ids = nullptr;
+ llaisysTensor_t hidden = nullptr;
+ if (!runHidden(token_ids, ntoken, true, nullptr, 0, past_len, cur_len, idx, pos_ids, hidden)) return false;
+
+ if (!_weights || !_weights->out_norm_w || !_weights->out_embed) {
+ tensorDestroy(idx);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ return false;
+ }
+
+ trace("head.slice");
+ llaisysTensor_t last_hidden = tensorSlice(hidden, 0, cur_len - 1, cur_len);
+ if (!require_tensor(last_hidden, "head.last_hidden")) {
+ tensorDestroy(idx);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ return false;
+ }
+
+ size_t last_shape[2] = {1, _config.hs};
+ trace("head.norm");
+ llaisysTensor_t final_norm = tensorCreate(last_shape, 2, _config.dtype, _device, _device_ids.empty() ? 0 : _device_ids[0]);
+ if (!require_tensor(final_norm, "head.norm")) {
+ tensorDestroy(last_hidden);
+ tensorDestroy(idx);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ return false;
+ }
+ ::llaisysRmsNorm(final_norm, last_hidden, _weights->out_norm_w, _config.epsilon);
+
+ trace("head.logits");
+ ::llaisysLinear(out_last_logits, final_norm, _weights->out_embed, nullptr);
+
+ tensorDestroy(last_hidden);
+ tensorDestroy(final_norm);
+ tensorDestroy(idx);
+ tensorDestroy(pos_ids);
+ tensorDestroy(hidden);
+ return true;
+}
+
+} // namespace llaisys::models::transformer
diff --git a/src/models/transformer/decoder/decoder.hpp b/src/models/transformer/decoder/decoder.hpp
new file mode 100644
index 000000000..5784a849e
--- /dev/null
+++ b/src/models/transformer/decoder/decoder.hpp
@@ -0,0 +1,95 @@
+#pragma once
+
+#include "llaisys/models/qwen2.h"
+#include "llaisys/comm.h"
+#include "llaisys/tensor.h"
+
+#include
+#include
+#include
+
+namespace llaisys::models::transformer {
+
+struct DecoderConfig {
+ llaisysDataType_t dtype{};
+ size_t nlayer{};
+ size_t hs{};
+ size_t nh{};
+ size_t nkvh{};
+ size_t dh{};
+ size_t di{};
+ size_t maxseq{};
+ size_t voc{};
+ float epsilon{};
+ float theta{};
+};
+
+class Decoder {
+public:
+ Decoder(const DecoderConfig &config,
+ const LlaisysQwen2Weights *weights,
+ llaisysDeviceType_t device,
+ const std::vector &device_ids);
+ ~Decoder();
+
+ // Prefill with a full sequence, returns last-step logits.
+ bool prefill(const int64_t *token_ids, size_t ntoken, llaisysTensor_t out_last_logits);
+ // Prefill packed independent sequences, outputs one logits row per sequence.
+ bool prefillPacked(const int64_t *token_ids,
+ size_t ntoken,
+ const int64_t *token_offsets,
+ size_t nseq,
+ llaisysTensor_t out_last_logits);
+
+ // Decode with only new tokens (append-only), returns last-step logits.
+ bool decodeStep(const int64_t *token_ids, size_t ntoken, llaisysTensor_t out_last_logits);
+ // Decode one token per sequence in a packed batch with per-sequence KV contexts.
+ bool decodePacked(const int64_t *token_ids,
+ size_t nseq,
+ const std::vector &contexts,
+ llaisysTensor_t out_last_logits,
+ size_t block_tokens_hint);
+
+ void resetKVCache();
+
+ void setKVCacheEnabled(bool enabled);
+ void bindExternalKVContext(void *ctx, size_t past_len_tokens);
+ void clearExternalKVContext();
+ bool hasExternalKVContext() const;
+ int exportKVContext(void *ctx, size_t block_tokens);
+
+ void setTensorParallel(llaisysComm_t comm, llaisysStream_t stream, int tp_size);
+
+private:
+ bool recoverExternalCache();
+ bool runHidden(const int64_t *token_ids,
+ size_t ntoken,
+ bool append_only,
+ const int64_t *segment_offsets,
+ size_t nseg,
+ size_t &past_len,
+ size_t &cur_len,
+ llaisysTensor_t &idx,
+ llaisysTensor_t &pos_ids,
+ llaisysTensor_t &hidden);
+ void ensureCache();
+ void releaseCache();
+
+ DecoderConfig _config{};
+ const LlaisysQwen2Weights *_weights{nullptr};
+ llaisysDeviceType_t _device{};
+ std::vector _device_ids;
+ std::vector _k_cache;
+ std::vector _v_cache;
+ size_t _past_len{0};
+ bool _cache_inited{false};
+ bool _kv_cache_enabled{true};
+ void *_external_kv_ctx{nullptr};
+ size_t _external_past_len{0};
+ bool _external_cache_ready{false};
+ llaisysComm_t _comm{nullptr};
+ llaisysStream_t _comm_stream{nullptr};
+ int _tp_size{1};
+};
+
+} // namespace llaisys::models::transformer
diff --git a/src/ops/add/cpu/add_cpu.cpp b/src/ops/add/cpu/add_cpu.cpp
index 47f6a3d49..04d499d7b 100644
--- a/src/ops/add/cpu/add_cpu.cpp
+++ b/src/ops/add/cpu/add_cpu.cpp
@@ -5,29 +5,29 @@
#include
template
-void add_(T *c, const T *a, const T *b, size_t numel) {
- for (size_t i = 0; i < numel; i++) {
- if constexpr (std::is_same_v || std::is_same_v) {
- c[i] = llaisys::utils::cast(llaisys::utils::cast(a[i]) + llaisys::utils::cast(b[i]));
- } else {
- c[i] = a[i] + b[i];
+ void add_(T *c, const T *a, const T *b, size_t numel) {
+ for (size_t i = 0; i < numel; i++) {
+ if constexpr (std::is_same_v || std::is_same_v) {
+ c[i] = llaisys::utils::cast(llaisys::utils::cast(a[i]) + llaisys::utils::cast(b[i]));
+ } else {
+ c[i] = a[i] + b[i];
+ }
}
}
-}
namespace llaisys::ops::cpu {
-void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel) {
- switch (type) {
- case LLAISYS_DTYPE_F32:
- return add_(reinterpret_cast(c), reinterpret_cast(a), reinterpret_cast(b), numel);
- case LLAISYS_DTYPE_BF16:
- return add_(reinterpret_cast(c), reinterpret_cast(a),
- reinterpret_cast(b), numel);
- case LLAISYS_DTYPE_F16:
- return add_(reinterpret_cast(c), reinterpret_cast(a),
- reinterpret_cast(b), numel);
- default:
- EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel) {
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ return add_(reinterpret_cast(c), reinterpret_cast(a), reinterpret_cast(b), numel);
+ case LLAISYS_DTYPE_BF16:
+ return add_(reinterpret_cast(c), reinterpret_cast(a),
+ reinterpret_cast(b), numel);
+ case LLAISYS_DTYPE_F16:
+ return add_(reinterpret_cast(c), reinterpret_cast(a),
+ reinterpret_cast(b), numel);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
}
-}
} // namespace llaisys::ops::cpu
diff --git a/src/ops/add/cpu/add_cpu.hpp b/src/ops/add/cpu/add_cpu.hpp
index 34d809a11..20f5396ef 100644
--- a/src/ops/add/cpu/add_cpu.hpp
+++ b/src/ops/add/cpu/add_cpu.hpp
@@ -4,5 +4,5 @@
#include
namespace llaisys::ops::cpu {
-void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t size);
+ void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t size);
}
\ No newline at end of file
diff --git a/src/ops/add/nvidia/add_nvidia.cu b/src/ops/add/nvidia/add_nvidia.cu
new file mode 100644
index 000000000..49b7208cf
--- /dev/null
+++ b/src/ops/add/nvidia/add_nvidia.cu
@@ -0,0 +1,41 @@
+#include "add_nvidia.hpp"
+
+#include "../../../device/nvidia/cuda_utils.hpp"
+
+namespace llaisys::ops::nvidia {
+namespace {
+template
+__global__ void add_kernel(T *c, const T *a, const T *b, size_t numel) {
+ size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;
+ if (idx < numel) {
+ float av = llaisys::device::nvidia::ScalarOps::load(a + idx);
+ float bv = llaisys::device::nvidia::ScalarOps::load(b + idx);
+ llaisys::device::nvidia::ScalarOps::store(c + idx, av + bv);
+ }
+}
+
+template
+void launch_add(T *c, const T *a, const T *b, size_t numel) {
+ const int threads = 256;
+ const int blocks = static_cast((numel + threads - 1) / threads);
+ add_kernel<<>>(c, a, b, numel);
+ llaisys::device::nvidia::cuda_check(cudaGetLastError());
+}
+} // namespace
+
+void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel) {
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ return launch_add(reinterpret_cast(c), reinterpret_cast(a),
+ reinterpret_cast(b), numel);
+ case LLAISYS_DTYPE_BF16:
+ return launch_add(reinterpret_cast(c), reinterpret_cast(a),
+ reinterpret_cast(b), numel);
+ case LLAISYS_DTYPE_F16:
+ return launch_add(reinterpret_cast(c), reinterpret_cast(a),
+ reinterpret_cast(b), numel);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+} // namespace llaisys::ops::nvidia
diff --git a/src/ops/add/nvidia/add_nvidia.hpp b/src/ops/add/nvidia/add_nvidia.hpp
new file mode 100644
index 000000000..8424ad596
--- /dev/null
+++ b/src/ops/add/nvidia/add_nvidia.hpp
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "../../../utils.hpp"
+
+#include
+
+namespace llaisys::ops::nvidia {
+void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel);
+}
diff --git a/src/ops/add/op.cpp b/src/ops/add/op.cpp
index a057330d7..fea297cf7 100644
--- a/src/ops/add/op.cpp
+++ b/src/ops/add/op.cpp
@@ -4,9 +4,16 @@
#include "../../utils.hpp"
#include "cpu/add_cpu.hpp"
+#ifdef ENABLE_NVIDIA_API
+#include "nvidia/add_nvidia.hpp"
+#endif
+#ifdef ENABLE_ILUVATAR_API
+#include "nvidia/add_nvidia.hpp"
+#endif
namespace llaisys::ops {
void add(tensor_t c, tensor_t a, tensor_t b) {
+ //确保所有张量都在同一设备上
CHECK_SAME_DEVICE(c, a, b);
// Only support contiguous inputs with same shape for now.
CHECK_SAME_SHAPE(c->shape(), a->shape(), b->shape());
@@ -25,8 +32,11 @@ void add(tensor_t c, tensor_t a, tensor_t b) {
return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel());
#ifdef ENABLE_NVIDIA_API
case LLAISYS_DEVICE_NVIDIA:
- TO_BE_IMPLEMENTED();
- return;
+ return nvidia::add(c->data(), a->data(), b->data(), c->dtype(), c->numel());
+#endif
+#ifdef ENABLE_ILUVATAR_API
+ case LLAISYS_DEVICE_ILUVATAR:
+ return nvidia::add(c->data(), a->data(), b->data(), c->dtype(), c->numel());
#endif
default:
EXCEPTION_UNSUPPORTED_DEVICE;
diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp
new file mode 100644
index 000000000..ab96b2b2f
--- /dev/null
+++ b/src/ops/argmax/cpu/argmax_cpu.cpp
@@ -0,0 +1,45 @@
+#include "argmax_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+#include
+
+namespace {
+ template
+ void argmax_impl(std::byte *max_idx, std::byte *max_val, const std::byte *vals, size_t numel) {
+ // Work in float for fp16/bf16 comparisons to avoid precision issues.
+ using value_t = T;
+ const value_t *v = reinterpret_cast(vals);
+ int64_t *out_idx = reinterpret_cast(max_idx);
+ value_t *out_val = reinterpret_cast(max_val);
+
+ float best = llaisys::utils::cast(v[0]);
+ int64_t best_idx = 0;
+ for (size_t i = 1; i < numel; ++i) {
+ float cur = llaisys::utils::cast(v[i]);
+ if (cur > best) {
+ best = cur;
+ best_idx = static_cast(i);
+ }
+ }
+
+ *out_idx = best_idx;
+ *out_val = llaisys::utils::cast(best);
+ }
+}
+
+namespace llaisys::ops::cpu {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) {
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ return argmax_impl(max_idx, max_val, vals, numel);
+ case LLAISYS_DTYPE_BF16:
+ return argmax_impl(max_idx, max_val, vals, numel);
+ case LLAISYS_DTYPE_F16:
+ return argmax_impl(max_idx, max_val, vals, numel);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp
new file mode 100644
index 000000000..26ae3ef03
--- /dev/null
+++ b/src/ops/argmax/cpu/argmax_cpu.hpp
@@ -0,0 +1,8 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel);
+}
diff --git a/src/ops/argmax/nvidia/argmax_nvidia.cu b/src/ops/argmax/nvidia/argmax_nvidia.cu
new file mode 100644
index 000000000..8e6b0abf7
--- /dev/null
+++ b/src/ops/argmax/nvidia/argmax_nvidia.cu
@@ -0,0 +1,42 @@
+#include "argmax_nvidia.hpp"
+
+#include "../../../device/nvidia/cuda_utils.hpp"
+
+namespace llaisys::ops::nvidia {
+namespace {
+template
+__global__ void argmax_kernel(int64_t *out_idx, T *out_val, const T *vals, size_t numel) {
+ float best = llaisys::device::nvidia::ScalarOps::load(vals);
+ int64_t best_idx = 0;
+ for (size_t i = 1; i < numel; ++i) {
+ float v = llaisys::device::nvidia::ScalarOps::load(vals + i);
+ if (v > best) {
+ best = v;
+ best_idx = static_cast(i);
+ }
+ }
+ *out_idx = best_idx;
+ llaisys::device::nvidia::ScalarOps::store(out_val, best);
+}
+
+template
+void launch_argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, size_t numel) {
+ argmax_kernel<<<1, 1>>>(reinterpret_cast(max_idx), reinterpret_cast(max_val),
+ reinterpret_cast(vals), numel);
+ llaisys::device::nvidia::cuda_check(cudaGetLastError());
+}
+} // namespace
+
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) {
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ return launch_argmax(max_idx, max_val, vals, numel);
+ case LLAISYS_DTYPE_BF16:
+ return launch_argmax(max_idx, max_val, vals, numel);
+ case LLAISYS_DTYPE_F16:
+ return launch_argmax(max_idx, max_val, vals, numel);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+} // namespace llaisys::ops::nvidia
diff --git a/src/ops/argmax/nvidia/argmax_nvidia.hpp b/src/ops/argmax/nvidia/argmax_nvidia.hpp
new file mode 100644
index 000000000..e51bc30d5
--- /dev/null
+++ b/src/ops/argmax/nvidia/argmax_nvidia.hpp
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "../../../utils.hpp"
+
+#include
+
+namespace llaisys::ops::nvidia {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel);
+}
diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp
index 6dc37d426..f565d0a83 100644
--- a/src/ops/argmax/op.cpp
+++ b/src/ops/argmax/op.cpp
@@ -1,7 +1,46 @@
#include "op.hpp"
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+
+#include "cpu/argmax_cpu.hpp"
+#ifdef ENABLE_NVIDIA_API
+#include "nvidia/argmax_nvidia.hpp"
+#endif
+#ifdef ENABLE_ILUVATAR_API
+#include "nvidia/argmax_nvidia.hpp"
+#endif
+
+
namespace llaisys::ops {
void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) {
- TO_BE_IMPLEMENTED();
+ CHECK_SAME_DEVICE(max_idx, max_val, vals);
+ CHECK_SAME_DTYPE(max_val->dtype(), vals->dtype());
+ ASSERT(max_idx->dtype() == LLAISYS_DTYPE_I64, "Argmax: max_idx must be int64.");
+ // 当前实现按扁平化处理多维输入,相当于对全部元素取全局最大
+ ASSERT(vals->numel() > 0, "Argmax: input must be non-empty.");
+ ASSERT(max_idx->numel() == 1 && max_val->numel() == 1, "Argmax: outputs must have a single element.");
+ ASSERT(max_idx->isContiguous() && max_val->isContiguous() && vals->isContiguous(),
+ "Argmax: all tensors must be contiguous.");
+
+ if (vals->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+ }
+ llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId());
+
+ switch (vals->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return nvidia::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+#endif
+#ifdef ENABLE_ILUVATAR_API
+ case LLAISYS_DEVICE_ILUVATAR:
+ return nvidia::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp
new file mode 100644
index 000000000..6839372d3
--- /dev/null
+++ b/src/ops/embedding/cpu/embedding_cpu.cpp
@@ -0,0 +1,33 @@
+#include "embedding_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+#include
+
+namespace llaisys::ops::cpu {
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type,
+ size_t index_numel, size_t embd_dim, size_t weight_rows) {
+ size_t elem_size = 0;
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ case LLAISYS_DTYPE_F16:
+ case LLAISYS_DTYPE_BF16:
+ elem_size = llaisys::utils::dsize(type);
+ break;
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+
+ const int64_t *idx_ptr = reinterpret_cast(index);
+ size_t row_bytes = embd_dim * elem_size;
+
+ for (size_t i = 0; i < index_numel; ++i) {
+ int64_t idx = idx_ptr[i];
+ ASSERT(idx >= 0 && static_cast(idx) < weight_rows, "Embedding: index out of range.");
+ const std::byte *src = weight + static_cast(idx) * row_bytes;
+ std::byte *dst = out + i * row_bytes;
+ std::memcpy(dst, src, row_bytes);
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp
new file mode 100644
index 000000000..1b1626278
--- /dev/null
+++ b/src/ops/embedding/cpu/embedding_cpu.hpp
@@ -0,0 +1,9 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type,
+ size_t index_numel, size_t embd_dim, size_t weight_rows);
+}
diff --git a/src/ops/embedding/nvidia/embedding_nvidia.cu b/src/ops/embedding/nvidia/embedding_nvidia.cu
new file mode 100644
index 000000000..7595c5cc8
--- /dev/null
+++ b/src/ops/embedding/nvidia/embedding_nvidia.cu
@@ -0,0 +1,51 @@
+#include "embedding_nvidia.hpp"
+
+#include "../../../device/nvidia/cuda_utils.hpp"
+
+namespace llaisys::ops::nvidia {
+namespace {
+template
+__global__ void embedding_kernel(T *out, const int64_t *index, const T *weight, size_t index_numel, size_t dim,
+ size_t vocab) {
+ size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;
+ size_t total = index_numel * dim;
+ if (idx >= total) {
+ return;
+ }
+ size_t row = idx / dim;
+ size_t col = idx % dim;
+ int64_t token = index[row];
+ if (token < 0 || static_cast(token) >= vocab) {
+ return;
+ }
+ size_t w_idx = static_cast(token) * dim + col;
+ float v = llaisys::device::nvidia::ScalarOps