diff --git a/README.md b/README.md
index 456067c8..716e1ebf 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,7 @@
# Welcome to LLAISYS
+> Implementation summary, experiment results, and reproduction steps for the completed extension tasks are documented in [doc/implementation_report.md](doc/implementation_report.md).
+
English |
中文
diff --git a/doc/implementation_report.md b/doc/implementation_report.md
new file mode 100644
index 00000000..118018cd
--- /dev/null
+++ b/doc/implementation_report.md
@@ -0,0 +1,495 @@
+# LLAISYS 扩展实现与实验报告
+
+## 1. 报告说明
+
+本报告对应仓库扩展任务的当前实现状态,重点记录以下内容:
+
+- 已完成或部分完成的项目及其对应实现
+- 关键功能与当前边界
+- 主要实验结果
+- 复现方式
+
+本次工作围绕以下方向展开:
+
+- 项目 #1:CPU 优化
+- 项目 #2:Nvidia CUDA 后端
+- 项目 #3:聊天机器人与采样生成
+- 项目 #4:多用户推理服务
+- 项目 #5:分布式推理中的张量并行
+- 项目 #6:新增模型类型支持入口
+
+## 2. 任务完成情况概览
+
+| 项目 | 当前状态 | 说明 |
+| --- | --- | --- |
+| #1 CPU 优化 | 已完成 | CPU 主推理路径上的核心算子均已做专门优化 |
+| #2 CUDA 集成 | 已完成一类平台 | 已完成 Nvidia 平台;未继续实现其他 CUDA/CUDA-ish 平台 |
+| #3 聊天机器人 | 已完成 | 采样、流式输出、CLI、会话管理均已实现 |
+| #4 多用户推理服务 | 已完成可运行版本 | 已实现请求池、调度线程、worker 池、会话复用;尚未实现真正的后端 batched decode |
+| #5 分布式推理 | 已完成可运行版本 | 已实现 Nvidia/Qwen2 张量并行;当前切分策略对 `tp_size` 有约束 |
+| #6 新模型支持 | 已完成基础支持入口 | 已支持 `qwen2`、`llama`、`mistral` 的模型识别与创建入口 |
+
+## 3. 各项目实现情况
+
+### 3.1 项目 #1:Optimize LLAISYS for CPU
+
+#### 3.1.1 算子优化
+
+CPU 路径上已经完成一轮专门优化的算子包括:
+
+- `add`
+- `argmax`
+- `embedding`
+- `linear`
+- `rearrange`
+- `rms_norm`
+- `rope`
+- `self_attention`
+- `swiglu`
+- `sample`
+
+其中,`linear` 是重点优化对象,涉及的主要工作包括:
+
+- OpenMP 并行
+- 多级分块
+- AVX2/FMA 快路径
+- 对部分大规模 `f32 GEMM` 接入 OpenBLAS
+- 根据矩阵形状启发式选择内核,而非无条件走 BLAS
+
+相关实现分布在以下路径:
+
+- `src/ops/linear/cpu/linear_cpu.cpp`
+- `src/ops/add/cpu/add_cpu.cpp`
+- `src/ops/argmax/cpu/argmax_cpu.cpp`
+- `src/ops/embedding/cpu/embedding_cpu.cpp`
+- `src/ops/rearrange/cpu/rearrange_cpu.cpp`
+- `src/ops/rms_norm/cpu/rms_norm_cpu.cpp`
+- `src/ops/rope/cpu/rope_cpu.cpp`
+- `src/ops/self_attention/cpu/self_attention_cpu.cpp`
+- `src/ops/swiglu/cpu/swiglu_cpu.cpp`
+- `src/ops/sample/cpu/sample_cpu.cpp`
+
+#### 3.1.2 模型推理路径优化
+
+在 Qwen2 推理路径中,还增加了 buffer 复用机制,以降低频繁创建临时张量带来的额外开销。
+
+已完成的优化包括:
+
+- decode 路径 buffer 复用
+- prefill 路径 scratch buffer 复用
+
+相关实现位于:
+
+- `src/models/qwen2/qwen2.cpp`
+
+### 3.2 项目 #2:Integrate CUDA into LLAISYS
+
+#### 3.2.1 Nvidia Runtime API
+
+已完成 Nvidia 平台 Runtime API 的实现,并打通了构建流程。
+
+相关文件包括:
+
+- `xmake/nvidia.lua`
+- `xmake.lua`
+- `src/device/nvidia/cuda_utils.cuh`
+- `src/device/nvidia/nvidia_runtime_api.cu`
+
+同时修复了多设备 Runtime 在错误 device 上创建 CUDA stream 的问题,保证多 GPU 环境下每个 Runtime 使用与自身 device 对应的 stream。
+
+相关实现位于:
+
+- `src/core/runtime/runtime.cpp`
+
+#### 3.2.2 Nvidia 算子实现
+
+CUDA 算子没有集中放在单一 `.cu` 文件中,而是按照算子目录拆分实现。当前已完成的 Nvidia 算子包括:
+
+- `src/ops/add/nvidia`
+- `src/ops/argmax/nvidia`
+- `src/ops/embedding/nvidia`
+- `src/ops/linear/nvidia`
+- `src/ops/rearrange/nvidia`
+- `src/ops/rms_norm/nvidia`
+- `src/ops/rope/nvidia`
+- `src/ops/sample/nvidia`
+- `src/ops/self_attention/nvidia`
+- `src/ops/swiglu/nvidia`
+- `src/ops/nvidia/nvidia_common.cuh`
+
+#### 3.2.3 CUDA 推理路径
+
+Qwen2 的 CUDA 推理路径已经打通,能够在 Nvidia 设备上完成真实模型推理与采样生成。
+
+相关路径:
+
+- `src/models/qwen2/qwen2.cpp`
+- `python/llaisys/models/qwen2.py`
+
+当前未继续实现第二个 CUDA/CUDA-ish 平台,因此项目 #2 的完成范围限定为 Nvidia。
+
+### 3.3 项目 #3:Build an AI chatbot
+
+#### 3.3.1 随机采样
+
+新增 `sample` 算子,并接入模型生成路径。当前支持:
+
+- `temperature`
+- `top-k`
+- `top-p`
+- 当 `temperature <= 0` 或 `top_k == 1` 时退化为 argmax
+
+相关路径:
+
+- `src/ops/sample/op.cpp`
+- `src/ops/sample/cpu/sample_cpu.cpp`
+- `src/ops/sample/nvidia/sample_nvidia.cu`
+- `python/llaisys/ops.py`
+
+#### 3.3.2 聊天服务与 CLI
+
+已实现 OpenAI 风格的聊天接口:
+
+- `POST /v1/chat/completions`
+
+并支持:
+
+- SSE 流式输出
+- 命令行交互式聊天
+- 非流式请求
+
+相关路径:
+
+- `python/llaisys/chat_server.py`
+- `python/llaisys/chat_cli.py`
+
+#### 3.3.3 会话管理
+
+当前已经实现基础会话管理能力,包括:
+
+- 会话创建、查询、更新、删除
+- 会话历史查看
+- 重生成上一次助手回复
+- 编辑历史用户消息并重新生成
+- 基于 `truncate` 的 KV 状态回退
+
+服务端接口包括:
+
+- `GET /v1/sessions`
+- `POST /v1/sessions`
+- `GET /v1/sessions/{session_id}`
+- `PUT /v1/sessions/{session_id}`
+- `DELETE /v1/sessions/{session_id}`
+
+CLI 支持:
+
+- `/new`
+- `/list`
+- `/switch`
+- `/history`
+- `/regen`
+- `/edit`
+- `/delete`
+- `/session`
+- `/help`
+
+### 3.4 项目 #4:Multi-user Inference Service
+
+#### 3.4.1 服务层调度
+
+当前多用户服务采用三层结构:
+
+- HTTP 层:参数校验、建请求对象、返回响应
+- 调度层:后台线程从请求池取出请求并分发
+- 执行层:worker 持有模型实例并实际执行生成
+
+已实现的能力包括:
+
+- 请求池
+- 调度线程
+- worker 池
+- 微批次出队
+- 迭代级连续调度
+- 会话与 worker 亲和
+- 同一会话前缀复用
+- 同一 `session_id` 并发保护
+
+相关路径:
+
+- `python/llaisys/chat_server.py`
+- `test/test_chat_api.py`
+- `test/chat_test_utils.py`
+
+#### 3.4.2 当前边界
+
+项目 #4 当前已经具备“服务层 continuous batching”的核心形态,但尚未实现以下能力:
+
+- 单个后端模型对象内的真正 batched decode
+- batched KV-cache
+- 多请求单轮合并后的统一 batch forward
+- batched matmul 路径
+
+也就是说,当前调度器会以迭代级粒度推进请求,但单个 worker 一次仍然只执行一个请求的一步生成。
+
+### 3.5 项目 #5:Distributed Inference
+
+#### 3.5.1 张量并行实现
+
+当前已经实现 Nvidia 平台上的 Qwen2 张量并行版本,使用 `torch.distributed` 和 NCCL。
+
+模型入口位于:
+
+- `python/llaisys/models/tensor_parallel.py`
+- `python/llaisys/models/__init__.py`
+
+支持的能力包括:
+
+- `reset`
+- `truncate`
+- `generate_next`
+- `generate`
+- `stream_generate`
+
+服务端已支持以下参数:
+
+- `--tp-size`
+- `--tp-device-ids`
+
+#### 3.5.2 当前切分策略与限制
+
+当前实现采用均匀切分策略:
+
+- `Q heads` 均匀切分
+- `KV heads` 均匀切分
+- `MLP intermediate` 均匀切分
+
+因此,`tp_size` 需要同时满足这些相关维度的均匀切分要求。对当前使用的 `DeepSeek-R1-Distill-Qwen-1.5B` 而言:
+
+- `num_attention_heads = 12`
+- `num_key_value_heads = 2`
+- `intermediate_size = 8960`
+
+在当前实现下,允许的 `tp_size` 为:
+
+- `1`
+- `2`
+
+因此,两卡张量并行可以正常工作,四卡在当前实现下不支持。
+
+#### 3.5.3 当前边界
+
+当前项目 #5 尚不包括:
+
+- MPI/CPU 分布式推理
+- 更复杂的 KV 复制式张量并行
+- 针对低 `num_key_value_heads` 模型的更高卡数切分策略
+
+### 3.6 项目 #6:Support New Models
+
+当前已完成统一模型创建入口,并支持自动识别 `config.json` 中的 `model_type`。
+
+目前支持的模型类型包括:
+
+- `qwen2`
+- `llama`
+- `mistral`
+
+相关路径:
+
+- `python/llaisys/models/__init__.py`
+
+该部分工作的重点是先统一接入路径与创建流程。当前真实完整验证仍以 Qwen2 为主。
+
+## 4. 关键实验结果
+
+### 4.1 CPU `linear` 初始基线
+
+在最早的 CPU `linear` profile 中,观测到如下基线:
+
+- 矩阵形状:`(64, 512) x (512, 512)`
+- PyTorch:约 `0.269 ms`
+- LLAISYS:约 `25.427 ms`
+
+### 4.2 CPU `linear` 优化后结果
+
+在 CPU 优化后,小规模 `linear` 的结果达到过:
+
+- 矩阵形状:`(64, 512) x (512, 512)`
+- LLAISYS:约 `0.0568 ms`
+
+### 4.3 接近真实模型形状的 `linear`
+
+对更接近真实模型的矩阵形状,测得:
+
+- `out=(198, 8960), x=(198, 1536), w=(8960, 1536)`
+ - Torch:`4.88102 ms`
+ - LLAISYS:`10.06528 ms`
+
+- `out=(198, 1536), x=(198, 8960), w=(1536, 8960)`
+ - Torch:`1.96279 ms`
+ - LLAISYS:`4.23397 ms`
+
+### 4.4 CPU 真实模型推理
+
+在 `DeepSeek-R1-Distill-Qwen-1.5B` 上,短生成曾测得:
+
+- 生成 8 个 token
+ - `short_generate_s=0.600684`
+ - `short_tok_per_s=13.318142`
+
+官方回归测试的一次结果:
+
+- Hugging Face:约 `0.71 s`
+- LLAISYS CPU:约 `0.72 s`
+- token 序列:完全一致
+
+### 4.5 CUDA 真实模型推理
+
+单卡 Nvidia 回归测试的一次结果:
+
+- Hugging Face:约 `0.79 s`
+- LLAISYS Nvidia:约 `0.14 s`
+- token 序列:完全一致
+
+### 4.6 两卡张量并行
+
+两卡张量并行 smoke test 已通过,标准 prompt 的生成 token 与期望结果一致。
+
+## 5. 复现方式
+
+### 5.1 构建
+
+启用 CPU BLAS 与 Nvidia:
+
+```bash
+xmake f --cpu-blas=y --openblas-prefix="$CONDA_PREFIX" --nv-gpu=y -cv
+xmake
+xmake install
+```
+
+若不启用 BLAS:
+
+```bash
+xmake f --nv-gpu=y -cv
+xmake
+xmake install
+```
+
+### 5.2 CPU 回归测试
+
+```bash
+conda run -n llaisys env PYTHONPATH=python:test python test/test_runtime.py --device cpu
+conda run -n llaisys env PYTHONPATH=python:test python test/ops/linear.py --device cpu
+conda run -n llaisys env PYTHONPATH=python:test python test/ops/sample.py --device cpu
+conda run -n llaisys env PYTHONPATH=python:test python test/test_generate_sampling.py --device cpu --model [模型目录]
+conda run -n llaisys env PYTHONPATH=python:test python test/test_infer.py --device cpu --model [模型目录] --test --max_steps 8
+```
+
+### 5.3 Nvidia 回归测试
+
+```bash
+conda run -n llaisys env PYTHONPATH=python:test python test/test_runtime.py --device nvidia
+conda run -n llaisys env PYTHONPATH=python:test python test/ops/add.py --device nvidia
+conda run -n llaisys env PYTHONPATH=python:test python test/ops/argmax.py --device nvidia
+conda run -n llaisys env PYTHONPATH=python:test python test/ops/embedding.py --device nvidia
+conda run -n llaisys env PYTHONPATH=python:test python test/ops/linear.py --device nvidia
+conda run -n llaisys env PYTHONPATH=python:test python test/ops/rms_norm.py --device nvidia
+conda run -n llaisys env PYTHONPATH=python:test python test/ops/rope.py --device nvidia
+conda run -n llaisys env PYTHONPATH=python:test python test/ops/swiglu.py --device nvidia
+conda run -n llaisys env PYTHONPATH=python:test python test/ops/self_attention.py --device nvidia
+conda run -n llaisys env PYTHONPATH=python:test python test/test_generate_sampling.py --device nvidia --model [模型目录]
+conda run -n llaisys env PYTHONPATH=python:test python test/test_infer.py --device nvidia --model [模型目录] --test --max_steps 8
+```
+
+### 5.4 聊天与会话管理测试
+
+```bash
+conda run -n llaisys env PYTHONPATH=python:test python test/test_chat_api.py
+conda run -n llaisys env PYTHONPATH=python:test python test/test_chat_cli.py
+```
+
+### 5.5 张量并行测试
+
+```bash
+conda run -n llaisys env PYTHONPATH=python:test python test/test_tensor_parallel.py --model [模型目录] --tp-size 2 --max-steps 8
+```
+
+### 5.6 启动单卡 CUDA 聊天服务
+
+```bash
+PYTHONPATH=python python -m llaisys.chat_server \
+ --model-path [模型目录] \
+ --device nvidia \
+ --device-id 0 \
+ --host 127.0.0.1 \
+ --port 8000 \
+ --num-workers 1 \
+ --max-batch-size 1 \
+ --batch-wait-ms 0
+```
+
+### 5.7 启动两卡张量并行聊天服务
+
+```bash
+PYTHONPATH=python python -m llaisys.chat_server \
+ --model-path [模型目录] \
+ --device nvidia \
+ --device-id 0 \
+ --tp-size 2 \
+ --tp-device-ids 0,1 \
+ --host 127.0.0.1 \
+ --port 8000 \
+ --num-workers 1 \
+ --max-batch-size 2 \
+ --batch-wait-ms 5
+```
+
+### 5.8 交互式 CLI
+
+```bash
+PYTHONPATH=python python -m llaisys.chat_cli \
+ --url http://127.0.0.1:8000 \
+ --model llaisys-qwen2 \
+ --max-tokens 512
+```
+
+### 5.9 curl 示例
+
+```bash
+curl -N http://127.0.0.1:8000/v1/chat/completions \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "model": "llaisys-qwen2",
+ "session_id": "demo",
+ "messages": [{"role": "user", "content": "请解释操作系统的页表机制"}],
+ "max_tokens": 512,
+ "temperature": 0.8,
+ "top_p": 0.8,
+ "top_k": 50,
+ "stream": true
+ }'
+```
+
+服务统计接口:
+
+```bash
+curl http://127.0.0.1:8000/v1/service/stats
+```
+
+## 6. 当前边界说明
+
+当前仓库已经从基础的 CPU/Qwen2 推理项目扩展为:
+
+- CPU 优化版本
+- Nvidia CUDA 版本
+- 支持采样、流式输出和会话管理的聊天服务
+- 支持多用户请求池与服务层连续调度的推理服务
+- 支持两卡张量并行的 Nvidia 推理版本
+- 支持多模型类型统一创建入口的版本
+
+尚未纳入当前完成范围的内容主要包括:
+
+- 第二个 CUDA/CUDA-ish 平台
+- 真正后端层面的 batched decode / batched KV-cache
+- 更复杂的 KV 复制式张量并行
+- MPI/CPU 分布式推理
diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h
index 7054626d..37936d04 100644
--- a/include/llaisys/models/qwen2.h
+++ b/include/llaisys/models/qwen2.h
@@ -14,8 +14,8 @@ __C {
struct LlaisysQwen2Weights {
llaisysTensor_t in_embed;
llaisysTensor_t out_embed;
- llaisysTensor_t out_norm_w; // a.k.a. model.norm.weight
- llaisysTensor_t *attn_norm_w; // a.k.a. input_layernorm.weight
+ llaisysTensor_t out_norm_w;
+ llaisysTensor_t *attn_norm_w;
llaisysTensor_t *attn_q_w;
llaisysTensor_t *attn_q_b;
llaisysTensor_t *attn_k_w;
@@ -23,7 +23,7 @@ __C {
llaisysTensor_t *attn_v_w;
llaisysTensor_t *attn_v_b;
llaisysTensor_t *attn_o_w;
- llaisysTensor_t *mlp_norm_w; // a.k.a. post_attention_layernorm.weight
+ llaisysTensor_t *mlp_norm_w;
llaisysTensor_t *mlp_gate_w;
llaisysTensor_t *mlp_up_w;
llaisysTensor_t *mlp_down_w;
@@ -37,6 +37,12 @@ __C {
__export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model);
+ __export void llaisysQwen2ModelReset(struct LlaisysQwen2Model * model);
+
+ __export void llaisysQwen2ModelTruncate(struct LlaisysQwen2Model * model, size_t position);
+
__export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken);
+
+ __export int64_t llaisysQwen2ModelGenerateNext(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken, int top_k, float top_p, float temperature);
}
#endif // LLAISYS_MODELS_QWEN2_H
diff --git a/include/llaisys/ops.h b/include/llaisys/ops.h
index ddb3be24..18e1b356 100644
--- a/include/llaisys/ops.h
+++ b/include/llaisys/ops.h
@@ -8,6 +8,7 @@ __C {
__export void llaisysArgmax(llaisysTensor_t max_idx, llaisysTensor_t max_val, llaisysTensor_t vals);
__export void llaisysEmbedding(llaisysTensor_t out, llaisysTensor_t index, llaisysTensor_t weight);
__export void llaisysLinear(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, llaisysTensor_t bias);
+ __export int64_t llaisysSample(llaisysTensor_t logits, int top_k, float top_p, float temperature);
__export void llaisysRearrange(llaisysTensor_t out, llaisysTensor_t in);
__export void llaisysRmsNorm(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, float eps);
__export void llaisysROPE(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t pos_ids, float theta);
diff --git a/python/llaisys/chat_cli.py b/python/llaisys/chat_cli.py
new file mode 100644
index 00000000..f953a780
--- /dev/null
+++ b/python/llaisys/chat_cli.py
@@ -0,0 +1,414 @@
+import argparse
+import json
+import sys
+import uuid
+from typing import Any, Dict, List, Optional, Tuple
+
+try:
+ import httpx
+except ImportError as exc: # pragma: no cover - optional dependency
+ raise RuntimeError(
+ "HTTP client support is optional. Install with `pip install ./python[server]`."
+ ) from exc
+
+
+def _build_payload(
+ model: str,
+ messages: List[Dict[str, str]],
+ max_tokens: int,
+ top_k: int,
+ top_p: float,
+ temperature: float,
+ stream: bool,
+ session_id: str,
+) -> Dict[str, Any]:
+ return {
+ "model": model,
+ "messages": messages,
+ "max_tokens": max_tokens,
+ "top_k": top_k,
+ "top_p": top_p,
+ "temperature": temperature,
+ "stream": stream,
+ "session_id": session_id,
+ }
+
+
+def _stream_completion(client: httpx.Client, url: str, payload: Dict[str, Any]) -> str:
+ parts: List[str] = []
+ with client.stream("POST", url, json=payload, timeout=None) as response:
+ response.raise_for_status()
+ for line in response.iter_lines():
+ if not line or not line.startswith("data: "):
+ continue
+ data = line[6:]
+ if data == "[DONE]":
+ break
+ event = json.loads(data)
+ delta = event["choices"][0].get("delta", {})
+ content = delta.get("content", "")
+ if content:
+ print(content, end="", flush=True)
+ parts.append(content)
+ print()
+ return "".join(parts)
+
+
+def _non_stream_completion(client: httpx.Client, url: str, payload: Dict[str, Any]) -> str:
+ response = client.post(url, json=payload, timeout=None)
+ response.raise_for_status()
+ content = response.json()["choices"][0]["message"]["content"]
+ print(content)
+ return content
+
+
+def _request_json(client: httpx.Client, method: str, url: str, payload: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+ response = client.request(method, url, json=payload, timeout=None)
+ response.raise_for_status()
+ if not response.content:
+ return {}
+ return response.json()
+
+
+def _create_session(
+ client: httpx.Client,
+ base_url: str,
+ session_id: str,
+ messages: List[Dict[str, str]],
+) -> Dict[str, Any]:
+ payload: Dict[str, Any] = {"session_id": session_id}
+ if messages:
+ payload["messages"] = messages
+ return _request_json(client, "POST", f"{base_url}/v1/sessions", payload)
+
+
+def _get_session(client: httpx.Client, base_url: str, session_id: str) -> Dict[str, Any]:
+ return _request_json(client, "GET", f"{base_url}/v1/sessions/{session_id}")
+
+
+def _list_sessions(client: httpx.Client, base_url: str) -> List[Dict[str, Any]]:
+ return _request_json(client, "GET", f"{base_url}/v1/sessions").get("data", [])
+
+
+def _delete_session(client: httpx.Client, base_url: str, session_id: str) -> Dict[str, Any]:
+ return _request_json(client, "DELETE", f"{base_url}/v1/sessions/{session_id}")
+
+
+def _ensure_session(
+ client: httpx.Client,
+ base_url: str,
+ session_id: str,
+ seed_messages: List[Dict[str, str]],
+) -> Tuple[str, List[Dict[str, str]]]:
+ try:
+ payload = _get_session(client, base_url, session_id)
+ return session_id, payload.get("messages", [])
+ except httpx.HTTPStatusError as exc:
+ if exc.response.status_code != 404:
+ raise
+ _create_session(client, base_url, session_id, seed_messages)
+ return session_id, list(seed_messages)
+
+
+def _print_history(messages: List[Dict[str, str]]) -> None:
+ if not messages:
+ print("(empty session)")
+ return
+ for idx, message in enumerate(messages, start=1):
+ print(f"{idx}. {message['role']}: {message['content']}")
+
+
+def _print_sessions(sessions: List[Dict[str, Any]], current_session_id: Optional[str]) -> None:
+ if not sessions:
+ print("(no sessions)")
+ return
+ for session in sessions:
+ marker = "*" if session["id"] == current_session_id else " "
+ print(
+ f"{marker} {session['id']} "
+ f"{session['title']} "
+ f"messages={session['message_count']} "
+ f"tokens={session['token_count']}"
+ )
+
+
+def _print_help() -> None:
+ print("/help Show this help")
+ print("/session Show the current session id")
+ print("/new [session_id] Start a new session")
+ print("/list List sessions on the server")
+ print("/switch Switch to another session")
+ print("/history Show the current conversation")
+ print("/regen Regenerate the last assistant reply")
+ print("/edit Edit a past user message and regenerate from there")
+ print("/delete [session_id] Delete a session")
+ print("/quit Exit")
+
+
+def _complete(
+ client: httpx.Client,
+ endpoint: str,
+ model: str,
+ messages: List[Dict[str, str]],
+ max_tokens: int,
+ top_k: int,
+ top_p: float,
+ temperature: float,
+ stream: bool,
+ session_id: str,
+) -> str:
+ payload = _build_payload(
+ model,
+ messages,
+ max_tokens,
+ top_k,
+ top_p,
+ temperature,
+ stream,
+ session_id,
+ )
+ return (
+ _stream_completion(client, endpoint, payload)
+ if stream
+ else _non_stream_completion(client, endpoint, payload)
+ )
+
+
+def _handle_command(
+ command_line: str,
+ *,
+ client: httpx.Client,
+ base_url: str,
+ endpoint: str,
+ model: str,
+ max_tokens: int,
+ top_k: int,
+ top_p: float,
+ temperature: float,
+ stream: bool,
+ session_id: str,
+ messages: List[Dict[str, str]],
+ initial_messages: List[Dict[str, str]],
+) -> Tuple[str, List[Dict[str, str]], bool]:
+ command, _, rest = command_line.partition(" ")
+ command = command.lower()
+ rest = rest.strip()
+
+ if command == "/help":
+ _print_help()
+ return session_id, messages, False
+
+ if command == "/session":
+ print(session_id)
+ return session_id, messages, False
+
+ if command == "/list":
+ _print_sessions(_list_sessions(client, base_url), session_id)
+ return session_id, messages, False
+
+ if command == "/history":
+ _print_history(messages)
+ return session_id, messages, False
+
+ if command == "/new":
+ next_session_id = rest or f"session-{uuid.uuid4().hex}"
+ _create_session(client, base_url, next_session_id, initial_messages)
+ print(f"Switched to {next_session_id}")
+ return next_session_id, list(initial_messages), False
+
+ if command == "/switch":
+ if not rest:
+ print("Usage: /switch ")
+ return session_id, messages, False
+ payload = _get_session(client, base_url, rest)
+ print(f"Switched to {rest}")
+ return rest, payload.get("messages", []), False
+
+ if command == "/delete":
+ target_session_id = rest or session_id
+ _delete_session(client, base_url, target_session_id)
+ print(f"Deleted {target_session_id}")
+ if target_session_id == session_id:
+ next_session_id = f"session-{uuid.uuid4().hex}"
+ _create_session(client, base_url, next_session_id, initial_messages)
+ print(f"Switched to {next_session_id}")
+ return next_session_id, list(initial_messages), False
+ return session_id, messages, False
+
+ if command == "/regen":
+ if not messages or messages[-1]["role"] != "assistant":
+ print("No assistant reply to regenerate.")
+ return session_id, messages, False
+ next_messages = messages[:-1]
+ reply = _complete(
+ client,
+ endpoint,
+ model,
+ next_messages,
+ max_tokens,
+ top_k,
+ top_p,
+ temperature,
+ stream,
+ session_id,
+ )
+ next_messages = next_messages + [{"role": "assistant", "content": reply}]
+ return session_id, next_messages, False
+
+ if command == "/edit":
+ parts = rest.split(" ", 1)
+ if len(parts) != 2:
+ print("Usage: /edit ")
+ return session_id, messages, False
+ try:
+ index = int(parts[0])
+ except ValueError:
+ print("Message index must be an integer.")
+ return session_id, messages, False
+ if index < 1 or index > len(messages):
+ print("Message index out of range.")
+ return session_id, messages, False
+ if messages[index - 1]["role"] != "user":
+ print("Only user messages can be edited.")
+ return session_id, messages, False
+
+ next_messages = [{"role": message["role"], "content": message["content"]} for message in messages[:index]]
+ next_messages[-1]["content"] = parts[1]
+ reply = _complete(
+ client,
+ endpoint,
+ model,
+ next_messages,
+ max_tokens,
+ top_k,
+ top_p,
+ temperature,
+ stream,
+ session_id,
+ )
+ next_messages.append({"role": "assistant", "content": reply})
+ return session_id, next_messages, False
+
+ if command in {"/quit", "/exit"}:
+ return session_id, messages, True
+
+ print("Unknown command. Use /help to list session commands.")
+ return session_id, messages, False
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--url", default="http://127.0.0.1:8000", type=str)
+ parser.add_argument("--model", default="llaisys-qwen2", type=str)
+ parser.add_argument("--prompt", default=None, type=str)
+ parser.add_argument("--system-prompt", default=None, type=str)
+ parser.add_argument("--max-tokens", default=128, type=int)
+ parser.add_argument("--top-k", default=50, type=int)
+ parser.add_argument("--top-p", default=0.8, type=float)
+ parser.add_argument("--temperature", default=0.8, type=float)
+ parser.add_argument("--no-stream", action="store_true")
+ parser.add_argument("--session-id", default=None, type=str)
+ parser.add_argument("--list-sessions", action="store_true")
+ parser.add_argument("--show-session", action="store_true")
+ parser.add_argument("--create-session", action="store_true")
+ parser.add_argument("--delete-session", action="store_true")
+ args = parser.parse_args()
+
+ base_url = args.url.rstrip("/")
+ endpoint = f"{base_url}/v1/chat/completions"
+ stream = not args.no_stream
+ session_id = args.session_id or f"session-{uuid.uuid4().hex}"
+ initial_messages: List[Dict[str, str]] = []
+ if args.system_prompt:
+ initial_messages.append({"role": "system", "content": args.system_prompt})
+
+ if args.prompt is None and not sys.stdin.isatty() and not any(
+ [args.list_sessions, args.show_session, args.create_session, args.delete_session]
+ ):
+ raise RuntimeError(
+ "Interactive mode requires a TTY on stdin. "
+ "Run this command directly inside the activated `llaisys` environment, "
+ "or pass `--prompt` for one-shot mode."
+ )
+
+ with httpx.Client(trust_env=False) as client:
+ if args.list_sessions:
+ _print_sessions(_list_sessions(client, base_url), session_id)
+ return
+
+ if args.create_session:
+ payload = _create_session(client, base_url, session_id, initial_messages)
+ print(payload["id"])
+ return
+
+ if args.show_session:
+ payload = _get_session(client, base_url, session_id)
+ _print_history(payload.get("messages", []))
+ return
+
+ if args.delete_session:
+ payload = _delete_session(client, base_url, session_id)
+ print(json.dumps(payload, ensure_ascii=False))
+ return
+
+ session_id, messages = _ensure_session(client, base_url, session_id, initial_messages)
+ prompt = args.prompt
+
+ while True:
+ if prompt is None:
+ try:
+ prompt = input("user> ").strip()
+ except EOFError:
+ break
+ except KeyboardInterrupt:
+ print()
+ break
+
+ if not prompt:
+ break
+
+ if prompt.startswith("/"):
+ session_id, messages, should_exit = _handle_command(
+ prompt,
+ client=client,
+ base_url=base_url,
+ endpoint=endpoint,
+ model=args.model,
+ max_tokens=args.max_tokens,
+ top_k=args.top_k,
+ top_p=args.top_p,
+ temperature=args.temperature,
+ stream=stream,
+ session_id=session_id,
+ messages=messages,
+ initial_messages=initial_messages,
+ )
+ if should_exit:
+ break
+ if args.prompt is not None:
+ break
+ prompt = None
+ continue
+
+ messages.append({"role": "user", "content": prompt})
+ reply = _complete(
+ client,
+ endpoint,
+ args.model,
+ messages,
+ args.max_tokens,
+ args.top_k,
+ args.top_p,
+ args.temperature,
+ stream,
+ session_id,
+ )
+ messages.append({"role": "assistant", "content": reply})
+
+ if args.prompt is not None:
+ break
+ prompt = None
+
+
+if __name__ == "__main__":
+ main()
diff --git a/python/llaisys/chat_server.py b/python/llaisys/chat_server.py
new file mode 100644
index 00000000..8867c724
--- /dev/null
+++ b/python/llaisys/chat_server.py
@@ -0,0 +1,1099 @@
+import argparse
+import json
+import queue
+import threading
+import time
+import uuid
+from collections import deque
+from dataclasses import dataclass, field
+from typing import Any, Callable, Deque, Dict, Iterable, List, Optional
+
+import llaisys
+from transformers import AutoTokenizer
+
+try:
+ from fastapi import Body, FastAPI, HTTPException
+ from fastapi.responses import JSONResponse, StreamingResponse
+except ImportError as exc: # pragma: no cover - optional dependency
+ raise RuntimeError(
+ "FastAPI support is optional. Install with `pip install ./python[server]`."
+ ) from exc
+
+
+_STREAM_END = object()
+
+
+def _device_from_name(device_name: str) -> llaisys.DeviceType:
+ if device_name == "cpu":
+ return llaisys.DeviceType.CPU
+ if device_name == "nvidia":
+ return llaisys.DeviceType.NVIDIA
+ raise ValueError(f"Unsupported device: {device_name}")
+
+
+def _normalize_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, str]]:
+ normalized = []
+ for message in messages:
+ if not isinstance(message, dict):
+ raise ValueError("Each message must be an object")
+ role = str(message.get("role", "")).strip()
+ content = message.get("content", "")
+ if not role:
+ raise ValueError("Each message must include a role")
+ if isinstance(content, list):
+ text_parts = [str(part.get("text", "")) for part in content if isinstance(part, dict)]
+ content = "".join(text_parts)
+ normalized.append({"role": role, "content": str(content)})
+ if not normalized:
+ raise ValueError("messages must not be empty")
+ return normalized
+
+
+def _clone_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ return [{"role": message["role"], "content": message["content"]} for message in messages]
+
+
+def _build_input_ids(tokenizer, messages: List[Dict[str, Any]]) -> List[int]:
+ prompt = tokenizer.apply_chat_template(
+ conversation=_normalize_messages(messages),
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+ return list(tokenizer.encode(prompt))
+
+
+def _finish_reason(end_token: Optional[int], generated_ids: List[int], max_tokens: int) -> str:
+ if generated_ids and generated_ids[-1] == end_token:
+ return "stop"
+ if len(generated_ids) >= max_tokens:
+ return "length"
+ return "stop"
+
+
+def _usage(prompt_tokens: int, completion_tokens: int) -> Dict[str, int]:
+ return {
+ "prompt_tokens": prompt_tokens,
+ "completion_tokens": completion_tokens,
+ "total_tokens": prompt_tokens + completion_tokens,
+ }
+
+
+def _sse_payload(payload: Dict[str, Any]) -> str:
+ return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
+
+
+def _resolve_session_id(payload: Dict[str, Any]) -> str:
+ session_id = payload.get("session_id")
+ return str(session_id) if session_id else f"session-{uuid.uuid4().hex}"
+
+
+def _session_title(messages: List[Dict[str, str]]) -> str:
+ for message in messages:
+ if message["role"] == "user" and message["content"]:
+ title = message["content"].strip().replace("\n", " ")
+ return title[:64] if title else "Untitled Session"
+ return "Untitled Session"
+
+
+def _longest_common_prefix_len(lhs: List[int], rhs: List[int]) -> int:
+ limit = min(len(lhs), len(rhs))
+ idx = 0
+ while idx < limit and lhs[idx] == rhs[idx]:
+ idx += 1
+ return idx
+
+
+def _decode_text_delta(tokenizer, generated_ids: List[int], previous_text: str) -> tuple[str, str]:
+ decoded = tokenizer.decode(
+ generated_ids,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )
+ if decoded.startswith(previous_text):
+ return decoded[len(previous_text):], decoded
+ return decoded, decoded
+
+
+def _model_generate_next(
+ model: Any,
+ inputs: List[int],
+ *,
+ top_k: int,
+ top_p: float,
+ temperature: float,
+ reset_state: bool,
+) -> int:
+ generate_next = getattr(model, "generate_next", None)
+ if callable(generate_next):
+ return int(
+ generate_next(
+ inputs,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ reset_state=reset_state,
+ )
+ )
+
+ if reset_state and hasattr(model, "reset"):
+ model.reset()
+
+ private_generate_next = getattr(model, "_generate_next", None)
+ if callable(private_generate_next):
+ return int(
+ private_generate_next(
+ inputs,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ )
+ )
+
+ output_ids = model.generate(
+ inputs,
+ max_new_tokens=1,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ reset_state=reset_state,
+ )
+ return int(output_ids[-1])
+
+
+@dataclass
+class SessionState:
+ session_id: str
+ messages: List[Dict[str, str]]
+ token_ids: List[int]
+ created_at: float
+ updated_at: float
+ worker_id: Optional[int] = None
+
+
+@dataclass
+class ReusePlan:
+ reset_state: bool
+ generation_inputs: List[int]
+ reused_tokens: int
+ truncate_to: Optional[int]
+
+
+@dataclass
+class CompletionRequest:
+ completion_id: str
+ session_id: str
+ messages: List[Dict[str, str]]
+ input_ids: List[int]
+ max_tokens: int
+ top_k: int
+ top_p: float
+ temperature: float
+ stream: bool
+ created: int
+ request_model: str
+ result_queue: "queue.Queue[Any]"
+ enqueued_at: float
+ batch_id: Optional[int] = None
+ first_batch_id: Optional[int] = None
+ generated_ids: List[int] = field(default_factory=list)
+ last_text: str = ""
+ cache_reused_tokens: int = 0
+ started: bool = False
+ prompt_inputs: List[int] = field(default_factory=list)
+ reset_state: bool = True
+ steps_dispatched: int = 0
+ assigned_worker_id: Optional[int] = None
+ first_dispatched_at: Optional[float] = None
+ finished_at: Optional[float] = None
+
+
+@dataclass
+class WorkerState:
+ worker_id: int
+ model: Any
+ job_queue: "queue.Queue[Optional[CompletionRequest]]" = field(default_factory=queue.Queue)
+ thread: Optional[threading.Thread] = None
+ busy: bool = False
+ current_request_id: Optional[str] = None
+ active_session_id: Optional[str] = None
+ active_token_ids: List[int] = field(default_factory=list)
+ needs_reset: bool = True
+ total_requests: int = 0
+ total_tokens_generated: int = 0
+ total_steps: int = 0
+ last_used_at: float = 0.0
+
+
+@dataclass
+class ServiceState:
+ tokenizer: Any
+ model_name: str
+ workers: List[WorkerState]
+ max_batch_size: int
+ batch_wait_s: float
+ sessions: Dict[str, SessionState] = field(default_factory=dict)
+ session_inflight: Dict[str, int] = field(default_factory=dict)
+ pending_requests: Deque[CompletionRequest] = field(default_factory=deque)
+ condition: threading.Condition = field(default_factory=threading.Condition)
+ shutdown: bool = False
+ next_batch_id: int = 1
+ scheduled_batches: int = 0
+ max_observed_batch_size: int = 0
+ completed_requests: int = 0
+ failed_requests: int = 0
+ total_enqueued: int = 0
+ total_generated_tokens: int = 0
+ requeued_requests: int = 0
+ cache_reuse_hits: int = 0
+ total_queue_wait_s: float = 0.0
+ total_request_time_s: float = 0.0
+ scheduler_thread: Optional[threading.Thread] = None
+
+
+def _session_payload(
+ state: SessionState,
+ *,
+ include_messages: bool = False,
+ inflight: int = 0,
+) -> Dict[str, Any]:
+ payload = {
+ "id": state.session_id,
+ "object": "session",
+ "title": _session_title(state.messages),
+ "created": int(state.created_at),
+ "updated": int(state.updated_at),
+ "message_count": len(state.messages),
+ "token_count": len(state.token_ids),
+ "worker_id": state.worker_id,
+ "inflight": inflight,
+ }
+ if include_messages:
+ payload["messages"] = _clone_messages(state.messages)
+ return payload
+
+
+def _prepare_worker_reuse(worker: WorkerState, session_id: str, input_ids: List[int]) -> ReusePlan:
+ if worker.needs_reset or worker.active_session_id != session_id or not worker.active_token_ids:
+ return ReusePlan(reset_state=True, generation_inputs=input_ids, reused_tokens=0, truncate_to=None)
+
+ prefix_len = _longest_common_prefix_len(worker.active_token_ids, input_ids)
+ if prefix_len <= 0:
+ return ReusePlan(reset_state=True, generation_inputs=input_ids, reused_tokens=0, truncate_to=None)
+
+ reusable_prefix = prefix_len
+ if reusable_prefix >= len(input_ids):
+ reusable_prefix = len(input_ids) - 1
+
+ if reusable_prefix <= 0:
+ return ReusePlan(reset_state=True, generation_inputs=input_ids, reused_tokens=0, truncate_to=None)
+
+ truncate_to = reusable_prefix if reusable_prefix < len(worker.active_token_ids) else None
+ return ReusePlan(
+ reset_state=False,
+ generation_inputs=input_ids[reusable_prefix:],
+ reused_tokens=reusable_prefix,
+ truncate_to=truncate_to,
+ )
+
+
+def _detach_worker_session_locked(service: ServiceState, worker: WorkerState) -> None:
+ if worker.active_session_id is None:
+ worker.active_token_ids = []
+ worker.needs_reset = True
+ return
+
+ state = service.sessions.get(worker.active_session_id)
+ if state is not None and state.worker_id == worker.worker_id:
+ state.worker_id = None
+ worker.active_session_id = None
+ worker.active_token_ids = []
+ worker.needs_reset = True
+
+
+def _clear_session_cache_locked(service: ServiceState, session_id: str) -> None:
+ state = service.sessions.get(session_id)
+ if state is None:
+ return
+ worker_id = state.worker_id
+ state.worker_id = None
+ state.token_ids = []
+ if worker_id is None:
+ return
+ worker = service.workers[worker_id]
+ if worker.active_session_id == session_id:
+ worker.active_session_id = None
+ worker.active_token_ids = []
+ worker.needs_reset = True
+
+
+def _decrement_inflight_locked(service: ServiceState, session_id: str) -> None:
+ remaining = service.session_inflight.get(session_id, 0) - 1
+ if remaining > 0:
+ service.session_inflight[session_id] = remaining
+ else:
+ service.session_inflight.pop(session_id, None)
+
+
+def _commit_session_locked(
+ service: ServiceState,
+ worker: WorkerState,
+ session_id: str,
+ messages: List[Dict[str, str]],
+ cached_tokens: List[int],
+) -> None:
+ now = time.time()
+ previous = service.sessions.get(session_id)
+ created_at = previous.created_at if previous is not None else now
+ service.sessions[session_id] = SessionState(
+ session_id=session_id,
+ messages=_clone_messages(messages),
+ token_ids=list(cached_tokens),
+ created_at=created_at,
+ updated_at=now,
+ worker_id=worker.worker_id,
+ )
+ worker.active_session_id = session_id
+ worker.active_token_ids = list(cached_tokens)
+ worker.needs_reset = False
+
+
+def _select_worker_locked(
+ service: ServiceState,
+ request: CompletionRequest,
+ idle_workers: Dict[int, WorkerState],
+ pinned_session_ids: set[str],
+) -> Optional[WorkerState]:
+ session = service.sessions.get(request.session_id)
+ if session is not None and session.worker_id in idle_workers:
+ return idle_workers[session.worker_id]
+
+ for worker in idle_workers.values():
+ if worker.active_session_id == request.session_id and not worker.needs_reset:
+ return worker
+
+ never_used = [worker for worker in idle_workers.values() if worker.active_session_id is None]
+ if never_used:
+ return min(never_used, key=lambda item: item.last_used_at)
+
+ reusable_workers = [
+ worker
+ for worker in idle_workers.values()
+ if worker.active_session_id not in pinned_session_ids
+ ]
+ if reusable_workers:
+ return min(reusable_workers, key=lambda item: item.last_used_at)
+
+ return min(idle_workers.values(), key=lambda item: item.last_used_at, default=None)
+
+
+def _service_stats(service: ServiceState) -> Dict[str, Any]:
+ with service.condition:
+ workers = [
+ {
+ "worker_id": worker.worker_id,
+ "busy": worker.busy,
+ "active_session_id": worker.active_session_id,
+ "cached_tokens": len(worker.active_token_ids),
+ "needs_reset": worker.needs_reset,
+ "total_requests": worker.total_requests,
+ "total_tokens_generated": worker.total_tokens_generated,
+ "total_steps": worker.total_steps,
+ }
+ for worker in service.workers
+ ]
+ return {
+ "object": "service.stats",
+ "model": service.model_name,
+ "worker_count": len(service.workers),
+ "max_batch_size": service.max_batch_size,
+ "batch_wait_ms": int(service.batch_wait_s * 1000),
+ "queue_depth": len(service.pending_requests),
+ "active_requests": sum(1 for worker in service.workers if worker.busy),
+ "session_count": len(service.sessions),
+ "scheduled_batches": service.scheduled_batches,
+ "max_observed_batch_size": service.max_observed_batch_size,
+ "completed_requests": service.completed_requests,
+ "failed_requests": service.failed_requests,
+ "total_enqueued": service.total_enqueued,
+ "total_generated_tokens": service.total_generated_tokens,
+ "requeued_requests": service.requeued_requests,
+ "cache_reuse_hits": service.cache_reuse_hits,
+ "avg_queue_wait_ms": (
+ (service.total_queue_wait_s / service.completed_requests) * 1000.0
+ if service.completed_requests
+ else 0.0
+ ),
+ "avg_request_time_ms": (
+ (service.total_request_time_s / service.completed_requests) * 1000.0
+ if service.completed_requests
+ else 0.0
+ ),
+ "cached_session_count": sum(1 for session in service.sessions.values() if session.worker_id is not None),
+ "workers": workers,
+ }
+
+
+def _scheduler_loop(service: ServiceState) -> None:
+ while True:
+ assigned: List[tuple[WorkerState, CompletionRequest]] = []
+ with service.condition:
+ while not service.shutdown and not service.pending_requests:
+ service.condition.wait()
+ if service.shutdown:
+ return
+
+ if len(service.pending_requests) < service.max_batch_size and service.batch_wait_s > 0:
+ service.condition.wait(timeout=service.batch_wait_s)
+ if service.shutdown:
+ return
+
+ candidates: List[CompletionRequest] = []
+ while service.pending_requests and len(candidates) < service.max_batch_size:
+ candidates.append(service.pending_requests.popleft())
+
+ idle_workers = {
+ worker.worker_id: worker
+ for worker in service.workers
+ if not worker.busy
+ }
+ pinned_session_ids = {
+ pending_request.session_id
+ for pending_request in service.pending_requests
+ if pending_request.started
+ }
+ pinned_session_ids.update(
+ candidate.session_id
+ for candidate in candidates
+ if candidate.started
+ )
+ unassigned: List[CompletionRequest] = []
+ if idle_workers:
+ batch_id = service.next_batch_id
+ service.next_batch_id += 1
+ for request in candidates:
+ worker = _select_worker_locked(service, request, idle_workers, pinned_session_ids)
+ if worker is None:
+ unassigned.append(request)
+ continue
+ if worker.active_session_id is not None and worker.active_session_id != request.session_id:
+ _detach_worker_session_locked(service, worker)
+ worker.busy = True
+ worker.current_request_id = request.completion_id
+ request.batch_id = batch_id
+ if request.first_batch_id is None:
+ request.first_batch_id = batch_id
+ request.assigned_worker_id = worker.worker_id
+ request.steps_dispatched += 1
+ if request.first_dispatched_at is None:
+ request.first_dispatched_at = time.time()
+ assigned.append((worker, request))
+ idle_workers.pop(worker.worker_id, None)
+ pinned_session_ids.add(request.session_id)
+
+ if assigned:
+ service.scheduled_batches += 1
+ service.max_observed_batch_size = max(
+ service.max_observed_batch_size,
+ len(assigned),
+ )
+ else:
+ for request in reversed(candidates):
+ service.pending_requests.appendleft(request)
+ service.condition.wait(timeout=max(service.batch_wait_s, 0.01))
+ continue
+
+ for request in reversed(unassigned):
+ service.pending_requests.appendleft(request)
+
+ for worker, request in assigned:
+ worker.job_queue.put(request)
+
+
+def _worker_loop(service: ServiceState, worker: WorkerState) -> None:
+ while True:
+ request = worker.job_queue.get()
+ if request is None:
+ return
+ _run_request(service, worker, request)
+
+
+def _run_request(service: ServiceState, worker: WorkerState, request: CompletionRequest) -> None:
+ try:
+ emit_role_chunk = False
+ with service.condition:
+ batch_id = request.batch_id
+ worker_id = worker.worker_id
+ truncate_to = None
+
+ if not request.started:
+ plan = _prepare_worker_reuse(worker, request.session_id, request.input_ids)
+ request.cache_reused_tokens = plan.reused_tokens
+ request.prompt_inputs = list(plan.generation_inputs)
+ request.reset_state = plan.reset_state
+ request.started = True
+ request.assigned_worker_id = worker_id
+ truncate_to = plan.truncate_to
+ if request.cache_reused_tokens > 0:
+ service.cache_reuse_hits += 1
+ emit_role_chunk = request.stream
+
+ if truncate_to is not None:
+ worker.model.truncate(truncate_to)
+
+ if emit_role_chunk:
+ request.result_queue.put(
+ _sse_payload(
+ {
+ "id": request.completion_id,
+ "object": "chat.completion.chunk",
+ "created": request.created,
+ "model": request.request_model,
+ "session_id": request.session_id,
+ "worker_id": worker_id,
+ "batch_id": batch_id,
+ "cache_reused_tokens": request.cache_reused_tokens,
+ "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
+ }
+ )
+ )
+
+ token_source = request.prompt_inputs if not request.generated_ids else [request.generated_ids[-1]]
+ next_token = _model_generate_next(
+ worker.model,
+ token_source,
+ top_k=request.top_k,
+ top_p=request.top_p,
+ temperature=request.temperature,
+ reset_state=request.reset_state if not request.generated_ids else False,
+ )
+ request.generated_ids.append(next_token)
+ text_chunk, request.last_text = _decode_text_delta(
+ service.tokenizer,
+ request.generated_ids,
+ request.last_text,
+ )
+
+ finished = (
+ next_token == getattr(worker.model.meta, "end_token", None)
+ or len(request.generated_ids) >= request.max_tokens
+ )
+ assistant_content = service.tokenizer.decode(
+ request.generated_ids,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )
+ finish_reason = _finish_reason(
+ getattr(worker.model.meta, "end_token", None),
+ request.generated_ids,
+ request.max_tokens,
+ )
+
+ with service.condition:
+ worker.busy = False
+ worker.current_request_id = None
+ worker.last_used_at = time.time()
+ worker.total_steps += 1
+ worker.total_tokens_generated += 1
+ service.total_generated_tokens += 1
+
+ if finished:
+ request.finished_at = time.time()
+ _commit_session_locked(
+ service,
+ worker,
+ request.session_id,
+ request.messages + [{"role": "assistant", "content": assistant_content}],
+ request.input_ids + request.generated_ids,
+ )
+ worker.total_requests += 1
+ service.completed_requests += 1
+ if request.first_dispatched_at is not None:
+ service.total_queue_wait_s += max(0.0, request.first_dispatched_at - request.enqueued_at)
+ service.total_request_time_s += max(0.0, request.finished_at - request.enqueued_at)
+ _decrement_inflight_locked(service, request.session_id)
+ else:
+ worker.active_session_id = request.session_id
+ worker.active_token_ids = list(request.input_ids + request.generated_ids)
+ worker.needs_reset = False
+ service.requeued_requests += 1
+ service.pending_requests.append(request)
+ service.condition.notify_all()
+ service.condition.notify_all()
+
+ if request.stream and text_chunk:
+ request.result_queue.put(
+ _sse_payload(
+ {
+ "id": request.completion_id,
+ "object": "chat.completion.chunk",
+ "created": request.created,
+ "model": request.request_model,
+ "session_id": request.session_id,
+ "worker_id": worker_id,
+ "batch_id": batch_id,
+ "cache_reused_tokens": request.cache_reused_tokens,
+ "choices": [{"index": 0, "delta": {"content": text_chunk}, "finish_reason": None}],
+ }
+ )
+ )
+
+ if not finished:
+ return
+
+ if request.stream:
+ request.result_queue.put(
+ _sse_payload(
+ {
+ "id": request.completion_id,
+ "object": "chat.completion.chunk",
+ "created": request.created,
+ "model": request.request_model,
+ "session_id": request.session_id,
+ "worker_id": worker_id,
+ "batch_id": request.first_batch_id,
+ "last_batch_id": batch_id,
+ "dispatch_count": request.steps_dispatched,
+ "cache_reused_tokens": request.cache_reused_tokens,
+ "queue_wait_ms": (
+ max(0.0, (request.first_dispatched_at or request.enqueued_at) - request.enqueued_at) * 1000.0
+ ),
+ "total_time_ms": (
+ max(0.0, (request.finished_at or time.time()) - request.enqueued_at) * 1000.0
+ ),
+ "choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}],
+ }
+ )
+ )
+ request.result_queue.put("data: [DONE]\n\n")
+ request.result_queue.put(_STREAM_END)
+ return
+
+ request.result_queue.put(
+ {
+ "id": request.completion_id,
+ "object": "chat.completion",
+ "created": request.created,
+ "model": request.request_model,
+ "session_id": request.session_id,
+ "worker_id": worker_id,
+ "batch_id": request.first_batch_id,
+ "last_batch_id": batch_id,
+ "dispatch_count": request.steps_dispatched,
+ "cache_reused_tokens": request.cache_reused_tokens,
+ "queue_wait_ms": (
+ max(0.0, (request.first_dispatched_at or request.enqueued_at) - request.enqueued_at) * 1000.0
+ ),
+ "total_time_ms": (
+ max(0.0, (request.finished_at or time.time()) - request.enqueued_at) * 1000.0
+ ),
+ "choices": [
+ {
+ "index": 0,
+ "message": {"role": "assistant", "content": assistant_content},
+ "finish_reason": finish_reason,
+ }
+ ],
+ "usage": _usage(len(request.input_ids), len(request.generated_ids)),
+ }
+ )
+ except Exception as exc: # pragma: no cover - defensive path
+ with service.condition:
+ _detach_worker_session_locked(service, worker)
+ worker.busy = False
+ worker.current_request_id = None
+ worker.last_used_at = time.time()
+ service.failed_requests += 1
+ _decrement_inflight_locked(service, request.session_id)
+ service.condition.notify_all()
+
+ if request.stream:
+ request.result_queue.put(
+ _sse_payload(
+ {
+ "id": request.completion_id,
+ "object": "chat.completion.chunk",
+ "created": request.created,
+ "model": request.request_model,
+ "session_id": request.session_id,
+ "worker_id": worker.worker_id,
+ "batch_id": request.batch_id,
+ "choices": [{"index": 0, "delta": {"content": ""}, "finish_reason": "error"}],
+ "error": str(exc),
+ }
+ )
+ )
+ request.result_queue.put("data: [DONE]\n\n")
+ request.result_queue.put(_STREAM_END)
+ return
+
+ request.result_queue.put(exc)
+
+
+def _shutdown_service(service: ServiceState) -> None:
+ with service.condition:
+ if service.shutdown:
+ return
+ service.shutdown = True
+ service.condition.notify_all()
+
+ if service.scheduler_thread is not None:
+ service.scheduler_thread.join(timeout=5)
+ service.scheduler_thread = None
+
+ for worker in service.workers:
+ worker.job_queue.put(None)
+ for worker in service.workers:
+ if worker.thread is not None:
+ worker.thread.join(timeout=5)
+ worker.thread = None
+
+
+def _resolve_runtime_components(
+ model_path: Optional[str],
+ device: str,
+ device_id: int,
+ tp_size: int,
+ tp_device_ids: Optional[List[int]],
+ *,
+ tokenizer,
+ model,
+ model_factory,
+ model_name: Optional[str],
+ num_workers: int,
+) -> tuple[Any, Callable[[], Any], str]:
+ if tokenizer is None:
+ if model_path is None:
+ raise ValueError("model_path is required when tokenizer is not injected")
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ if model is not None and model_factory is not None:
+ raise ValueError("Pass either `model` or `model_factory`, not both")
+
+ if model_factory is None:
+ if model is not None:
+ if num_workers != 1:
+ raise ValueError("model_factory is required when injecting a model with num_workers > 1")
+ model_factory = lambda: model
+ resolved_model_name = model_name or "llaisys-chat"
+ return tokenizer, model_factory, resolved_model_name
+
+ if model_path is None:
+ raise ValueError("model_path is required when model/model_factory are not injected")
+
+ device_type = _device_from_name(device)
+
+ def model_factory() -> Any:
+ return llaisys.models.create_model(
+ model_path,
+ device_type,
+ device_id,
+ tp_size=tp_size,
+ tp_device_ids=tp_device_ids,
+ )
+
+ resolved_model_name = model_name or llaisys.models.default_model_name(model_path)
+ return tokenizer, model_factory, resolved_model_name
+
+ resolved_model_name = model_name or "llaisys-chat"
+ return tokenizer, model_factory, resolved_model_name
+
+
+def create_app(
+ model_path: Optional[str] = None,
+ device: str = "cpu",
+ device_id: int = 0,
+ tp_size: int = 1,
+ tp_device_ids: Optional[List[int]] = None,
+ *,
+ tokenizer=None,
+ model=None,
+ model_factory: Optional[Callable[[], Any]] = None,
+ model_name: Optional[str] = None,
+ num_workers: int = 1,
+ max_batch_size: int = 1,
+ batch_wait_ms: int = 5,
+) -> FastAPI:
+ if num_workers <= 0:
+ raise ValueError("num_workers must be positive")
+ if max_batch_size <= 0:
+ raise ValueError("max_batch_size must be positive")
+ if batch_wait_ms < 0:
+ raise ValueError("batch_wait_ms must be non-negative")
+ if tp_size <= 0:
+ raise ValueError("tp_size must be positive")
+ if tp_size > 1 and device != "nvidia":
+ raise ValueError("Tensor parallel service mode currently supports only NVIDIA devices")
+ if tp_size > 1 and num_workers != 1:
+ raise ValueError("tp_size > 1 currently requires num_workers == 1")
+
+ tokenizer, model_factory, resolved_model_name = _resolve_runtime_components(
+ model_path,
+ device,
+ device_id,
+ tp_size,
+ tp_device_ids,
+ tokenizer=tokenizer,
+ model=model,
+ model_factory=model_factory,
+ model_name=model_name,
+ num_workers=num_workers,
+ )
+
+ workers = [WorkerState(worker_id=i, model=model_factory()) for i in range(num_workers)]
+ service = ServiceState(
+ tokenizer=tokenizer,
+ model_name=resolved_model_name,
+ workers=workers,
+ max_batch_size=max_batch_size,
+ batch_wait_s=batch_wait_ms / 1000.0,
+ )
+
+ for worker in service.workers:
+ worker.thread = threading.Thread(
+ target=_worker_loop,
+ args=(service, worker),
+ name=f"llaisys-worker-{worker.worker_id}",
+ daemon=True,
+ )
+ worker.thread.start()
+
+ service.scheduler_thread = threading.Thread(
+ target=_scheduler_loop,
+ args=(service,),
+ name="llaisys-scheduler",
+ daemon=True,
+ )
+ service.scheduler_thread.start()
+
+ app = FastAPI(title="LLAISYS Chat API", version="0.3.0")
+ app.state.service = service
+ app.state.tokenizer = tokenizer
+ app.state.model_name = resolved_model_name
+
+ @app.on_event("shutdown")
+ def shutdown_event() -> None:
+ _shutdown_service(app.state.service)
+
+ @app.get("/health")
+ def health() -> Dict[str, Any]:
+ stats = _service_stats(app.state.service)
+ return {
+ "status": "ok",
+ "model": app.state.model_name,
+ "sessions": stats["session_count"],
+ "queue_depth": stats["queue_depth"],
+ "active_requests": stats["active_requests"],
+ "worker_count": stats["worker_count"],
+ "tp_size": tp_size,
+ "tp_device_ids": list(tp_device_ids or []),
+ }
+
+ @app.get("/v1/service/stats")
+ def service_stats() -> Dict[str, Any]:
+ return _service_stats(app.state.service)
+
+ @app.get("/v1/sessions")
+ def list_sessions() -> Dict[str, Any]:
+ service = app.state.service
+ with service.condition:
+ sessions = sorted(service.sessions.values(), key=lambda item: item.updated_at, reverse=True)
+ return {
+ "object": "list",
+ "data": [
+ _session_payload(
+ session,
+ inflight=service.session_inflight.get(session.session_id, 0),
+ )
+ for session in sessions
+ ],
+ }
+
+ @app.post("/v1/sessions")
+ def create_session(payload: Optional[Dict[str, Any]] = Body(default=None)) -> JSONResponse:
+ payload = payload or {}
+ session_id = _resolve_session_id(payload)
+ messages = payload.get("messages")
+ normalized_messages = _normalize_messages(messages) if messages else []
+ service = app.state.service
+
+ with service.condition:
+ if session_id in service.sessions:
+ raise HTTPException(status_code=409, detail=f"Session `{session_id}` already exists")
+
+ now = time.time()
+ service.sessions[session_id] = SessionState(
+ session_id=session_id,
+ messages=normalized_messages,
+ token_ids=[],
+ created_at=now,
+ updated_at=now,
+ )
+ return JSONResponse(_session_payload(service.sessions[session_id], include_messages=True))
+
+ @app.get("/v1/sessions/{session_id}")
+ def get_session(session_id: str) -> Dict[str, Any]:
+ service = app.state.service
+ with service.condition:
+ state = service.sessions.get(session_id)
+ if state is None:
+ raise HTTPException(status_code=404, detail=f"Unknown session `{session_id}`")
+ payload = _session_payload(
+ state,
+ include_messages=True,
+ inflight=service.session_inflight.get(session_id, 0),
+ )
+ payload["active"] = state.worker_id is not None
+ return payload
+
+ @app.put("/v1/sessions/{session_id}")
+ def replace_session(session_id: str, payload: Dict[str, Any] = Body(...)) -> Dict[str, Any]:
+ if "messages" not in payload:
+ raise HTTPException(status_code=400, detail="`messages` is required")
+
+ try:
+ normalized_messages = _normalize_messages(payload["messages"])
+ except (TypeError, ValueError) as exc:
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
+
+ service = app.state.service
+ with service.condition:
+ state = service.sessions.get(session_id)
+ if state is None:
+ raise HTTPException(status_code=404, detail=f"Unknown session `{session_id}`")
+ if service.session_inflight.get(session_id, 0) > 0:
+ raise HTTPException(status_code=409, detail=f"Session `{session_id}` is busy")
+
+ _clear_session_cache_locked(service, session_id)
+ state.messages = normalized_messages
+ state.updated_at = time.time()
+ return _session_payload(state, include_messages=True)
+
+ @app.delete("/v1/sessions/{session_id}")
+ def delete_session(session_id: str) -> Dict[str, Any]:
+ service = app.state.service
+ with service.condition:
+ state = service.sessions.get(session_id)
+ if state is None:
+ raise HTTPException(status_code=404, detail=f"Unknown session `{session_id}`")
+ if service.session_inflight.get(session_id, 0) > 0:
+ raise HTTPException(status_code=409, detail=f"Session `{session_id}` is busy")
+
+ _clear_session_cache_locked(service, session_id)
+ service.sessions.pop(session_id, None)
+ return {"id": session_id, "object": "session.deleted", "deleted": True}
+
+ @app.post("/v1/chat/completions")
+ async def chat_completions(payload: Dict[str, Any] = Body(...)):
+ service = app.state.service
+ session_id = _resolve_session_id(payload)
+ with service.condition:
+ stored_session = service.sessions.get(session_id)
+
+ raw_messages = payload.get("messages")
+ if raw_messages is None:
+ if stored_session is None:
+ raise HTTPException(status_code=400, detail="`messages` is required for a new session")
+ raw_messages = stored_session.messages
+
+ try:
+ messages = _normalize_messages(raw_messages)
+ input_ids = _build_input_ids(service.tokenizer, messages)
+ max_tokens = int(payload.get("max_tokens", payload.get("max_completion_tokens", 128)))
+ top_k = int(payload.get("top_k", 1))
+ top_p = float(payload.get("top_p", 1.0))
+ temperature = float(payload.get("temperature", 1.0))
+ stream = bool(payload.get("stream", False))
+ except (TypeError, ValueError) as exc:
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
+
+ if max_tokens <= 0:
+ raise HTTPException(status_code=400, detail="`max_tokens` must be positive")
+
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
+ created = int(time.time())
+ request_model = str(payload.get("model", app.state.model_name))
+ result_queue: "queue.Queue[Any]" = queue.Queue()
+ request = CompletionRequest(
+ completion_id=completion_id,
+ session_id=session_id,
+ messages=messages,
+ input_ids=input_ids,
+ max_tokens=max_tokens,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ stream=stream,
+ created=created,
+ request_model=request_model,
+ result_queue=result_queue,
+ enqueued_at=time.time(),
+ )
+
+ with service.condition:
+ if service.session_inflight.get(session_id, 0) > 0:
+ raise HTTPException(status_code=409, detail=f"Session `{session_id}` already has an in-flight request")
+ service.session_inflight[session_id] = service.session_inflight.get(session_id, 0) + 1
+ service.pending_requests.append(request)
+ service.total_enqueued += 1
+ service.condition.notify_all()
+
+ if stream:
+ def event_stream() -> Iterable[str]:
+ while True:
+ item = result_queue.get()
+ if item is _STREAM_END:
+ break
+ yield item
+
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
+
+ result = result_queue.get()
+ if isinstance(result, Exception):
+ raise HTTPException(status_code=500, detail=str(result))
+ return JSONResponse(result)
+
+ return app
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", required=True, type=str)
+ parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str)
+ parser.add_argument("--device-id", default=0, type=int)
+ parser.add_argument("--host", default="127.0.0.1", type=str)
+ parser.add_argument("--port", default=8000, type=int)
+ parser.add_argument("--model-name", default=None, type=str)
+ parser.add_argument("--num-workers", default=1, type=int)
+ parser.add_argument("--max-batch-size", default=1, type=int)
+ parser.add_argument("--batch-wait-ms", default=5, type=int)
+ parser.add_argument("--tp-size", default=1, type=int)
+ parser.add_argument("--tp-device-ids", default=None, type=str)
+ args = parser.parse_args()
+
+ try:
+ import uvicorn
+ except ImportError as exc: # pragma: no cover - optional dependency
+ raise RuntimeError(
+ "Uvicorn is optional. Install with `pip install ./python[server]`."
+ ) from exc
+
+ tp_device_ids = None
+ if args.tp_device_ids:
+ tp_device_ids = [int(part.strip()) for part in args.tp_device_ids.split(",") if part.strip()]
+
+ app = create_app(
+ model_path=args.model_path,
+ device=args.device,
+ device_id=args.device_id,
+ tp_size=args.tp_size,
+ tp_device_ids=tp_device_ids,
+ model_name=args.model_name,
+ num_workers=args.num_workers,
+ max_batch_size=args.max_batch_size,
+ batch_wait_ms=args.batch_wait_ms,
+ )
+ uvicorn.run(app, host=args.host, port=args.port)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/python/llaisys/libllaisys/llaisys_types.py b/python/llaisys/libllaisys/llaisys_types.py
index c5a0b467..af9998f0 100644
--- a/python/llaisys/libllaisys/llaisys_types.py
+++ b/python/llaisys/libllaisys/llaisys_types.py
@@ -2,7 +2,6 @@
from enum import IntEnum
-# Device Type enum
class DeviceType(IntEnum):
CPU = 0
NVIDIA = 1
@@ -12,7 +11,6 @@ class DeviceType(IntEnum):
llaisysDeviceType_t = ctypes.c_int
-# Data Type enum
class DataType(IntEnum):
INVALID = 0
BYTE = 1
@@ -39,7 +37,6 @@ class DataType(IntEnum):
llaisysDataType_t = ctypes.c_int
-# Memory Copy Kind enum
class MemcpyKind(IntEnum):
H2H = 0
H2D = 1
@@ -48,8 +45,13 @@ class MemcpyKind(IntEnum):
llaisysMemcpyKind_t = ctypes.c_int
+llaisysTensor_t = ctypes.c_void_p
+
+class LlaisysQwen2Model(ctypes.Structure):
+ pass
+llaisysQwen2ModelHandle = ctypes.POINTER(LlaisysQwen2Model)
+llaisysQwen2Weights_p = ctypes.c_void_p
-# Stream type (opaque pointer)
llaisysStream_t = ctypes.c_void_p
__all__ = [
diff --git a/python/llaisys/libllaisys/models.py b/python/llaisys/libllaisys/models.py
new file mode 100644
index 00000000..b8a44911
--- /dev/null
+++ b/python/llaisys/libllaisys/models.py
@@ -0,0 +1,42 @@
+import ctypes
+from .llaisys_types import llaisysDataType_t, llaisysTensor_t, llaisysDeviceType_t
+
+class LlaisysQwen2Meta(ctypes.Structure):
+ _fields_ = [
+ ("dtype", llaisysDataType_t),
+ ("nlayer", ctypes.c_size_t),
+ ("hs", ctypes.c_size_t),
+ ("nh", ctypes.c_size_t),
+ ("nkvh", ctypes.c_size_t),
+ ("dh", ctypes.c_size_t),
+ ("di", ctypes.c_size_t),
+ ("maxseq", ctypes.c_size_t),
+ ("voc", ctypes.c_size_t),
+ ("epsilon", ctypes.c_float),
+ ("theta", ctypes.c_float),
+ ("end_token", ctypes.c_int64),
+ ]
+
+class LlaisysQwen2Weights(ctypes.Structure):
+ _fields_ = [
+ ("in_embed", llaisysTensor_t),
+ ("out_embed", llaisysTensor_t),
+ ("out_norm_w", llaisysTensor_t),
+ ("attn_norm_w", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_q_w", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_q_b", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_k_w", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_k_b", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_v_w", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_v_b", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_o_w", ctypes.POINTER(llaisysTensor_t)),
+ ("mlp_norm_w", ctypes.POINTER(llaisysTensor_t)),
+ ("mlp_gate_w", ctypes.POINTER(llaisysTensor_t)),
+ ("mlp_up_w", ctypes.POINTER(llaisysTensor_t)),
+ ("mlp_down_w", ctypes.POINTER(llaisysTensor_t)),
+ ]
+
+class LlaisysQwen2Model(ctypes.Structure):
+ pass
+
+llaisysQwen2ModelHandle = ctypes.POINTER(LlaisysQwen2Model)
diff --git a/python/llaisys/libllaisys/ops.py b/python/llaisys/libllaisys/ops.py
index 5be095ef..c56f4749 100644
--- a/python/llaisys/libllaisys/ops.py
+++ b/python/llaisys/libllaisys/ops.py
@@ -1,5 +1,5 @@
from .tensor import llaisysTensor_t
-from ctypes import c_float
+from ctypes import c_float, c_int, c_int64
def load_ops(lib):
lib.llaisysAdd.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t]
@@ -14,6 +14,9 @@ def load_ops(lib):
lib.llaisysLinear.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t, llaisysTensor_t]
lib.llaisysLinear.restype = None
+ lib.llaisysSample.argtypes = [llaisysTensor_t, c_int, c_float, c_float]
+ lib.llaisysSample.restype = c_int64
+
lib.llaisysRearrange.argtypes = [llaisysTensor_t, llaisysTensor_t]
lib.llaisysRearrange.restype = None
diff --git a/python/llaisys/models/__init__.py b/python/llaisys/models/__init__.py
index af9918b0..5b2b096d 100644
--- a/python/llaisys/models/__init__.py
+++ b/python/llaisys/models/__init__.py
@@ -1 +1,59 @@
+import json
+from pathlib import Path
+from typing import Type
+
+from ..libllaisys import DeviceType
+from .llama import Llama
from .qwen2 import Qwen2
+from .tensor_parallel import TensorParallelQwen2
+
+MODEL_CLASSES: tuple[Type[Qwen2], ...] = (Qwen2, Llama)
+
+
+def detect_model_type(model_path: str) -> str:
+ config_path = Path(model_path) / "config.json"
+ with open(config_path, "r") as handle:
+ config = json.load(handle)
+ return str(config.get("model_type", "")).lower()
+
+
+def create_model(
+ model_path: str,
+ device: DeviceType = DeviceType.CPU,
+ device_id: int = 0,
+ *,
+ tp_size: int = 1,
+ tp_device_ids=None,
+):
+ model_type = detect_model_type(model_path)
+ if int(tp_size) > 1:
+ if device != DeviceType.NVIDIA:
+ raise ValueError("Tensor parallel inference currently supports only NVIDIA devices")
+ if TensorParallelQwen2.supports_model_type(model_type):
+ return TensorParallelQwen2(
+ model_path,
+ device,
+ device_id,
+ tp_size=int(tp_size),
+ tp_device_ids=tp_device_ids,
+ )
+ raise ValueError(f"Tensor parallel inference is not supported for model type: {model_type or ''}")
+ for cls in MODEL_CLASSES:
+ if cls.supports_model_type(model_type):
+ return cls(model_path, device, device_id)
+ raise ValueError(f"Unsupported model type: {model_type or ''}")
+
+
+def default_model_name(model_path: str) -> str:
+ model_type = detect_model_type(model_path)
+ return f"llaisys-{model_type or 'model'}"
+
+
+__all__ = [
+ "Qwen2",
+ "Llama",
+ "TensorParallelQwen2",
+ "create_model",
+ "detect_model_type",
+ "default_model_name",
+]
diff --git a/python/llaisys/models/llama.py b/python/llaisys/models/llama.py
new file mode 100644
index 00000000..888eaaf5
--- /dev/null
+++ b/python/llaisys/models/llama.py
@@ -0,0 +1,5 @@
+from .qwen2 import Qwen2
+
+
+class Llama(Qwen2):
+ SUPPORTED_MODEL_TYPES = {"llama", "mistral"}
diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py
index 0d07b0b2..562534ec 100644
--- a/python/llaisys/models/qwen2.py
+++ b/python/llaisys/models/qwen2.py
@@ -1,33 +1,387 @@
-from typing import Sequence
-from ..libllaisys import LIB_LLAISYS
-from ..libllaisys import DeviceType
-
+import ctypes
+import json
+import mmap
+import os
+import struct
from pathlib import Path
-import safetensors
+from typing import ClassVar, Dict, Iterator, List, Sequence, Tuple
+
+import numpy as np
+
+from ..libllaisys import LIB_LLAISYS, llaisysTensor_t, llaisysDataType_t, llaisysDeviceType_t, DataType, DeviceType
+from ..libllaisys.models import LlaisysQwen2Meta, LlaisysQwen2Weights, LlaisysQwen2Model, llaisysQwen2ModelHandle
+from ..tensor import Tensor
+
+LIB_LLAISYS.llaisysQwen2ModelCreate.argtypes = [ctypes.POINTER(LlaisysQwen2Meta), llaisysDeviceType_t, ctypes.POINTER(ctypes.c_int), ctypes.c_int]
+LIB_LLAISYS.llaisysQwen2ModelCreate.restype = llaisysQwen2ModelHandle
+
+LIB_LLAISYS.llaisysQwen2ModelDestroy.argtypes = [llaisysQwen2ModelHandle]
+LIB_LLAISYS.llaisysQwen2ModelDestroy.restype = None
+
+LIB_LLAISYS.llaisysQwen2ModelWeights.argtypes = [llaisysQwen2ModelHandle]
+LIB_LLAISYS.llaisysQwen2ModelWeights.restype = ctypes.POINTER(LlaisysQwen2Weights)
+
+LIB_LLAISYS.llaisysQwen2ModelReset.argtypes = [llaisysQwen2ModelHandle]
+LIB_LLAISYS.llaisysQwen2ModelReset.restype = None
+
+LIB_LLAISYS.llaisysQwen2ModelTruncate.argtypes = [llaisysQwen2ModelHandle, ctypes.c_size_t]
+LIB_LLAISYS.llaisysQwen2ModelTruncate.restype = None
+
+LIB_LLAISYS.llaisysQwen2ModelInfer.argtypes = [llaisysQwen2ModelHandle, ctypes.POINTER(ctypes.c_int64), ctypes.c_size_t]
+LIB_LLAISYS.llaisysQwen2ModelInfer.restype = ctypes.c_int64
+
+LIB_LLAISYS.llaisysQwen2ModelGenerateNext.argtypes = [
+ llaisysQwen2ModelHandle,
+ ctypes.POINTER(ctypes.c_int64),
+ ctypes.c_size_t,
+ ctypes.c_int,
+ ctypes.c_float,
+ ctypes.c_float,
+]
+LIB_LLAISYS.llaisysQwen2ModelGenerateNext.restype = ctypes.c_int64
class Qwen2:
+ SUPPORTED_MODEL_TYPES: ClassVar[set[str]] = {"qwen2"}
- def __init__(self, model_path, device: DeviceType = DeviceType.CPU):
- # TODO: Implement model constructor
+ @classmethod
+ def supports_model_type(cls, model_type: str) -> bool:
+ return str(model_type).lower() in cls.SUPPORTED_MODEL_TYPES
+ def __init__(self, model_path: str, device: DeviceType = DeviceType.CPU, device_id: int = 0):
model_path = Path(model_path)
+ config_path = model_path / "config.json"
+
+ with open(config_path, "r") as f:
+ config = json.load(f)
+ self.config = config
+ self.model_type = str(config.get("model_type", "")).lower()
+
+ self.device = device
+ self.device_id = device_id
+
+ self.meta = LlaisysQwen2Meta()
+ self._configure_meta(config)
+
+ dev_ids = (ctypes.c_int * 1)(device_id)
+ self.handle = LIB_LLAISYS.llaisysQwen2ModelCreate(ctypes.byref(self.meta), device, dev_ids, 1)
+ self.weights_ptr = LIB_LLAISYS.llaisysQwen2ModelWeights(self.handle)
+ self.tensors_ref = []
for file in sorted(model_path.glob("*.safetensors")):
- data_ = safetensors.safe_open(file, framework="numpy", device="cpu")
- for name_ in data_.keys():
- ## TODO: load the model weights
- pass
+ print(f"Loading weights from {file}...")
+ weights_data = self._load_safetensors(file)
+ for key, (arr, shape, source_dtype) in weights_data.items():
+ raw = self._convert_array_for_model_dtype(arr, source_dtype)
+ if not raw.flags["C_CONTIGUOUS"]:
+ raw = np.ascontiguousarray(raw)
+
+ t = Tensor(list(shape), self.meta.dtype, device, device_id)
+ t.load(ctypes.c_void_p(raw.ctypes.data))
+ self._assign_weight(key, t)
+
+ self._finalize_weights()
+
+ def _configure_meta(self, config: Dict[str, object]) -> None:
+ dtype_str = str(config.get("torch_dtype", "float32"))
+ self.meta.dtype = self._runtime_dtype(dtype_str, self.device)
+ self.meta.nlayer = int(config.get("num_hidden_layers", 24))
+ self.meta.hs = int(config.get("hidden_size", 2048))
+ self.meta.nh = int(config.get("num_attention_heads", 16))
+ self.meta.nkvh = int(config.get("num_key_value_heads", self.meta.nh))
+ self.meta.dh = self.meta.hs // self.meta.nh
+ self.meta.di = int(config.get("intermediate_size", 11008))
+ self.meta.maxseq = int(config.get("max_position_embeddings", 8192))
+ self.meta.voc = int(config.get("vocab_size", 151936))
+ self.meta.epsilon = float(config.get("rms_norm_eps", 1e-6))
+ self.meta.theta = float(config.get("rope_theta", 1000000.0))
+ eos_token = config.get("eos_token_id", 151643)
+ if isinstance(eos_token, list):
+ eos_token = eos_token[0] if eos_token else 151643
+ self.meta.end_token = int(eos_token)
+
+ @staticmethod
+ def _llaisys_dtype_from_string(dtype_str: str) -> DataType:
+ normalized = str(dtype_str).lower()
+ if normalized in {"bfloat16", "bf16"}:
+ return DataType.BF16
+ if normalized in {"float16", "f16", "half"}:
+ return DataType.F16
+ return DataType.F32
+
+ @staticmethod
+ def _runtime_dtype(dtype_str: str, device: DeviceType) -> DataType:
+ # CPU inference is currently optimized only for F32 kernels. Keep that fast path
+ # as the default, and allow native 16-bit weights only as an explicit opt-in.
+ if device == DeviceType.CPU and os.getenv("LLAISYS_CPU_NATIVE_DTYPE", "").lower() not in {"1", "true", "yes", "on"}:
+ return DataType.F32
+ return Qwen2._llaisys_dtype_from_string(dtype_str)
+
+ def _load_safetensors(self, path: Path) -> Dict[str, Tuple[np.ndarray, Sequence[int], str]]:
+ tensors = {}
+ with open(path, "rb") as f:
+ length_bytes = f.read(8)
+ if not length_bytes:
+ return {}
+ header_size = struct.unpack(" np.ndarray:
+ target_dtype = self.meta.dtype
+ if target_dtype == DataType.F32:
+ return self._to_float32_array(arr, source_dtype)
+ if target_dtype == DataType.F16:
+ return self._to_float16_array(arr, source_dtype)
+ if target_dtype == DataType.BF16:
+ return self._to_bfloat16_bytes(arr, source_dtype)
+ raise ValueError(f"Unsupported model dtype: {target_dtype}")
+
+ @staticmethod
+ def _to_float32_array(arr: np.ndarray, source_dtype: str) -> np.ndarray:
+ if source_dtype == "bf16":
+ u32 = arr.astype(np.uint32) << 16
+ return u32.view(np.float32)
+ if source_dtype == "f16":
+ return arr.astype(np.float32)
+ return np.asarray(arr, dtype=np.float32)
+
+ @staticmethod
+ def _to_float16_array(arr: np.ndarray, source_dtype: str) -> np.ndarray:
+ if source_dtype == "f16":
+ return np.asarray(arr, dtype=np.float16)
+ return Qwen2._to_float32_array(arr, source_dtype).astype(np.float16)
+
+ @staticmethod
+ def _to_bfloat16_bytes(arr: np.ndarray, source_dtype: str) -> np.ndarray:
+ if source_dtype == "bf16":
+ return np.asarray(arr, dtype=np.uint16)
+ float32_arr = Qwen2._to_float32_array(arr, source_dtype)
+ u32 = float32_arr.view(np.uint32)
+ return (u32 >> 16).astype(np.uint16)
+
+ def _finalize_weights(self) -> None:
+ weights = self.weights_ptr.contents
+ if weights.out_embed:
+ return
+ if self.config.get("tie_word_embeddings") and weights.in_embed:
+ weights.out_embed = weights.in_embed
+ return
+ raise ValueError("Missing output embedding weights: expected `lm_head.weight` or tied word embeddings")
+
+ def _assign_weight(self, name: str, t: Tensor):
+ w = self.weights_ptr.contents
+ # Keep Python Tensor wrappers alive because the model stores raw C tensor handles.
+ self.tensors_ref.append(t)
+ if name == "model.embed_tokens.weight":
+ w.in_embed = t.lib_tensor()
+ elif name == "lm_head.weight":
+ w.out_embed = t.lib_tensor()
+ elif name == "model.norm.weight":
+ w.out_norm_w = t.lib_tensor()
+ elif name.startswith("model.layers."):
+ parts = name.split(".")
+ layer_idx = int(parts[2])
+ suffix = ".".join(parts[3:])
+
+ def set_w(target_ptr):
+ target_ptr[layer_idx] = t.lib_tensor()
+
+ if suffix == "input_layernorm.weight":
+ set_w(w.attn_norm_w)
+ elif suffix == "self_attn.q_proj.weight":
+ set_w(w.attn_q_w)
+ elif suffix == "self_attn.q_proj.bias":
+ set_w(w.attn_q_b)
+ elif suffix == "self_attn.k_proj.weight":
+ set_w(w.attn_k_w)
+ elif suffix == "self_attn.k_proj.bias":
+ set_w(w.attn_k_b)
+ elif suffix == "self_attn.v_proj.weight":
+ set_w(w.attn_v_w)
+ elif suffix == "self_attn.v_proj.bias":
+ set_w(w.attn_v_b)
+ elif suffix == "self_attn.o_proj.weight":
+ set_w(w.attn_o_w)
+ elif suffix == "post_attention_layernorm.weight":
+ set_w(w.mlp_norm_w)
+ elif suffix == "mlp.gate_proj.weight":
+ set_w(w.mlp_gate_w)
+ elif suffix == "mlp.up_proj.weight":
+ set_w(w.mlp_up_w)
+ elif suffix == "mlp.down_proj.weight":
+ set_w(w.mlp_down_w)
+
+ def __del__(self):
+ if hasattr(self, "handle") and self.handle:
+ LIB_LLAISYS.llaisysQwen2ModelDestroy(self.handle)
+
+ def reset(self) -> None:
+ LIB_LLAISYS.llaisysQwen2ModelReset(self.handle)
+
+ def truncate(self, position: int) -> None:
+ if position < 0:
+ raise ValueError("position must be non-negative")
+ LIB_LLAISYS.llaisysQwen2ModelTruncate(self.handle, int(position))
+
+ def _generate_next(
+ self,
+ inputs: Sequence[int],
+ top_k: int = 1,
+ top_p: float = 1.0,
+ temperature: float = 1.0,
+ ) -> int:
+ if not inputs:
+ raise ValueError("inputs must not be empty")
+
+ arr = (ctypes.c_int64 * len(inputs))(*inputs)
+ return int(
+ LIB_LLAISYS.llaisysQwen2ModelGenerateNext(
+ self.handle,
+ arr,
+ len(inputs),
+ int(top_k),
+ float(top_p),
+ float(temperature),
+ )
+ )
+
+ def generate_next(
+ self,
+ inputs: Sequence[int],
+ top_k: int = 1,
+ top_p: float = 1.0,
+ temperature: float = 1.0,
+ *,
+ reset_state: bool = False,
+ ) -> int:
+ if reset_state:
+ self.reset()
+ return self._generate_next(inputs, top_k=top_k, top_p=top_p, temperature=temperature)
def generate(
self,
inputs: Sequence[int],
- max_new_tokens: int = None,
+ max_new_tokens: int = 20,
top_k: int = 1,
top_p: float = 0.8,
temperature: float = 0.8,
- ):
+ *,
+ reset_state: bool = True,
+ ) -> List[int]:
+ if not inputs:
+ raise ValueError("inputs must not be empty")
+
+ if reset_state:
+ self.reset()
+ generated = []
+ tokens = list(inputs)
+
+ next_token = self.generate_next(
+ tokens,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ )
+ generated.append(next_token)
+ tokens = [next_token]
+
+ for _ in range(max_new_tokens - 1):
+ next_token = self.generate_next(
+ tokens,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ )
+ generated.append(next_token)
+ tokens = [next_token]
+
+ if next_token == self.meta.end_token:
+ break
+
+ return list(inputs) + generated
+
+ def stream_generate(
+ self,
+ inputs: Sequence[int],
+ *,
+ tokenizer=None,
+ max_new_tokens: int = 20,
+ top_k: int = 1,
+ top_p: float = 0.8,
+ temperature: float = 0.8,
+ reset_state: bool = True,
+ ) -> Iterator[Tuple[int, str]]:
+ if not inputs:
+ raise ValueError("inputs must not be empty")
+
+ if reset_state:
+ self.reset()
+ prompt_tokens = list(inputs)
+ generated: List[int] = []
+ last_text = ""
+
+ for step in range(max_new_tokens):
+ token_source = prompt_tokens if step == 0 else [generated[-1]]
+ next_token = self.generate_next(
+ token_source,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ )
+ generated.append(next_token)
+
+ text_chunk = ""
+ if tokenizer is not None:
+ decoded = tokenizer.decode(
+ generated,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )
+ text_chunk = decoded[len(last_text):] if decoded.startswith(last_text) else decoded
+ last_text = decoded
- # TODO: Implement generate function
+ yield next_token, text_chunk
- return []
+ if next_token == self.meta.end_token:
+ break
diff --git a/python/llaisys/models/tensor_parallel.py b/python/llaisys/models/tensor_parallel.py
new file mode 100644
index 00000000..07d43f1f
--- /dev/null
+++ b/python/llaisys/models/tensor_parallel.py
@@ -0,0 +1,803 @@
+import ctypes
+import json
+import mmap
+import os
+import queue
+import socket
+import struct
+import math
+from dataclasses import dataclass
+from pathlib import Path
+from typing import ClassVar, Dict, Iterator, List, Optional, Sequence, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+from ..libllaisys import DataType, DeviceType, MemcpyKind
+from ..ops import Ops
+from ..runtime import RuntimeAPI
+from ..tensor import Tensor
+from .qwen2 import Qwen2
+
+
+def _tp_log(rank: int, message: str) -> None:
+ if os.getenv("LLAISYS_TP_DEBUG", "").lower() not in {"1", "true", "yes", "on"}:
+ return
+ with open(f"/tmp/llaisys_tp_rank{rank}.log", "a", encoding="utf-8") as handle:
+ handle.write(f"{message}\n")
+
+
+@dataclass
+class _TensorParallelMeta:
+ dtype: DataType
+ nlayer: int
+ hs: int
+ nh: int
+ nkvh: int
+ dh: int
+ di: int
+ maxseq: int
+ voc: int
+ epsilon: float
+ theta: float
+ end_token: int
+
+
+def _divisors(value: int) -> List[int]:
+ divisors: List[int] = []
+ for candidate in range(1, int(math.isqrt(value)) + 1):
+ if value % candidate != 0:
+ continue
+ divisors.append(candidate)
+ paired = value // candidate
+ if paired != candidate:
+ divisors.append(paired)
+ return sorted(divisors)
+
+
+def _valid_tp_sizes(meta: _TensorParallelMeta) -> List[int]:
+ common = math.gcd(math.gcd(int(meta.nh), int(meta.nkvh)), int(meta.di))
+ return _divisors(common)
+
+
+def _validate_tp_size(meta: _TensorParallelMeta, tp_size: int) -> None:
+ tp_size = int(tp_size)
+ valid_sizes = _valid_tp_sizes(meta)
+ if tp_size in valid_sizes:
+ return
+ raise ValueError(
+ "Unsupported tp_size for current tensor-parallel implementation: "
+ f"tp_size={tp_size}, num_attention_heads={meta.nh}, "
+ f"num_key_value_heads={meta.nkvh}, intermediate_size={meta.di}, "
+ f"valid_tp_sizes={valid_sizes}. "
+ "This implementation shards attention heads, KV heads, and MLP intermediate dims evenly across ranks."
+ )
+
+
+def _find_free_port() -> int:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.bind(("127.0.0.1", 0))
+ return int(sock.getsockname()[1])
+
+
+def _llaisys_dtype_to_torch(dtype: DataType) -> torch.dtype:
+ if dtype == DataType.F32:
+ return torch.float32
+ if dtype == DataType.F16:
+ return torch.float16
+ if dtype == DataType.BF16:
+ return torch.bfloat16
+ raise ValueError(f"Unsupported runtime dtype: {dtype}")
+
+
+def _resolve_tp_device_ids(device_id: int, tp_size: int, tp_device_ids: Optional[Sequence[int]]) -> List[int]:
+ if tp_device_ids is not None:
+ device_ids = [int(item) for item in tp_device_ids]
+ else:
+ device_ids = list(range(int(device_id), int(device_id) + int(tp_size)))
+ if len(device_ids) != int(tp_size):
+ raise ValueError("tp_device_ids length must match tp_size")
+ return device_ids
+
+
+def _normalize_command_payload(command: str, seq: int, **payload) -> Dict[str, object]:
+ message = {"command": command, "seq": int(seq)}
+ message.update(payload)
+ return message
+
+
+class _ShardedQwen2Rank:
+ def __init__(self, model_path: str, rank: int, world_size: int, device_id: int):
+ self.model_path = Path(model_path)
+ self.rank = int(rank)
+ self.world_size = int(world_size)
+ self.device_id = int(device_id)
+ self.device = DeviceType.NVIDIA
+ self.runtime = RuntimeAPI(DeviceType.NVIDIA)
+ self.runtime.set_device(self.device_id)
+ torch.cuda.set_device(self.device_id)
+
+ with open(self.model_path / "config.json", "r") as handle:
+ config = json.load(handle)
+ self.config = config
+ if str(config.get("model_type", "")).lower() != "qwen2":
+ raise ValueError("TensorParallelQwen2 currently supports only Qwen2 models")
+
+ dtype = Qwen2._runtime_dtype(str(config.get("torch_dtype", "float32")), DeviceType.NVIDIA)
+ eos_token = config.get("eos_token_id", 151643)
+ if isinstance(eos_token, list):
+ eos_token = eos_token[0] if eos_token else 151643
+ self.meta = _TensorParallelMeta(
+ dtype=dtype,
+ nlayer=int(config.get("num_hidden_layers", 24)),
+ hs=int(config.get("hidden_size", 2048)),
+ nh=int(config.get("num_attention_heads", 16)),
+ nkvh=int(config.get("num_key_value_heads", int(config.get("num_attention_heads", 16)))),
+ dh=int(config.get("hidden_size", 2048)) // int(config.get("num_attention_heads", 16)),
+ di=int(config.get("intermediate_size", 11008)),
+ maxseq=int(config.get("max_position_embeddings", 8192)),
+ voc=int(config.get("vocab_size", 151936)),
+ epsilon=float(config.get("rms_norm_eps", 1e-6)),
+ theta=float(config.get("rope_theta", 1000000.0)),
+ end_token=int(eos_token),
+ )
+
+ _validate_tp_size(self.meta, self.world_size)
+
+ self.local_nh = self.meta.nh // self.world_size
+ self.local_nkvh = self.meta.nkvh // self.world_size
+ self.local_q_dim = self.local_nh * self.meta.dh
+ self.local_kv_dim = self.local_nkvh * self.meta.dh
+ self.local_di = self.meta.di // self.world_size
+ self.torch_dtype = _llaisys_dtype_to_torch(self.meta.dtype)
+ self.cuda_device = torch.device(f"cuda:{self.device_id}")
+
+ self.in_embed: Optional[Tensor] = None
+ self.out_embed: Optional[Tensor] = None
+ self.out_norm_w: Optional[Tensor] = None
+ self.layers: List[Dict[str, Tensor]] = [dict() for _ in range(self.meta.nlayer)]
+ self._tensors: List[Tensor] = []
+ self._load_weights()
+
+ if self.out_embed is None and self.config.get("tie_word_embeddings") and self.in_embed is not None:
+ self.out_embed = self.in_embed
+ if self.in_embed is None or self.out_embed is None or self.out_norm_w is None:
+ raise ValueError("Missing required embedding/final norm weights for tensor-parallel Qwen2")
+
+ self.k_caches = [
+ Tensor((self.meta.maxseq, self.local_nkvh, self.meta.dh), self.meta.dtype, self.device, self.device_id)
+ for _ in range(self.meta.nlayer)
+ ]
+ self.v_caches = [
+ Tensor((self.meta.maxseq, self.local_nkvh, self.meta.dh), self.meta.dtype, self.device, self.device_id)
+ for _ in range(self.meta.nlayer)
+ ]
+ self._cur_pos = 0
+
+ def reset(self) -> None:
+ self._cur_pos = 0
+
+ def truncate(self, position: int) -> None:
+ position = int(position)
+ if position < 0 or position > self._cur_pos:
+ raise ValueError("truncate position exceeds current cache length")
+ self._cur_pos = position
+
+ def generate_next(self, inputs: Sequence[int], top_k: int, top_p: float, temperature: float) -> int:
+ if not inputs:
+ raise ValueError("inputs must not be empty")
+ if self._cur_pos + len(inputs) > self.meta.maxseq:
+ raise ValueError("sequence exceeds KV-cache capacity")
+
+ _tp_log(self.rank, f"generate_next start cur_pos={self._cur_pos} ntoken={len(inputs)}")
+ hidden = self._embedding(list(inputs))
+ pos_ids = self._make_int64_tensor(np.arange(self._cur_pos, self._cur_pos + len(inputs), dtype=np.int64))
+ scale = 1.0 / float(self.meta.dh) ** 0.5
+
+ for layer_idx, layer in enumerate(self.layers):
+ _tp_log(self.rank, f"layer {layer_idx} start")
+ normed = Tensor(hidden.shape(), self.meta.dtype, self.device, self.device_id)
+ Ops.rms_norm(normed, hidden, layer["attn_norm_w"], self.meta.epsilon)
+ self._debug_sync(f"layer {layer_idx} attn_norm")
+
+ q = Tensor((len(inputs), self.local_q_dim), self.meta.dtype, self.device, self.device_id)
+ k = Tensor((len(inputs), self.local_kv_dim), self.meta.dtype, self.device, self.device_id)
+ v = Tensor((len(inputs), self.local_kv_dim), self.meta.dtype, self.device, self.device_id)
+ Ops.linear(q, normed, layer["attn_q_w"], layer["attn_q_b"])
+ self._debug_sync(f"layer {layer_idx} attn_q")
+ Ops.linear(k, normed, layer["attn_k_w"], layer["attn_k_b"])
+ self._debug_sync(f"layer {layer_idx} attn_k")
+ Ops.linear(v, normed, layer["attn_v_w"], layer["attn_v_b"])
+ self._debug_sync(f"layer {layer_idx} attn_v")
+
+ q_heads = q.view(len(inputs), self.local_nh, self.meta.dh)
+ k_heads = k.view(len(inputs), self.local_nkvh, self.meta.dh)
+ v_heads = v.view(len(inputs), self.local_nkvh, self.meta.dh)
+ Ops.rope(q_heads, q_heads, pos_ids, self.meta.theta)
+ self._debug_sync(f"layer {layer_idx} rope_q")
+ Ops.rope(k_heads, k_heads, pos_ids, self.meta.theta)
+ self._debug_sync(f"layer {layer_idx} rope_k")
+
+ k_slot = self.k_caches[layer_idx].slice(0, self._cur_pos, self._cur_pos + len(inputs))
+ v_slot = self.v_caches[layer_idx].slice(0, self._cur_pos, self._cur_pos + len(inputs))
+ Ops.rearrange(k_slot, k_heads)
+ self._debug_sync(f"layer {layer_idx} cache_k")
+ Ops.rearrange(v_slot, v_heads)
+ self._debug_sync(f"layer {layer_idx} cache_v")
+
+ k_full = self.k_caches[layer_idx].slice(0, 0, self._cur_pos + len(inputs))
+ v_full = self.v_caches[layer_idx].slice(0, 0, self._cur_pos + len(inputs))
+ attn_local = Tensor((len(inputs), self.local_nh, self.meta.dh), self.meta.dtype, self.device, self.device_id)
+ Ops.self_attention(attn_local, q_heads, k_full, v_full, scale)
+ self._debug_sync(f"layer {layer_idx} self_attention")
+
+ attn_flat = attn_local.view(len(inputs), self.local_q_dim)
+ attn_partial = Tensor((len(inputs), self.meta.hs), self.meta.dtype, self.device, self.device_id)
+ Ops.linear(attn_partial, attn_flat, layer["attn_o_w"], None)
+ self._debug_sync(f"layer {layer_idx} attn_o")
+ self._all_reduce_tensor(attn_partial)
+ self._debug_sync(f"layer {layer_idx} attn_all_reduce")
+ Ops.add(hidden, hidden, attn_partial)
+ self._debug_sync(f"layer {layer_idx} attn_residual")
+
+ mlp_normed = Tensor(hidden.shape(), self.meta.dtype, self.device, self.device_id)
+ Ops.rms_norm(mlp_normed, hidden, layer["mlp_norm_w"], self.meta.epsilon)
+ self._debug_sync(f"layer {layer_idx} mlp_norm")
+ gate = Tensor((len(inputs), self.local_di), self.meta.dtype, self.device, self.device_id)
+ up = Tensor((len(inputs), self.local_di), self.meta.dtype, self.device, self.device_id)
+ Ops.linear(gate, mlp_normed, layer["mlp_gate_w"], None)
+ self._debug_sync(f"layer {layer_idx} mlp_gate")
+ Ops.linear(up, mlp_normed, layer["mlp_up_w"], None)
+ self._debug_sync(f"layer {layer_idx} mlp_up")
+ swiglu_out = Tensor((len(inputs), self.local_di), self.meta.dtype, self.device, self.device_id)
+ Ops.swiglu(swiglu_out, gate, up)
+ self._debug_sync(f"layer {layer_idx} swiglu")
+ mlp_partial = Tensor((len(inputs), self.meta.hs), self.meta.dtype, self.device, self.device_id)
+ Ops.linear(mlp_partial, swiglu_out, layer["mlp_down_w"], None)
+ self._debug_sync(f"layer {layer_idx} mlp_down")
+ self._all_reduce_tensor(mlp_partial)
+ self._debug_sync(f"layer {layer_idx} mlp_all_reduce")
+ Ops.add(hidden, hidden, mlp_partial)
+ self._debug_sync(f"layer {layer_idx} mlp_residual")
+ _tp_log(self.rank, f"layer {layer_idx} done")
+
+ self._cur_pos += len(inputs)
+ _tp_log(self.rank, f"generate_next layers done cur_pos={self._cur_pos}")
+
+ next_token = -1
+ if self.rank == 0:
+ normed = Tensor(hidden.shape(), self.meta.dtype, self.device, self.device_id)
+ Ops.rms_norm(normed, hidden, self.out_norm_w, self.meta.epsilon)
+ last_hidden = normed.slice(0, len(inputs) - 1, len(inputs)).view(1, self.meta.hs)
+ logits = Tensor((1, self.meta.voc), self.meta.dtype, self.device, self.device_id)
+ Ops.linear(logits, last_hidden, self.out_embed, None)
+ next_token = Ops.sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
+
+ return self._broadcast_token(next_token)
+
+ def _debug_sync(self, label: str) -> None:
+ if os.getenv("LLAISYS_TP_DEBUG", "").lower() not in {"1", "true", "yes", "on"}:
+ return
+ self.runtime.set_device(self.device_id)
+ self.runtime.device_synchronize()
+ _tp_log(self.rank, f"{label} ok")
+
+ def _all_reduce_tensor(self, tensor: Tensor) -> None:
+ buffer = self._to_torch(tensor)
+ dist.all_reduce(buffer)
+ self._from_torch(buffer, tensor)
+
+ def _broadcast_token(self, token: int) -> int:
+ value = torch.tensor([int(token) if self.rank == 0 else 0], dtype=torch.int64, device=self.cuda_device)
+ dist.broadcast(value, src=0)
+ return int(value.item())
+
+ def _embedding(self, token_ids: List[int]) -> Tensor:
+ tokens = self._make_int64_tensor(np.asarray(token_ids, dtype=np.int64))
+ hidden = Tensor((len(token_ids), self.meta.hs), self.meta.dtype, self.device, self.device_id)
+ Ops.embedding(hidden, tokens, self.in_embed)
+ return hidden
+
+ def _to_torch(self, tensor: Tensor) -> torch.Tensor:
+ torch_tensor = torch.empty(tensor.shape(), dtype=self.torch_dtype, device=self.cuda_device)
+ self.runtime.set_device(self.device_id)
+ self.runtime.memcpy_sync(
+ torch_tensor.data_ptr(),
+ tensor.data_ptr(),
+ torch_tensor.numel() * torch_tensor.element_size(),
+ MemcpyKind.D2D,
+ )
+ return torch_tensor
+
+ def _from_torch(self, torch_tensor: torch.Tensor, tensor: Tensor) -> None:
+ self.runtime.set_device(self.device_id)
+ self.runtime.memcpy_sync(
+ tensor.data_ptr(),
+ torch_tensor.data_ptr(),
+ torch_tensor.numel() * torch_tensor.element_size(),
+ MemcpyKind.D2D,
+ )
+
+ def _make_int64_tensor(self, values: np.ndarray) -> Tensor:
+ tensor = Tensor(values.shape, DataType.I64, self.device, self.device_id)
+ tensor.load(ctypes.c_void_p(values.ctypes.data))
+ return tensor
+
+ def _tensor_from_array(self, array: np.ndarray, *, dtype: Optional[DataType] = None) -> Tensor:
+ target_dtype = self.meta.dtype if dtype is None else dtype
+ tensor = Tensor(array.shape, target_dtype, self.device, self.device_id)
+ tensor.load(ctypes.c_void_p(array.ctypes.data))
+ self._tensors.append(tensor)
+ return tensor
+
+ def _load_weights(self) -> None:
+ for file in sorted(self.model_path.glob("*.safetensors")):
+ for key, array, source_dtype in self._iter_safetensors(file):
+ self._assign_weight(key, array, source_dtype)
+
+ def _assign_weight(self, name: str, array: np.ndarray, source_dtype: str) -> None:
+ if name == "model.embed_tokens.weight":
+ converted = self._convert_array(array, source_dtype)
+ self.in_embed = self._tensor_from_array(np.ascontiguousarray(converted))
+ return
+ if name == "lm_head.weight":
+ converted = self._convert_array(array, source_dtype)
+ self.out_embed = self._tensor_from_array(np.ascontiguousarray(converted))
+ return
+ if name == "model.norm.weight":
+ converted = self._convert_array(array, source_dtype)
+ self.out_norm_w = self._tensor_from_array(np.ascontiguousarray(converted))
+ return
+ if not name.startswith("model.layers."):
+ return
+
+ parts = name.split(".")
+ layer_idx = int(parts[2])
+ suffix = ".".join(parts[3:])
+ layer = self.layers[layer_idx]
+
+ if suffix == "input_layernorm.weight":
+ layer["attn_norm_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(array, source_dtype)))
+ return
+ if suffix == "post_attention_layernorm.weight":
+ layer["mlp_norm_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(array, source_dtype)))
+ return
+ if suffix == "self_attn.q_proj.weight":
+ shard = self._shard_rows(array)
+ layer["attn_q_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype)))
+ return
+ if suffix == "self_attn.q_proj.bias":
+ shard = self._shard_rows_bias(array)
+ layer["attn_q_b"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype)))
+ return
+ if suffix == "self_attn.k_proj.weight":
+ shard = self._shard_rows(array, kv=True)
+ layer["attn_k_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype)))
+ return
+ if suffix == "self_attn.k_proj.bias":
+ shard = self._shard_rows_bias(array, kv=True)
+ layer["attn_k_b"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype)))
+ return
+ if suffix == "self_attn.v_proj.weight":
+ shard = self._shard_rows(array, kv=True)
+ layer["attn_v_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype)))
+ return
+ if suffix == "self_attn.v_proj.bias":
+ shard = self._shard_rows_bias(array, kv=True)
+ layer["attn_v_b"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype)))
+ return
+ if suffix == "self_attn.o_proj.weight":
+ shard = self._shard_cols(array, self.local_q_dim)
+ layer["attn_o_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype)))
+ return
+ if suffix == "mlp.gate_proj.weight":
+ shard = self._shard_rows(array, mlp=True)
+ layer["mlp_gate_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype)))
+ return
+ if suffix == "mlp.up_proj.weight":
+ shard = self._shard_rows(array, mlp=True)
+ layer["mlp_up_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype)))
+ return
+ if suffix == "mlp.down_proj.weight":
+ shard = self._shard_cols(array, self.local_di)
+ layer["mlp_down_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype)))
+ return
+
+ def _shard_rows(self, array: np.ndarray, *, kv: bool = False, mlp: bool = False) -> np.ndarray:
+ if kv:
+ chunk = self.local_kv_dim
+ elif mlp:
+ chunk = self.local_di
+ else:
+ chunk = self.local_q_dim
+ start = self.rank * chunk
+ end = start + chunk
+ return np.asarray(array[start:end, :])
+
+ def _shard_rows_bias(self, array: np.ndarray, *, kv: bool = False) -> np.ndarray:
+ chunk = self.local_kv_dim if kv else self.local_q_dim
+ start = self.rank * chunk
+ end = start + chunk
+ return np.asarray(array[start:end])
+
+ def _shard_cols(self, array: np.ndarray, chunk: int) -> np.ndarray:
+ start = self.rank * chunk
+ end = start + chunk
+ return np.asarray(array[:, start:end])
+
+ def _iter_safetensors(self, path: Path):
+ with open(path, "rb") as handle:
+ length_bytes = handle.read(8)
+ if not length_bytes:
+ return
+ header_size = struct.unpack(" np.ndarray:
+ if self.meta.dtype == DataType.F32:
+ return Qwen2._to_float32_array(arr, source_dtype)
+ if self.meta.dtype == DataType.F16:
+ return Qwen2._to_float16_array(arr, source_dtype)
+ if self.meta.dtype == DataType.BF16:
+ return Qwen2._to_bfloat16_bytes(arr, source_dtype)
+ raise ValueError(f"Unsupported model dtype: {self.meta.dtype}")
+
+
+def _tensor_parallel_worker_main(
+ rank: int,
+ world_size: int,
+ model_path: str,
+ device_ids: List[int],
+ master_addr: str,
+ master_port: int,
+ command_queue,
+ result_queue,
+) -> None:
+ os.environ.setdefault("MASTER_ADDR", master_addr)
+ os.environ.setdefault("MASTER_PORT", str(master_port))
+ os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")
+
+ torch.cuda.set_device(device_ids[rank])
+ _tp_log(rank, f"worker start device={device_ids[rank]}")
+ dist.init_process_group(
+ backend="nccl",
+ init_method=f"tcp://{master_addr}:{master_port}",
+ world_size=world_size,
+ rank=rank,
+ device_id=torch.device(f"cuda:{device_ids[rank]}"),
+ )
+ worker = None
+ try:
+ worker = _ShardedQwen2Rank(model_path, rank, world_size, device_ids[rank])
+ _tp_log(rank, "model loaded")
+ dist.barrier()
+ if rank == 0:
+ result_queue.put({"status": "ready"})
+
+ while True:
+ message = command_queue.get()
+ command = str(message["command"])
+ seq = int(message["seq"])
+ _tp_log(rank, f"command {command} seq={seq}")
+
+ if command == "shutdown":
+ if rank == 0:
+ result_queue.put({"seq": seq, "status": "ok"})
+ break
+ if command == "reset":
+ worker.reset()
+ if rank == 0:
+ result_queue.put({"seq": seq, "status": "ok"})
+ continue
+ if command == "truncate":
+ worker.truncate(int(message["position"]))
+ if rank == 0:
+ result_queue.put({"seq": seq, "status": "ok"})
+ continue
+ if command == "generate_next":
+ if bool(message.get("reset_state", False)):
+ worker.reset()
+ token = worker.generate_next(
+ message["inputs"],
+ top_k=int(message["top_k"]),
+ top_p=float(message["top_p"]),
+ temperature=float(message["temperature"]),
+ )
+ if rank == 0:
+ result_queue.put({"seq": seq, "status": "ok", "token": int(token)})
+ continue
+ raise ValueError(f"Unknown tensor-parallel command: {command}")
+ except Exception as exc: # pragma: no cover - defensive worker path
+ _tp_log(rank, f"worker exception {repr(exc)}")
+ result_queue.put({"status": "error", "rank": rank, "error": repr(exc)})
+ raise
+ finally:
+ if dist.is_initialized():
+ dist.destroy_process_group()
+
+
+class TensorParallelQwen2:
+ SUPPORTED_MODEL_TYPES: ClassVar[set[str]] = {"qwen2"}
+
+ @classmethod
+ def supports_model_type(cls, model_type: str) -> bool:
+ return str(model_type).lower() in cls.SUPPORTED_MODEL_TYPES
+
+ def __init__(
+ self,
+ model_path: str,
+ device: DeviceType = DeviceType.NVIDIA,
+ device_id: int = 0,
+ *,
+ tp_size: int = 2,
+ tp_device_ids: Optional[Sequence[int]] = None,
+ ):
+ if device != DeviceType.NVIDIA:
+ raise ValueError("TensorParallelQwen2 currently supports only NVIDIA devices")
+ if int(tp_size) <= 1:
+ raise ValueError("tp_size must be greater than 1 for tensor parallel inference")
+
+ self.model_path = str(model_path)
+ self.device = device
+ self.device_id = int(device_id)
+ self.device_ids = _resolve_tp_device_ids(device_id, tp_size, tp_device_ids)
+ self.tp_size = len(self.device_ids)
+
+ with open(Path(model_path) / "config.json", "r") as handle:
+ config = json.load(handle)
+ self.config = config
+ self.model_type = str(config.get("model_type", "")).lower()
+ if not self.supports_model_type(self.model_type):
+ raise ValueError(f"Unsupported tensor-parallel model type: {self.model_type or ''}")
+
+ dtype = Qwen2._runtime_dtype(str(config.get("torch_dtype", "float32")), DeviceType.NVIDIA)
+ eos_token = config.get("eos_token_id", 151643)
+ if isinstance(eos_token, list):
+ eos_token = eos_token[0] if eos_token else 151643
+ self.meta = _TensorParallelMeta(
+ dtype=dtype,
+ nlayer=int(config.get("num_hidden_layers", 24)),
+ hs=int(config.get("hidden_size", 2048)),
+ nh=int(config.get("num_attention_heads", 16)),
+ nkvh=int(config.get("num_key_value_heads", int(config.get("num_attention_heads", 16)))),
+ dh=int(config.get("hidden_size", 2048)) // int(config.get("num_attention_heads", 16)),
+ di=int(config.get("intermediate_size", 11008)),
+ maxseq=int(config.get("max_position_embeddings", 8192)),
+ voc=int(config.get("vocab_size", 151936)),
+ epsilon=float(config.get("rms_norm_eps", 1e-6)),
+ theta=float(config.get("rope_theta", 1000000.0)),
+ end_token=int(eos_token),
+ )
+ _validate_tp_size(self.meta, self.tp_size)
+
+ ctx = mp.get_context("spawn")
+ self._command_queues = [ctx.Queue() for _ in range(self.tp_size)]
+ self._result_queue = ctx.Queue()
+ self._processes = []
+ self._closed = False
+ self._seq = 0
+ self._master_addr = "127.0.0.1"
+ self._master_port = _find_free_port()
+
+ for rank in range(self.tp_size):
+ process = ctx.Process(
+ target=_tensor_parallel_worker_main,
+ args=(
+ rank,
+ self.tp_size,
+ self.model_path,
+ self.device_ids,
+ self._master_addr,
+ self._master_port,
+ self._command_queues[rank],
+ self._result_queue,
+ ),
+ daemon=True,
+ )
+ process.start()
+ self._processes.append(process)
+
+ self._await_ready()
+
+ def __del__(self):
+ try:
+ self.close()
+ except Exception:
+ pass
+
+ def close(self) -> None:
+ if self._closed:
+ return
+ self._closed = True
+ try:
+ self._request("shutdown")
+ except Exception:
+ pass
+ for process in self._processes:
+ process.join(timeout=10)
+ if process.is_alive():
+ process.kill()
+ self._processes.clear()
+
+ def reset(self) -> None:
+ self._request("reset")
+
+ def truncate(self, position: int) -> None:
+ self._request("truncate", position=int(position))
+
+ def generate_next(
+ self,
+ inputs: Sequence[int],
+ top_k: int = 1,
+ top_p: float = 1.0,
+ temperature: float = 1.0,
+ *,
+ reset_state: bool = False,
+ ) -> int:
+ if not inputs:
+ raise ValueError("inputs must not be empty")
+ response = self._request(
+ "generate_next",
+ inputs=[int(token) for token in inputs],
+ top_k=int(top_k),
+ top_p=float(top_p),
+ temperature=float(temperature),
+ reset_state=bool(reset_state),
+ )
+ return int(response["token"])
+
+ def generate(
+ self,
+ inputs: Sequence[int],
+ max_new_tokens: int = 20,
+ top_k: int = 1,
+ top_p: float = 0.8,
+ temperature: float = 0.8,
+ *,
+ reset_state: bool = True,
+ ) -> List[int]:
+ if not inputs:
+ raise ValueError("inputs must not be empty")
+ if reset_state:
+ self.reset()
+
+ generated: List[int] = []
+ token_source = list(inputs)
+ for _ in range(int(max_new_tokens)):
+ next_token = self.generate_next(
+ token_source,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ reset_state=False,
+ )
+ generated.append(next_token)
+ token_source = [next_token]
+ if next_token == self.meta.end_token:
+ break
+ return list(inputs) + generated
+
+ def stream_generate(
+ self,
+ inputs: Sequence[int],
+ *,
+ tokenizer=None,
+ max_new_tokens: int = 20,
+ top_k: int = 1,
+ top_p: float = 0.8,
+ temperature: float = 0.8,
+ reset_state: bool = True,
+ ) -> Iterator[Tuple[int, str]]:
+ if reset_state:
+ self.reset()
+ generated: List[int] = []
+ previous_text = ""
+ token_source = list(inputs)
+
+ for _ in range(int(max_new_tokens)):
+ next_token = self.generate_next(
+ token_source,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ reset_state=False,
+ )
+ generated.append(next_token)
+ token_source = [next_token]
+
+ text_chunk = ""
+ if tokenizer is not None:
+ decoded = tokenizer.decode(
+ generated,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )
+ text_chunk = decoded[len(previous_text):] if decoded.startswith(previous_text) else decoded
+ previous_text = decoded
+
+ yield next_token, text_chunk
+ if next_token == self.meta.end_token:
+ break
+
+ def _request(self, command: str, **payload):
+ self._seq += 1
+ message = _normalize_command_payload(command, self._seq, **payload)
+ for queue_ in self._command_queues:
+ queue_.put(message)
+
+ deadline_s = 600.0
+ waited_s = 0.0
+ while True:
+ try:
+ response = self._result_queue.get(timeout=5)
+ break
+ except queue.Empty:
+ waited_s += 5.0
+ dead = {
+ process.pid: process.exitcode
+ for process in self._processes
+ if process.exitcode not in (None, 0)
+ }
+ if dead:
+ raise RuntimeError(
+ f"Tensor-parallel worker exited while handling `{command}`: {dead}"
+ )
+ if waited_s >= deadline_s:
+ raise TimeoutError(f"Timed out waiting for tensor-parallel command `{command}`")
+
+ if response.get("status") == "error":
+ raise RuntimeError(response.get("error", "Unknown tensor-parallel worker error"))
+ if response.get("seq") != self._seq:
+ raise RuntimeError(f"Unexpected tensor-parallel response ordering: {response}")
+ return response
+
+ def _await_ready(self) -> None:
+ while True:
+ try:
+ ready = self._result_queue.get(timeout=5)
+ except queue.Empty:
+ dead = [process.pid for process in self._processes if process.exitcode not in (None, 0)]
+ if dead:
+ raise RuntimeError(f"Tensor-parallel worker exited before ready: pids={dead}")
+ continue
+ if ready.get("status") == "ready":
+ return
+ if ready.get("status") == "error":
+ raise RuntimeError(
+ f"Tensor-parallel worker rank {ready.get('rank')} failed to start: {ready.get('error')}"
+ )
+ raise RuntimeError(f"Unexpected tensor-parallel startup response: {ready}")
diff --git a/python/llaisys/ops.py b/python/llaisys/ops.py
index ed0180bc..0e3b3b81 100644
--- a/python/llaisys/ops.py
+++ b/python/llaisys/ops.py
@@ -21,7 +21,21 @@ def embedding(out: Tensor, index: Tensor, weight: Tensor):
@staticmethod
def linear(out: Tensor, inp: Tensor, weight: Tensor, bias: Tensor):
LIB_LLAISYS.llaisysLinear(
- out.lib_tensor(), inp.lib_tensor(), weight.lib_tensor(), bias.lib_tensor()
+ out.lib_tensor(),
+ inp.lib_tensor(),
+ weight.lib_tensor(),
+ bias.lib_tensor() if bias is not None else None,
+ )
+
+ @staticmethod
+ def sample(logits: Tensor, top_k: int = 1, top_p: float = 1.0, temperature: float = 1.0) -> int:
+ return int(
+ LIB_LLAISYS.llaisysSample(
+ logits.lib_tensor(),
+ c_int(top_k),
+ c_float(top_p),
+ c_float(temperature),
+ )
)
@staticmethod
diff --git a/python/setup.cfg b/python/setup.cfg
index b35fc65f..3aeef424 100644
--- a/python/setup.cfg
+++ b/python/setup.cfg
@@ -14,6 +14,17 @@ install_requires =
transformers
accelerate
+[options.extras_require]
+server =
+ fastapi
+ uvicorn
+ httpx
+
+[options.entry_points]
+console_scripts =
+ llaisys-chat-server = llaisys.chat_server:main
+ llaisys-chat-cli = llaisys.chat_cli:main
+
[options.package_data]
llaisys =
libllaisys/*.so
diff --git a/src/core/context/context.cpp b/src/core/context/context.cpp
index 44894b9e..63756fab 100644
--- a/src/core/context/context.cpp
+++ b/src/core/context/context.cpp
@@ -52,7 +52,7 @@ Context::~Context() {
void Context::setDevice(llaisysDeviceType_t device_type, int device_id) {
// If doest not match the current runtime.
if (_current_runtime == nullptr || _current_runtime->deviceType() != device_type || _current_runtime->deviceId() != device_id) {
- auto runtimes = _runtime_map[device_type];
+ auto &runtimes = _runtime_map[device_type];
CHECK_ARGUMENT((size_t)device_id < runtimes.size() && device_id >= 0, "invalid device id");
if (_current_runtime != nullptr) {
_current_runtime->_deactivate();
diff --git a/src/core/runtime/runtime.cpp b/src/core/runtime/runtime.cpp
index 7f03a862..7d10fae7 100644
--- a/src/core/runtime/runtime.cpp
+++ b/src/core/runtime/runtime.cpp
@@ -7,6 +7,7 @@ namespace llaisys::core {
Runtime::Runtime(llaisysDeviceType_t device_type, int device_id)
: _device_type(device_type), _device_id(device_id), _is_active(false) {
_api = llaisys::device::getRuntimeAPI(_device_type);
+ _api->set_device(_device_id);
_stream = _api->create_stream();
_allocator = new allocators::NaiveAllocator(_api);
}
@@ -17,6 +18,7 @@ Runtime::~Runtime() {
}
delete _allocator;
_allocator = nullptr;
+ _api->set_device(_device_id);
_api->destroy_stream(_stream);
_api = nullptr;
}
diff --git a/src/device/nvidia/cuda_utils.cuh b/src/device/nvidia/cuda_utils.cuh
new file mode 100644
index 00000000..5096013c
--- /dev/null
+++ b/src/device/nvidia/cuda_utils.cuh
@@ -0,0 +1,160 @@
+#pragma once
+
+#include "../../utils.hpp"
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+namespace llaisys::device::nvidia::cuda_utils {
+
+inline void checkCuda(cudaError_t status, const char *expr, const char *file, int line) {
+ if (status == cudaSuccess) {
+ return;
+ }
+ std::cerr << "[ERROR] CUDA call failed: " << expr << " -> " << cudaGetErrorString(status)
+ << " at " << file << ":" << line << std::endl;
+ throw std::runtime_error("CUDA error");
+}
+
+inline const char *cublasStatusName(cublasStatus_t status) {
+ switch (status) {
+ case CUBLAS_STATUS_SUCCESS:
+ return "CUBLAS_STATUS_SUCCESS";
+ case CUBLAS_STATUS_NOT_INITIALIZED:
+ return "CUBLAS_STATUS_NOT_INITIALIZED";
+ case CUBLAS_STATUS_ALLOC_FAILED:
+ return "CUBLAS_STATUS_ALLOC_FAILED";
+ case CUBLAS_STATUS_INVALID_VALUE:
+ return "CUBLAS_STATUS_INVALID_VALUE";
+ case CUBLAS_STATUS_ARCH_MISMATCH:
+ return "CUBLAS_STATUS_ARCH_MISMATCH";
+ case CUBLAS_STATUS_MAPPING_ERROR:
+ return "CUBLAS_STATUS_MAPPING_ERROR";
+ case CUBLAS_STATUS_EXECUTION_FAILED:
+ return "CUBLAS_STATUS_EXECUTION_FAILED";
+ case CUBLAS_STATUS_INTERNAL_ERROR:
+ return "CUBLAS_STATUS_INTERNAL_ERROR";
+ default:
+ return "CUBLAS_STATUS_UNKNOWN";
+ }
+}
+
+inline void checkCublas(cublasStatus_t status, const char *expr, const char *file, int line) {
+ if (status == CUBLAS_STATUS_SUCCESS) {
+ return;
+ }
+ std::cerr << "[ERROR] cuBLAS call failed: " << expr << " -> " << cublasStatusName(status)
+ << " at " << file << ":" << line << std::endl;
+ throw std::runtime_error("cuBLAS error");
+}
+
+#define LLAISYS_CUDA_CHECK(EXPR__) ::llaisys::device::nvidia::cuda_utils::checkCuda((EXPR__), #EXPR__, __FILE__, __LINE__)
+#define LLAISYS_CUBLAS_CHECK(EXPR__) ::llaisys::device::nvidia::cuda_utils::checkCublas((EXPR__), #EXPR__, __FILE__, __LINE__)
+
+inline cudaMemcpyKind memcpyKind(llaisysMemcpyKind_t kind) {
+ switch (kind) {
+ case LLAISYS_MEMCPY_H2H:
+ return cudaMemcpyHostToHost;
+ case LLAISYS_MEMCPY_H2D:
+ return cudaMemcpyHostToDevice;
+ case LLAISYS_MEMCPY_D2H:
+ return cudaMemcpyDeviceToHost;
+ case LLAISYS_MEMCPY_D2D:
+ return cudaMemcpyDeviceToDevice;
+ default:
+ throw std::invalid_argument("Unsupported memcpy kind");
+ }
+}
+
+inline cudaDataType_t cublasDataType(llaisysDataType_t dtype) {
+ switch (dtype) {
+ case LLAISYS_DTYPE_F32:
+ return CUDA_R_32F;
+ case LLAISYS_DTYPE_F16:
+ return CUDA_R_16F;
+ case LLAISYS_DTYPE_BF16:
+ return CUDA_R_16BF;
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(dtype);
+ }
+}
+
+inline cublasComputeType_t cublasComputeType(llaisysDataType_t dtype) {
+ switch (dtype) {
+ case LLAISYS_DTYPE_F32:
+ case LLAISYS_DTYPE_F16:
+ case LLAISYS_DTYPE_BF16:
+ return CUBLAS_COMPUTE_32F;
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(dtype);
+ }
+}
+
+template
+__device__ __forceinline__ float toFloat(T value) {
+ return static_cast(value);
+}
+
+template <>
+__device__ __forceinline__ float toFloat(float value) {
+ return value;
+}
+
+template <>
+__device__ __forceinline__ float toFloat(llaisys::fp16_t value) {
+ union {
+ __half half_value;
+ __half_raw raw;
+ } bits{};
+ bits.raw.x = value._v;
+ return __half2float(bits.half_value);
+}
+
+template <>
+__device__ __forceinline__ float toFloat(llaisys::bf16_t value) {
+ union {
+ __nv_bfloat16 bf16_value;
+ __nv_bfloat16_raw raw;
+ } bits{};
+ bits.raw.x = value._v;
+ return __bfloat162float(bits.bf16_value);
+}
+
+template
+__device__ __forceinline__ T fromFloat(float value) {
+ return static_cast(value);
+}
+
+template <>
+__device__ __forceinline__ float fromFloat(float value) {
+ return value;
+}
+
+template <>
+__device__ __forceinline__ llaisys::fp16_t fromFloat(float value) {
+ union {
+ __half half_value;
+ __half_raw raw;
+ } bits{};
+ bits.half_value = __float2half_rn(value);
+ return llaisys::fp16_t{bits.raw.x};
+}
+
+template <>
+__device__ __forceinline__ llaisys::bf16_t fromFloat(float value) {
+ union {
+ __nv_bfloat16 bf16_value;
+ __nv_bfloat16_raw raw;
+ } bits{};
+ bits.bf16_value = __float2bfloat16(value);
+ return llaisys::bf16_t{bits.raw.x};
+}
+
+} // namespace llaisys::device::nvidia::cuda_utils
diff --git a/src/device/nvidia/nvidia_resource.cu b/src/device/nvidia/nvidia_resource.cu
index 2e63647e..84c59c6f 100644
--- a/src/device/nvidia/nvidia_resource.cu
+++ b/src/device/nvidia/nvidia_resource.cu
@@ -3,5 +3,6 @@
namespace llaisys::device::nvidia {
Resource::Resource(int device_id) : llaisys::device::DeviceResource(LLAISYS_DEVICE_NVIDIA, device_id) {}
+Resource::~Resource() = default;
} // namespace llaisys::device::nvidia
diff --git a/src/device/nvidia/nvidia_runtime_api.cu b/src/device/nvidia/nvidia_runtime_api.cu
index cab92826..ba95472a 100644
--- a/src/device/nvidia/nvidia_runtime_api.cu
+++ b/src/device/nvidia/nvidia_runtime_api.cu
@@ -1,56 +1,80 @@
#include "../runtime_api.hpp"
-
-#include
-#include
+#include "cuda_utils.cuh"
namespace llaisys::device::nvidia {
namespace runtime_api {
int getDeviceCount() {
- TO_BE_IMPLEMENTED();
+ int count = 0;
+ LLAISYS_CUDA_CHECK(cudaGetDeviceCount(&count));
+ return count;
}
-void setDevice(int) {
- TO_BE_IMPLEMENTED();
+void setDevice(int device_id) {
+ LLAISYS_CUDA_CHECK(cudaSetDevice(device_id));
}
void deviceSynchronize() {
- TO_BE_IMPLEMENTED();
+ LLAISYS_CUDA_CHECK(cudaDeviceSynchronize());
}
llaisysStream_t createStream() {
- TO_BE_IMPLEMENTED();
+ cudaStream_t stream = nullptr;
+ LLAISYS_CUDA_CHECK(cudaStreamCreate(&stream));
+ return reinterpret_cast(stream);
}
void destroyStream(llaisysStream_t stream) {
- TO_BE_IMPLEMENTED();
+ if (stream == nullptr) {
+ return;
+ }
+ LLAISYS_CUDA_CHECK(cudaStreamDestroy(reinterpret_cast(stream)));
}
void streamSynchronize(llaisysStream_t stream) {
- TO_BE_IMPLEMENTED();
+ if (stream == nullptr) {
+ return;
+ }
+ LLAISYS_CUDA_CHECK(cudaStreamSynchronize(reinterpret_cast(stream)));
}
void *mallocDevice(size_t size) {
- TO_BE_IMPLEMENTED();
+ void *ptr = nullptr;
+ LLAISYS_CUDA_CHECK(cudaMalloc(&ptr, size));
+ return ptr;
}
void freeDevice(void *ptr) {
- TO_BE_IMPLEMENTED();
+ if (ptr == nullptr) {
+ return;
+ }
+ LLAISYS_CUDA_CHECK(cudaFree(ptr));
}
void *mallocHost(size_t size) {
- TO_BE_IMPLEMENTED();
+ void *ptr = nullptr;
+ LLAISYS_CUDA_CHECK(cudaMallocHost(&ptr, size));
+ return ptr;
}
void freeHost(void *ptr) {
- TO_BE_IMPLEMENTED();
+ if (ptr == nullptr) {
+ return;
+ }
+ LLAISYS_CUDA_CHECK(cudaFreeHost(ptr));
}
void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) {
- TO_BE_IMPLEMENTED();
+ LLAISYS_CUDA_CHECK(cudaDeviceSynchronize());
+ LLAISYS_CUDA_CHECK(cudaMemcpy(dst, src, size, cuda_utils::memcpyKind(kind)));
}
-void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) {
- TO_BE_IMPLEMENTED();
+void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) {
+ LLAISYS_CUDA_CHECK(cudaMemcpyAsync(
+ dst,
+ src,
+ size,
+ cuda_utils::memcpyKind(kind),
+ reinterpret_cast(stream)));
}
static const LlaisysRuntimeAPI RUNTIME_API = {
diff --git a/src/llaisys/models/qwen2.cc b/src/llaisys/models/qwen2.cc
new file mode 100644
index 00000000..209ea8b1
--- /dev/null
+++ b/src/llaisys/models/qwen2.cc
@@ -0,0 +1,31 @@
+#include "llaisys/models/qwen2.h"
+#include "../../models/qwen2/model.hpp"
+
+using namespace llaisys::models::qwen2;
+
+extern "C" {
+
+struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) {
+ int dev_id = (ndevice > 0 && device_ids != nullptr) ? device_ids[0] : 0;
+ Qwen2Model* model = new Qwen2Model(*meta, device, dev_id);
+ return reinterpret_cast(model);
+}
+
+void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model) {
+ if (model) {
+ delete reinterpret_cast(model);
+ }
+}
+
+struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model) {
+ if (!model) return nullptr;
+ return reinterpret_cast(model)->getWeightsStruct();
+}
+
+int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken) {
+ if (!model) return -1;
+ std::vector tokens(token_ids, token_ids + ntoken);
+ return reinterpret_cast(model)->infer(tokens);
+}
+
+}
diff --git a/src/llaisys/ops.cc b/src/llaisys/ops.cc
index c99fbc32..924978ba 100644
--- a/src/llaisys/ops.cc
+++ b/src/llaisys/ops.cc
@@ -6,6 +6,7 @@
#include "../ops/argmax/op.hpp"
#include "../ops/embedding/op.hpp"
#include "../ops/linear/op.hpp"
+#include "../ops/sample/op.hpp"
#include "../ops/rearrange/op.hpp"
#include "../ops/rms_norm/op.hpp"
#include "../ops/rope/op.hpp"
@@ -23,7 +24,10 @@ __C {
llaisys::ops::embedding(out->tensor, index->tensor, weight->tensor);
}
void llaisysLinear(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, llaisysTensor_t bias) {
- llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, bias->tensor);
+ llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, bias ? bias->tensor : nullptr);
+ }
+ int64_t llaisysSample(llaisysTensor_t logits, int top_k, float top_p, float temperature) {
+ return llaisys::ops::sample(logits->tensor, top_k, top_p, temperature);
}
void llaisysRearrange(llaisysTensor_t out, llaisysTensor_t in) {
llaisys::ops::rearrange(out->tensor, in->tensor);
diff --git a/src/llaisys/qwen2.cc b/src/llaisys/qwen2.cc
new file mode 100644
index 00000000..24936310
--- /dev/null
+++ b/src/llaisys/qwen2.cc
@@ -0,0 +1,54 @@
+#include "llaisys/models/qwen2.h"
+#include "../models/qwen2/qwen2.hpp"
+
+extern "C" {
+
+struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) {
+ if (!meta || ndevice < 1) return nullptr;
+ // For now support single device
+ int device_id = device_ids ? device_ids[0] : 0;
+
+ // Copy meta
+ LlaisysQwen2Meta cpp_meta = *meta;
+
+ auto* model = new llaisys::Qwen2Model(cpp_meta, device, device_id);
+ return reinterpret_cast(model);
+}
+
+void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model) {
+ if (model) {
+ delete reinterpret_cast(model);
+ }
+}
+
+struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model) {
+ if (!model) return nullptr;
+ auto* cpp_model = reinterpret_cast(model);
+ return cpp_model->getWeights();
+}
+
+void llaisysQwen2ModelReset(struct LlaisysQwen2Model * model) {
+ if (!model) return;
+ auto* cpp_model = reinterpret_cast(model);
+ cpp_model->reset();
+}
+
+void llaisysQwen2ModelTruncate(struct LlaisysQwen2Model * model, size_t position) {
+ if (!model) return;
+ auto* cpp_model = reinterpret_cast(model);
+ cpp_model->truncate(position);
+}
+
+int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken) {
+ if (!model) return -1;
+ auto* cpp_model = reinterpret_cast(model);
+ return cpp_model->infer(token_ids, ntoken);
+}
+
+int64_t llaisysQwen2ModelGenerateNext(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken, int top_k, float top_p, float temperature) {
+ if (!model) return -1;
+ auto* cpp_model = reinterpret_cast(model);
+ return cpp_model->generateNext(token_ids, ntoken, top_k, top_p, temperature);
+}
+
+}
diff --git a/src/models/qwen2/qwen2.cpp b/src/models/qwen2/qwen2.cpp
new file mode 100644
index 00000000..60340e7b
--- /dev/null
+++ b/src/models/qwen2/qwen2.cpp
@@ -0,0 +1,328 @@
+#include "qwen2.hpp"
+#include "llaisys/ops.h"
+#include "../../ops/add/op.hpp"
+#include "../../ops/argmax/op.hpp"
+#include "../../ops/embedding/op.hpp"
+#include "../../ops/linear/op.hpp"
+#include "../../ops/sample/op.hpp"
+#include "../../ops/rearrange/op.hpp"
+#include "../../ops/rms_norm/op.hpp"
+#include "../../ops/rope/op.hpp"
+#include "../../ops/self_attention/op.hpp"
+#include "../../ops/swiglu/op.hpp"
+#include "../../core/llaisys_core.hpp"
+#include "../../utils/check.hpp"
+#include
+#include
+#include
+#include "../../llaisys/llaisys_tensor.hpp"
+
+namespace llaisys {
+
+inline tensor_t to_cpp(llaisysTensor_t t) {
+ if (!t) return nullptr;
+ return reinterpret_cast(t)->tensor;
+}
+
+inline tensor_t to_cpp(llaisysTensor_t* t_array, size_t idx) {
+ if (!t_array || !t_array[idx]) return nullptr;
+ return reinterpret_cast(t_array[idx])->tensor;
+}
+
+Qwen2Model::Qwen2Model(const LlaisysQwen2Meta& meta, llaisysDeviceType_t device, int device_id)
+ : _meta(meta), _device_type(device), _device_id(device_id) {
+
+ _weights.attn_norm_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+ _weights.attn_q_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+ _weights.attn_q_b = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+ _weights.attn_k_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+ _weights.attn_k_b = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+ _weights.attn_v_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+ _weights.attn_v_b = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+ _weights.attn_o_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+ _weights.mlp_norm_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+ _weights.mlp_gate_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+ _weights.mlp_up_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+ _weights.mlp_down_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t));
+
+ init_buffers();
+}
+
+Qwen2Model::~Qwen2Model() {
+ free(_weights.attn_norm_w);
+ free(_weights.attn_q_w);
+ free(_weights.attn_q_b);
+ free(_weights.attn_k_w);
+ free(_weights.attn_k_b);
+ free(_weights.attn_v_w);
+ free(_weights.attn_v_b);
+ free(_weights.attn_o_w);
+ free(_weights.mlp_norm_w);
+ free(_weights.mlp_gate_w);
+ free(_weights.mlp_up_w);
+ free(_weights.mlp_down_w);
+}
+
+void Qwen2Model::init_buffers() {
+ core::context().setDevice(_device_type, _device_id);
+
+ const size_t q_dim = _meta.nh * _meta.dh;
+ const size_t k_dim = _meta.nkvh * _meta.dh;
+ const float scale = 1.0f / std::sqrt(static_cast(_meta.dh));
+ (void)scale;
+
+ for(size_t i=0; i<_meta.nlayer; ++i) {
+ _kv_caches.push_back({
+ Tensor::create({_meta.maxseq, _meta.nkvh, _meta.dh}, _meta.dtype, _device_type, _device_id),
+ Tensor::create({_meta.maxseq, _meta.nkvh, _meta.dh}, _meta.dtype, _device_type, _device_id)
+ });
+
+ auto q_flat = Tensor::create({1, q_dim}, _meta.dtype, _device_type, _device_id);
+ auto k_flat = Tensor::create({1, k_dim}, _meta.dtype, _device_type, _device_id);
+ auto v_flat = Tensor::create({1, k_dim}, _meta.dtype, _device_type, _device_id);
+ auto attn_val_3d = Tensor::create({1, _meta.nh, _meta.dh}, _meta.dtype, _device_type, _device_id);
+
+ _decode_layers.push_back({
+ q_flat,
+ k_flat,
+ v_flat,
+ q_flat->view({1, _meta.nh, _meta.dh}),
+ k_flat->view({1, _meta.nkvh, _meta.dh}),
+ v_flat->view({1, _meta.nkvh, _meta.dh}),
+ attn_val_3d,
+ attn_val_3d->view({1, _meta.hs}),
+ Tensor::create({1, _meta.di}, _meta.dtype, _device_type, _device_id),
+ Tensor::create({1, _meta.di}, _meta.dtype, _device_type, _device_id),
+ Tensor::create({1, _meta.di}, _meta.dtype, _device_type, _device_id)
+ });
+ }
+
+ _hidden_states = Tensor::create({1, _meta.hs}, _meta.dtype, _device_type, _device_id);
+ _ln_out = Tensor::create({1, _meta.hs}, _meta.dtype, _device_type, _device_id);
+ _attn_out = Tensor::create({1, _meta.hs}, _meta.dtype, _device_type, _device_id);
+ _mlp_out = Tensor::create({1, _meta.hs}, _meta.dtype, _device_type, _device_id);
+ _logits = Tensor::create({1, _meta.voc}, _meta.dtype, _device_type, _device_id);
+ _tokens_tensor = Tensor::create({1}, LLAISYS_DTYPE_I64, _device_type, _device_id);
+ _single_pos_id = Tensor::create({1}, LLAISYS_DTYPE_I64, _device_type, _device_id);
+ _pos_ids = Tensor::create({_meta.maxseq}, LLAISYS_DTYPE_I64, _device_type, _device_id);
+ _argmax_idx = Tensor::create({1}, LLAISYS_DTYPE_I64, _device_type, _device_id);
+ _argmax_val = Tensor::create({1}, _meta.dtype, _device_type, _device_id);
+}
+
+void Qwen2Model::ensure_prefill_buffers(size_t ntoken) {
+ if (_prefill.capacity >= ntoken) {
+ return;
+ }
+
+ const size_t next_capacity = std::max(ntoken, _prefill.capacity == 0 ? ntoken : _prefill.capacity * 2);
+ const size_t q_dim = _meta.nh * _meta.dh;
+ const size_t k_dim = _meta.nkvh * _meta.dh;
+
+ _prefill.capacity = next_capacity;
+ _prefill.tokens = Tensor::create({next_capacity}, LLAISYS_DTYPE_I64, _device_type, _device_id);
+ _prefill.hidden_states = Tensor::create({next_capacity, _meta.hs}, _meta.dtype, _device_type, _device_id);
+ _prefill.normed = Tensor::create({next_capacity, _meta.hs}, _meta.dtype, _device_type, _device_id);
+ _prefill.q_flat = Tensor::create({next_capacity, q_dim}, _meta.dtype, _device_type, _device_id);
+ _prefill.k_flat = Tensor::create({next_capacity, k_dim}, _meta.dtype, _device_type, _device_id);
+ _prefill.v_flat = Tensor::create({next_capacity, k_dim}, _meta.dtype, _device_type, _device_id);
+ _prefill.attn_val_3d = Tensor::create({next_capacity, _meta.nh, _meta.dh}, _meta.dtype, _device_type, _device_id);
+ _prefill.attn_out = Tensor::create({next_capacity, _meta.hs}, _meta.dtype, _device_type, _device_id);
+ _prefill.gate = Tensor::create({next_capacity, _meta.di}, _meta.dtype, _device_type, _device_id);
+ _prefill.up = Tensor::create({next_capacity, _meta.di}, _meta.dtype, _device_type, _device_id);
+ _prefill.swiglu = Tensor::create({next_capacity, _meta.di}, _meta.dtype, _device_type, _device_id);
+ _prefill.mlp_out = Tensor::create({next_capacity, _meta.hs}, _meta.dtype, _device_type, _device_id);
+}
+
+tensor_t Qwen2Model::forward(const int64_t* token_ids, size_t ntoken) {
+ CHECK_ARGUMENT(token_ids != nullptr, "token_ids must not be null");
+ CHECK_ARGUMENT(ntoken > 0, "ntoken must be greater than zero");
+ CHECK_ARGUMENT(_cur_pos + ntoken <= _meta.maxseq, "sequence exceeds KV-cache capacity");
+
+ core::context().setDevice(_device_type, _device_id);
+
+ if (ntoken == 1) {
+ return forward_single_token(token_ids);
+ }
+
+ ensure_prefill_buffers(ntoken);
+
+ tensor_t input_tokens = _prefill.tokens->slice(0, 0, ntoken);
+ input_tokens->load(token_ids);
+
+ tensor_t current_pos_ids = _pos_ids->slice(0, 0, ntoken);
+ std::vector pos_data(ntoken);
+ for(size_t i=0; iload(pos_data.data());
+
+ std::vector seq_shape = {ntoken, _meta.hs};
+
+ tensor_t hidden_states = _prefill.hidden_states->slice(0, 0, ntoken);
+ hidden_states = hidden_states->view(seq_shape);
+ ops::embedding(hidden_states, input_tokens, to_cpp(_weights.in_embed));
+
+ const size_t q_dim = _meta.nh * _meta.dh;
+ const size_t k_dim = _meta.nkvh * _meta.dh;
+ const float scale = 1.0f / std::sqrt(static_cast(_meta.dh));
+
+ for(size_t i=0; i<_meta.nlayer; ++i) {
+ tensor_t normed = _prefill.normed->slice(0, 0, ntoken);
+ normed = normed->view(seq_shape);
+ ops::rms_norm(normed, hidden_states, to_cpp(_weights.attn_norm_w, i), _meta.epsilon);
+
+ tensor_t q = _prefill.q_flat->slice(0, 0, ntoken);
+ q = q->view({ntoken, q_dim});
+ tensor_t k = _prefill.k_flat->slice(0, 0, ntoken);
+ k = k->view({ntoken, k_dim});
+ tensor_t v = _prefill.v_flat->slice(0, 0, ntoken);
+ v = v->view({ntoken, k_dim});
+
+ ops::linear(q, normed, to_cpp(_weights.attn_q_w, i), to_cpp(_weights.attn_q_b, i));
+ ops::linear(k, normed, to_cpp(_weights.attn_k_w, i), to_cpp(_weights.attn_k_b, i));
+ ops::linear(v, normed, to_cpp(_weights.attn_v_w, i), to_cpp(_weights.attn_v_b, i));
+
+ q = q->view({ntoken, _meta.nh, _meta.dh});
+ k = k->view({ntoken, _meta.nkvh, _meta.dh});
+ v = v->view({ntoken, _meta.nkvh, _meta.dh});
+
+ ops::rope(q, q, current_pos_ids, _meta.theta);
+ ops::rope(k, k, current_pos_ids, _meta.theta);
+
+ tensor_t k_cache_slot = _kv_caches[i].k->slice(0, _cur_pos, _cur_pos + ntoken);
+ tensor_t v_cache_slot = _kv_caches[i].v->slice(0, _cur_pos, _cur_pos + ntoken);
+
+ ops::rearrange(k_cache_slot, k);
+ ops::rearrange(v_cache_slot, v);
+
+ tensor_t k_full = _kv_caches[i].k->slice(0, 0, _cur_pos + ntoken);
+ tensor_t v_full = _kv_caches[i].v->slice(0, 0, _cur_pos + ntoken);
+
+ tensor_t attn_val = _prefill.attn_val_3d->slice(0, 0, ntoken);
+ attn_val = attn_val->view({ntoken, _meta.nh, _meta.dh});
+
+ ops::self_attention(attn_val, q, k_full, v_full, scale);
+
+ attn_val = attn_val->view({ntoken, _meta.hs});
+
+ tensor_t attn_output = _prefill.attn_out->slice(0, 0, ntoken);
+ attn_output = attn_output->view(seq_shape);
+ ops::linear(attn_output, attn_val, to_cpp(_weights.attn_o_w, i), nullptr);
+
+ ops::add(hidden_states, hidden_states, attn_output);
+
+ normed = _prefill.normed->slice(0, 0, ntoken);
+ normed = normed->view(seq_shape);
+ ops::rms_norm(normed, hidden_states, to_cpp(_weights.mlp_norm_w, i), _meta.epsilon);
+
+ tensor_t gate = _prefill.gate->slice(0, 0, ntoken);
+ gate = gate->view({ntoken, _meta.di});
+ tensor_t up = _prefill.up->slice(0, 0, ntoken);
+ up = up->view({ntoken, _meta.di});
+
+ ops::linear(gate, normed, to_cpp(_weights.mlp_gate_w, i), nullptr);
+ ops::linear(up, normed, to_cpp(_weights.mlp_up_w, i), nullptr);
+
+ tensor_t swiglu_out = _prefill.swiglu->slice(0, 0, ntoken);
+ swiglu_out = swiglu_out->view({ntoken, _meta.di});
+ ops::swiglu(swiglu_out, gate, up);
+
+ tensor_t mlp_output = _prefill.mlp_out->slice(0, 0, ntoken);
+ mlp_output = mlp_output->view(seq_shape);
+ ops::linear(mlp_output, swiglu_out, to_cpp(_weights.mlp_down_w, i), nullptr);
+
+ ops::add(hidden_states, hidden_states, mlp_output);
+ }
+
+ tensor_t final_normed = _prefill.normed->slice(0, 0, ntoken);
+ final_normed = final_normed->view(seq_shape);
+ ops::rms_norm(final_normed, hidden_states, to_cpp(_weights.out_norm_w), _meta.epsilon);
+
+ tensor_t last_hidden = final_normed->slice(0, ntoken-1, ntoken);
+ last_hidden = last_hidden->view({1, _meta.hs});
+
+ ops::linear(_logits, last_hidden, to_cpp(_weights.out_embed), nullptr);
+
+ _cur_pos += ntoken;
+
+ return _logits;
+}
+
+tensor_t Qwen2Model::forward_single_token(const int64_t* token_id) {
+ CHECK_ARGUMENT(token_id != nullptr, "token_id must not be null");
+
+ _tokens_tensor->load(token_id);
+ const int64_t current_pos = static_cast(_cur_pos);
+ _single_pos_id->load(¤t_pos);
+
+ ops::embedding(_hidden_states, _tokens_tensor, to_cpp(_weights.in_embed));
+
+ const float scale = 1.0f / std::sqrt(static_cast(_meta.dh));
+ for (size_t i = 0; i < _meta.nlayer; ++i) {
+ auto &layer = _decode_layers[i];
+
+ ops::rms_norm(_ln_out, _hidden_states, to_cpp(_weights.attn_norm_w, i), _meta.epsilon);
+
+ ops::linear(layer.q_flat, _ln_out, to_cpp(_weights.attn_q_w, i), to_cpp(_weights.attn_q_b, i));
+ ops::linear(layer.k_flat, _ln_out, to_cpp(_weights.attn_k_w, i), to_cpp(_weights.attn_k_b, i));
+ ops::linear(layer.v_flat, _ln_out, to_cpp(_weights.attn_v_w, i), to_cpp(_weights.attn_v_b, i));
+
+ ops::rope(layer.q_view, layer.q_view, _single_pos_id, _meta.theta);
+ ops::rope(layer.k_view, layer.k_view, _single_pos_id, _meta.theta);
+
+ tensor_t k_cache_slot = _kv_caches[i].k->slice(0, _cur_pos, _cur_pos + 1);
+ tensor_t v_cache_slot = _kv_caches[i].v->slice(0, _cur_pos, _cur_pos + 1);
+ ops::rearrange(k_cache_slot, layer.k_view);
+ ops::rearrange(v_cache_slot, layer.v_view);
+
+ tensor_t k_full = _kv_caches[i].k->slice(0, 0, _cur_pos + 1);
+ tensor_t v_full = _kv_caches[i].v->slice(0, 0, _cur_pos + 1);
+ ops::self_attention(layer.attn_val_3d, layer.q_view, k_full, v_full, scale);
+
+ ops::linear(_attn_out, layer.attn_val_2d, to_cpp(_weights.attn_o_w, i), nullptr);
+ ops::add(_hidden_states, _hidden_states, _attn_out);
+
+ ops::rms_norm(_ln_out, _hidden_states, to_cpp(_weights.mlp_norm_w, i), _meta.epsilon);
+ ops::linear(layer.gate, _ln_out, to_cpp(_weights.mlp_gate_w, i), nullptr);
+ ops::linear(layer.up, _ln_out, to_cpp(_weights.mlp_up_w, i), nullptr);
+ ops::swiglu(layer.swiglu, layer.gate, layer.up);
+ ops::linear(_mlp_out, layer.swiglu, to_cpp(_weights.mlp_down_w, i), nullptr);
+ ops::add(_hidden_states, _hidden_states, _mlp_out);
+ }
+
+ ops::rms_norm(_ln_out, _hidden_states, to_cpp(_weights.out_norm_w), _meta.epsilon);
+ ops::linear(_logits, _ln_out, to_cpp(_weights.out_embed), nullptr);
+
+ _cur_pos += 1;
+ return _logits;
+}
+
+int64_t Qwen2Model::infer(const int64_t* token_ids, size_t ntoken) {
+ tensor_t logits = forward(token_ids, ntoken);
+ ops::argmax(_argmax_idx, _argmax_val, logits);
+
+ if (_device_type == LLAISYS_DEVICE_CPU) {
+ return *reinterpret_cast(_argmax_idx->data());
+ }
+ int64_t result = 0;
+ core::context().runtime().api()->memcpy_sync(
+ &result,
+ _argmax_idx->data(),
+ sizeof(result),
+ LLAISYS_MEMCPY_D2H);
+ return result;
+}
+
+int64_t Qwen2Model::generateNext(const int64_t* token_ids, size_t ntoken, int top_k, float top_p, float temperature) {
+ tensor_t logits = forward(token_ids, ntoken);
+ return ops::sample(logits, top_k, top_p, temperature);
+}
+
+void Qwen2Model::reset() {
+ _cur_pos = 0;
+}
+
+void Qwen2Model::truncate(size_t position) {
+ CHECK_ARGUMENT(position <= _cur_pos, "truncate position exceeds current cache length");
+ _cur_pos = position;
+}
+
+} // namespace llaisys
diff --git a/src/models/qwen2/qwen2.hpp b/src/models/qwen2/qwen2.hpp
new file mode 100644
index 00000000..894d1b0d
--- /dev/null
+++ b/src/models/qwen2/qwen2.hpp
@@ -0,0 +1,83 @@
+#pragma once
+#include "llaisys/models/qwen2.h"
+#include "../../tensor/tensor.hpp"
+#include
+#include
+
+namespace llaisys {
+
+class Qwen2Model {
+public:
+ Qwen2Model(const LlaisysQwen2Meta& meta, llaisysDeviceType_t device, int device_id);
+ ~Qwen2Model();
+
+ LlaisysQwen2Weights* getWeights() { return &_weights; }
+
+ int64_t infer(const int64_t* token_ids, size_t ntoken);
+ int64_t generateNext(const int64_t* token_ids, size_t ntoken, int top_k, float top_p, float temperature);
+ void reset();
+ void truncate(size_t position);
+
+private:
+ LlaisysQwen2Meta _meta;
+ LlaisysQwen2Weights _weights;
+
+ llaisysDeviceType_t _device_type;
+ int _device_id;
+
+ struct KVCache {
+ tensor_t k;
+ tensor_t v;
+ };
+ struct DecodeLayerBuffers {
+ tensor_t q_flat;
+ tensor_t k_flat;
+ tensor_t v_flat;
+ tensor_t q_view;
+ tensor_t k_view;
+ tensor_t v_view;
+ tensor_t attn_val_3d;
+ tensor_t attn_val_2d;
+ tensor_t gate;
+ tensor_t up;
+ tensor_t swiglu;
+ };
+ struct PrefillBuffers {
+ size_t capacity = 0;
+ tensor_t tokens;
+ tensor_t hidden_states;
+ tensor_t normed;
+ tensor_t q_flat;
+ tensor_t k_flat;
+ tensor_t v_flat;
+ tensor_t attn_val_3d;
+ tensor_t attn_out;
+ tensor_t gate;
+ tensor_t up;
+ tensor_t swiglu;
+ tensor_t mlp_out;
+ };
+ std::vector _kv_caches;
+ std::vector _decode_layers;
+ PrefillBuffers _prefill;
+
+ size_t _cur_pos = 0;
+
+ tensor_t _hidden_states;
+ tensor_t _ln_out;
+ tensor_t _attn_out;
+ tensor_t _mlp_out;
+ tensor_t _logits;
+ tensor_t _tokens_tensor;
+ tensor_t _single_pos_id;
+ tensor_t _pos_ids;
+ tensor_t _argmax_idx;
+ tensor_t _argmax_val;
+
+ void init_buffers();
+ void ensure_prefill_buffers(size_t ntoken);
+ tensor_t forward(const int64_t* token_ids, size_t ntoken);
+ tensor_t forward_single_token(const int64_t* token_id);
+};
+
+} // namespace llaisys
diff --git a/src/ops/add/cpu/add_cpu.cpp b/src/ops/add/cpu/add_cpu.cpp
index 47f6a3d4..90c332a8 100644
--- a/src/ops/add/cpu/add_cpu.cpp
+++ b/src/ops/add/cpu/add_cpu.cpp
@@ -1,12 +1,58 @@
#include "add_cpu.hpp"
+#if defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)
+#ifdef __C
+#pragma push_macro("__C")
+#undef __C
+#define LLAISYS_RESTORE_C_MACRO
+#endif
+#include
+#ifdef LLAISYS_RESTORE_C_MACRO
+#pragma pop_macro("__C")
+#undef LLAISYS_RESTORE_C_MACRO
+#endif
+#endif
+
#include "../../../utils.hpp"
-#include
+#include
+#include
+
+namespace {
+
+#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__))
+inline bool has_avx2() {
+ return __builtin_cpu_supports("avx2");
+}
+
+__attribute__((target("avx2")))
+void add_f32_avx2(float *c, const float *a, const float *b, std::ptrdiff_t numel) {
+ const std::ptrdiff_t simd_numel = numel - (numel % 8);
+
+#pragma omp parallel for schedule(static) if (numel >= 4096)
+ for (std::ptrdiff_t i = 0; i < simd_numel; i += 8) {
+ const __m256 va = _mm256_loadu_ps(a + i);
+ const __m256 vb = _mm256_loadu_ps(b + i);
+ _mm256_storeu_ps(c + i, _mm256_add_ps(va, vb));
+ }
+
+ for (std::ptrdiff_t i = simd_numel; i < numel; ++i) {
+ c[i] = a[i] + b[i];
+ }
+}
+#endif
+
+void add_f32(float *c, const float *a, const float *b, std::ptrdiff_t numel) {
+#pragma omp parallel for schedule(static) if (numel >= 4096)
+ for (std::ptrdiff_t i = 0; i < numel; ++i) {
+ c[i] = a[i] + b[i];
+ }
+}
template
-void add_(T *c, const T *a, const T *b, size_t numel) {
- for (size_t i = 0; i < numel; i++) {
+void add_generic(T *c, const T *a, const T *b, std::ptrdiff_t numel) {
+#pragma omp parallel for schedule(static) if (numel >= 4096)
+ for (std::ptrdiff_t i = 0; i < numel; ++i) {
if constexpr (std::is_same_v || std::is_same_v) {
c[i] = llaisys::utils::cast(llaisys::utils::cast(a[i]) + llaisys::utils::cast(b[i]));
} else {
@@ -15,17 +61,25 @@ void add_(T *c, const T *a, const T *b, size_t numel) {
}
}
+} // namespace
+
namespace llaisys::ops::cpu {
void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel) {
+ const auto elem_count = static_cast(numel);
switch (type) {
case LLAISYS_DTYPE_F32:
- return add_(reinterpret_cast(c), reinterpret_cast(a), reinterpret_cast(b), numel);
+#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__))
+ if (has_avx2()) {
+ return add_f32_avx2(reinterpret_cast(c), reinterpret_cast(a), reinterpret_cast(b), elem_count);
+ }
+#endif
+ return add_f32(reinterpret_cast(c), reinterpret_cast(a), reinterpret_cast(b), elem_count);
case LLAISYS_DTYPE_BF16:
- return add_(reinterpret_cast(c), reinterpret_cast(a),
- reinterpret_cast(b), numel);
+ return add_generic(reinterpret_cast(c), reinterpret_cast(a),
+ reinterpret_cast(b), elem_count);
case LLAISYS_DTYPE_F16:
- return add_(reinterpret_cast(c), reinterpret_cast(a),
- reinterpret_cast(b), numel);
+ return add_generic(reinterpret_cast(c), reinterpret_cast(a),
+ reinterpret_cast(b), elem_count);
default:
EXCEPTION_UNSUPPORTED_DATATYPE(type);
}
diff --git a/src/ops/add/nvidia/add_nvidia.cu b/src/ops/add/nvidia/add_nvidia.cu
new file mode 100644
index 00000000..8485030e
--- /dev/null
+++ b/src/ops/add/nvidia/add_nvidia.cu
@@ -0,0 +1,47 @@
+#include "add_nvidia.cuh"
+
+#include "../../nvidia/nvidia_common.cuh"
+
+namespace {
+
+template
+__global__ void add_kernel(T *c, const T *a, const T *b, size_t numel) {
+ const size_t idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x);
+ if (idx >= numel) {
+ return;
+ }
+
+ const float lhs = llaisys::device::nvidia::cuda_utils::toFloat(a[idx]);
+ const float rhs = llaisys::device::nvidia::cuda_utils::toFloat(b[idx]);
+ c[idx] = llaisys::device::nvidia::cuda_utils::fromFloat(lhs + rhs);
+}
+
+template
+void add_impl(std::byte *c, const std::byte *a, const std::byte *b, size_t numel) {
+ constexpr int threads = 256;
+ add_kernel<<>>(
+ reinterpret_cast(c),
+ reinterpret_cast(a),
+ reinterpret_cast(b),
+ numel);
+ LLAISYS_CUDA_CHECK(cudaGetLastError());
+}
+
+} // namespace
+
+namespace llaisys::ops::nvidia {
+
+void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t dtype, size_t numel) {
+ switch (dtype) {
+ case LLAISYS_DTYPE_F32:
+ return add_impl(c, a, b, numel);
+ case LLAISYS_DTYPE_F16:
+ return add_impl(c, a, b, numel);
+ case LLAISYS_DTYPE_BF16:
+ return add_impl(c, a, b, numel);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(dtype);
+ }
+}
+
+} // namespace llaisys::ops::nvidia
diff --git a/src/ops/add/nvidia/add_nvidia.cuh b/src/ops/add/nvidia/add_nvidia.cuh
new file mode 100644
index 00000000..cd7e4584
--- /dev/null
+++ b/src/ops/add/nvidia/add_nvidia.cuh
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::nvidia {
+void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t dtype, size_t numel);
+}
diff --git a/src/ops/add/op.cpp b/src/ops/add/op.cpp
index a057330d..6c973fcc 100644
--- a/src/ops/add/op.cpp
+++ b/src/ops/add/op.cpp
@@ -4,6 +4,7 @@
#include "../../utils.hpp"
#include "cpu/add_cpu.hpp"
+#include "nvidia/add_nvidia.cuh"
namespace llaisys::ops {
void add(tensor_t c, tensor_t a, tensor_t b) {
@@ -25,8 +26,7 @@ void add(tensor_t c, tensor_t a, tensor_t b) {
return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel());
#ifdef ENABLE_NVIDIA_API
case LLAISYS_DEVICE_NVIDIA:
- TO_BE_IMPLEMENTED();
- return;
+ return nvidia::add(c->data(), a->data(), b->data(), c->dtype(), c->numel());
#endif
default:
EXCEPTION_UNSUPPORTED_DEVICE;
diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp
new file mode 100644
index 00000000..c29bc532
--- /dev/null
+++ b/src/ops/argmax/cpu/argmax_cpu.cpp
@@ -0,0 +1,142 @@
+#include "argmax_cpu.hpp"
+
+#if defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)
+#ifdef __C
+#pragma push_macro("__C")
+#undef __C
+#define LLAISYS_RESTORE_C_MACRO
+#endif
+#include
+#ifdef LLAISYS_RESTORE_C_MACRO
+#pragma pop_macro("__C")
+#undef LLAISYS_RESTORE_C_MACRO
+#endif
+#endif
+
+#include "../../../utils.hpp"
+
+#include
+#include
+#include
+
+namespace {
+
+#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__))
+inline bool has_avx2() {
+ return __builtin_cpu_supports("avx2");
+}
+
+__attribute__((target("avx2")))
+void argmax_f32_avx2(size_t *max_idx, float *max_val, const float *vals, std::ptrdiff_t numel) {
+ if (numel <= 0) {
+ *max_idx = 0;
+ *max_val = 0.0f;
+ return;
+ }
+
+ std::ptrdiff_t i = 0;
+ float best_val = vals[0];
+ size_t best_idx = 0;
+
+ if (numel >= 8) {
+ __m256 max_vals = _mm256_loadu_ps(vals);
+ __m256i max_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
+ __m256i cur_indices = _mm256_setr_epi32(8, 9, 10, 11, 12, 13, 14, 15);
+ const __m256i step = _mm256_set1_epi32(8);
+ i = 8;
+
+ for (; i + 8 <= numel; i += 8) {
+ const __m256 cur_vals = _mm256_loadu_ps(vals + i);
+ const __m256 mask = _mm256_cmp_ps(cur_vals, max_vals, _CMP_GT_OQ);
+ max_vals = _mm256_blendv_ps(max_vals, cur_vals, mask);
+ max_indices = _mm256_blendv_epi8(max_indices, cur_indices, _mm256_castps_si256(mask));
+ cur_indices = _mm256_add_epi32(cur_indices, step);
+ }
+
+ alignas(32) float lane_vals[8];
+ alignas(32) int lane_indices[8];
+ _mm256_store_ps(lane_vals, max_vals);
+ _mm256_store_si256(reinterpret_cast<__m256i *>(lane_indices), max_indices);
+ for (int lane = 0; lane < 8; ++lane) {
+ if (lane_vals[lane] > best_val) {
+ best_val = lane_vals[lane];
+ best_idx = static_cast(lane_indices[lane]);
+ }
+ }
+ }
+
+ for (; i < numel; ++i) {
+ if (vals[i] > best_val) {
+ best_val = vals[i];
+ best_idx = static_cast(i);
+ }
+ }
+
+ *max_idx = best_idx;
+ *max_val = best_val;
+}
+#endif
+
+template
+void argmax_generic(size_t *max_idx, T *max_val, const T *vals, std::ptrdiff_t numel){
+ size_t max_index = 0;
+ float max_value = llaisys::utils::cast(vals[0]);
+
+ std::ptrdiff_t i = 1;
+ for (; i + 3 < numel; i += 4) {
+ const float v0 = llaisys::utils::cast(vals[i + 0]);
+ const float v1 = llaisys::utils::cast(vals[i + 1]);
+ const float v2 = llaisys::utils::cast(vals[i + 2]);
+ const float v3 = llaisys::utils::cast(vals[i + 3]);
+ if (v0 > max_value) {
+ max_value = v0;
+ max_index = static_cast(i + 0);
+ }
+ if (v1 > max_value) {
+ max_value = v1;
+ max_index = static_cast(i + 1);
+ }
+ if (v2 > max_value) {
+ max_value = v2;
+ max_index = static_cast(i + 2);
+ }
+ if (v3 > max_value) {
+ max_value = v3;
+ max_index = static_cast(i + 3);
+ }
+ }
+
+ for (; i < numel; ++i) {
+ const float current_value = llaisys::utils::cast(vals[i]);
+ if (current_value > max_value) {
+ max_value = current_value;
+ max_index = static_cast(i);
+ }
+ }
+
+ *max_idx = max_index;
+ *max_val = llaisys::utils::cast(max_value);
+}
+
+} // namespace
+
+namespace llaisys::ops::cpu {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel){
+ const auto elem_count = static_cast(numel);
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__))
+ if (has_avx2()) {
+ return argmax_f32_avx2(reinterpret_cast(max_idx), reinterpret_cast(max_val), reinterpret_cast(vals), elem_count);
+ }
+#endif
+ return argmax_generic(reinterpret_cast(max_idx), reinterpret_cast(max_val), reinterpret_cast(vals), elem_count);
+ case LLAISYS_DTYPE_BF16:
+ return argmax_generic(reinterpret_cast(max_idx), reinterpret_cast(max_val), reinterpret_cast(vals), elem_count);
+ case LLAISYS_DTYPE_F16:
+ return argmax_generic(reinterpret_cast(max_idx), reinterpret_cast(max_val), reinterpret_cast(vals), elem_count);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp
new file mode 100644
index 00000000..5f58c207
--- /dev/null
+++ b/src/ops/argmax/cpu/argmax_cpu.hpp
@@ -0,0 +1,8 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+ void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t dtype, size_t numel);
+}
diff --git a/src/ops/argmax/nvidia/argmax_nvidia.cu b/src/ops/argmax/nvidia/argmax_nvidia.cu
new file mode 100644
index 00000000..337f24da
--- /dev/null
+++ b/src/ops/argmax/nvidia/argmax_nvidia.cu
@@ -0,0 +1,40 @@
+#include "argmax_nvidia.cuh"
+
+#include "../../nvidia/nvidia_common.cuh"
+#include "../cpu/argmax_cpu.hpp"
+
+#include
+
+namespace llaisys::ops::nvidia {
+
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t dtype, size_t numel) {
+ std::vector host_vals(numel * utils::dsize(dtype));
+ std::vector host_idx(sizeof(size_t));
+ std::vector host_max_val(utils::dsize(dtype));
+
+ LLAISYS_CUDA_CHECK(cudaMemcpyAsync(
+ host_vals.data(),
+ vals,
+ host_vals.size(),
+ cudaMemcpyDeviceToHost,
+ current_stream()));
+ LLAISYS_CUDA_CHECK(cudaStreamSynchronize(current_stream()));
+
+ cpu::argmax(host_idx.data(), host_max_val.data(), host_vals.data(), dtype, numel);
+
+ LLAISYS_CUDA_CHECK(cudaMemcpyAsync(
+ max_idx,
+ host_idx.data(),
+ host_idx.size(),
+ cudaMemcpyHostToDevice,
+ current_stream()));
+ LLAISYS_CUDA_CHECK(cudaMemcpyAsync(
+ max_val,
+ host_max_val.data(),
+ host_max_val.size(),
+ cudaMemcpyHostToDevice,
+ current_stream()));
+ LLAISYS_CUDA_CHECK(cudaStreamSynchronize(current_stream()));
+}
+
+} // namespace llaisys::ops::nvidia
diff --git a/src/ops/argmax/nvidia/argmax_nvidia.cuh b/src/ops/argmax/nvidia/argmax_nvidia.cuh
new file mode 100644
index 00000000..710a2498
--- /dev/null
+++ b/src/ops/argmax/nvidia/argmax_nvidia.cuh
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::nvidia {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t dtype, size_t numel);
+}
diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp
index 6dc37d42..c2b5ae5e 100644
--- a/src/ops/argmax/op.cpp
+++ b/src/ops/argmax/op.cpp
@@ -1,7 +1,32 @@
#include "op.hpp"
+
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+
+#include "cpu/argmax_cpu.hpp"
+#include "nvidia/argmax_nvidia.cuh"
+
namespace llaisys::ops {
void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) {
- TO_BE_IMPLEMENTED();
+ CHECK_SAME_DEVICE(max_idx, max_val, vals);
+ CHECK_SAME_DTYPE(max_val->dtype(), vals->dtype());
+ ASSERT(vals->isContiguous(), "Argmax: vals tensor must be contiguous.");
+
+ if(vals->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+ }
+
+ llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId());
+ switch (vals->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return nvidia::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp
new file mode 100644
index 00000000..d0010892
--- /dev/null
+++ b/src/ops/embedding/cpu/embedding_cpu.cpp
@@ -0,0 +1,31 @@
+#include "embedding_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+#include
+
+template
+void embedding_(T *out, const int64_t *index, const T* weight, size_t numel, size_t embedding_dim){
+#pragma omp parallel for schedule(static) if (numel >= 16)
+ for(std::ptrdiff_t i = 0; i < static_cast(numel); ++i){
+ T* out_row_dst = out + i * embedding_dim;
+ const T* weight_row_src = weight + index[i] * embedding_dim;
+ std::memcpy(out_row_dst, weight_row_src, embedding_dim * sizeof(T));
+ }
+}
+
+namespace llaisys::ops::cpu {
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t numel, size_t embedding_dim){
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ return embedding_(reinterpret_cast(out), reinterpret_cast(index), reinterpret_cast(weight), numel, embedding_dim);
+ case LLAISYS_DTYPE_BF16:
+ return embedding_(reinterpret_cast(out), reinterpret_cast(index), reinterpret_cast(weight), numel, embedding_dim);
+ case LLAISYS_DTYPE_F16:
+ return embedding_(reinterpret_cast(out), reinterpret_cast(index), reinterpret_cast(weight), numel, embedding_dim);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp
new file mode 100644
index 00000000..7d4f0d82
--- /dev/null
+++ b/src/ops/embedding/cpu/embedding_cpu.hpp
@@ -0,0 +1,8 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+ void embedding(std::byte *out, const std::byte *index, const std::byte *wight, llaisysDataType_t dtype, size_t numel, size_t embedding_dim);
+}
diff --git a/src/ops/embedding/nvidia/embedding_nvidia.cu b/src/ops/embedding/nvidia/embedding_nvidia.cu
new file mode 100644
index 00000000..9dadd983
--- /dev/null
+++ b/src/ops/embedding/nvidia/embedding_nvidia.cu
@@ -0,0 +1,50 @@
+#include "embedding_nvidia.cuh"
+
+#include "../../nvidia/nvidia_common.cuh"
+
+namespace {
+
+template
+__global__ void embedding_kernel(T *out, const int64_t *index, const T *weight, size_t numel, size_t embedding_dim) {
+ const size_t idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x);
+ const size_t total = numel * embedding_dim;
+ if (idx >= total) {
+ return;
+ }
+
+ const size_t row = idx / embedding_dim;
+ const size_t col = idx % embedding_dim;
+ out[idx] = weight[static_cast(index[row]) * embedding_dim + col];
+}
+
+template
+void embedding_impl(std::byte *out, const std::byte *index, const std::byte *weight, size_t numel, size_t embedding_dim) {
+ constexpr int threads = 256;
+ const size_t total = numel * embedding_dim;
+ embedding_kernel<<>>(
+ reinterpret_cast(out),
+ reinterpret_cast(index),
+ reinterpret_cast(weight),
+ numel,
+ embedding_dim);
+ LLAISYS_CUDA_CHECK(cudaGetLastError());
+}
+
+} // namespace
+
+namespace llaisys::ops::nvidia {
+
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t dtype, size_t numel, size_t embedding_dim) {
+ switch (dtype) {
+ case LLAISYS_DTYPE_F32:
+ return embedding_impl(out, index, weight, numel, embedding_dim);
+ case LLAISYS_DTYPE_F16:
+ return embedding_impl(out, index, weight, numel, embedding_dim);
+ case LLAISYS_DTYPE_BF16:
+ return embedding_impl(out, index, weight, numel, embedding_dim);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(dtype);
+ }
+}
+
+} // namespace llaisys::ops::nvidia
diff --git a/src/ops/embedding/nvidia/embedding_nvidia.cuh b/src/ops/embedding/nvidia/embedding_nvidia.cuh
new file mode 100644
index 00000000..cd3fab17
--- /dev/null
+++ b/src/ops/embedding/nvidia/embedding_nvidia.cuh
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::nvidia {
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t dtype, size_t numel, size_t embedding_dim);
+}
diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp
index 84b9a5d0..2864563c 100644
--- a/src/ops/embedding/op.cpp
+++ b/src/ops/embedding/op.cpp
@@ -1,7 +1,35 @@
#include "op.hpp"
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+
+#include "cpu/embedding_cpu.hpp"
+#include "nvidia/embedding_nvidia.cuh"
namespace llaisys::ops {
void embedding(tensor_t out, tensor_t index, tensor_t weight) {
- TO_BE_IMPLEMENTED();
+ CHECK_SAME_DEVICE(out, index, weight);
+ CHECK_SAME_DTYPE(out->dtype(), weight->dtype());
+ ASSERT(index->dtype() == LLAISYS_DTYPE_I64, "Embedding: index tensor must be int64.");
+ ASSERT(index->isContiguous(), "Embedding: index tensor must be contiguous.");
+ size_t embedding_dim = weight->shape().back();
+ ASSERT(out->shape().size() == 2 && out->shape()[1] == embedding_dim,
+ "Embedding: output tensor shape is invalid.");
+ ASSERT(index->shape().size() == 1 && index->shape()[0] == out->shape()[0],
+ "Embedding: index tensor shape is invalid.");
+ if(out->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->numel(), embedding_dim);
+ }
+
+ llaisys::core::context().setDevice(out->deviceType(), out->deviceId());
+ switch (out->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->numel(), embedding_dim);
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return nvidia::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->numel(), embedding_dim);
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp
new file mode 100644
index 00000000..86c65e5c
--- /dev/null
+++ b/src/ops/linear/cpu/linear_cpu.cpp
@@ -0,0 +1,504 @@
+#include "linear_cpu.hpp"
+
+#if defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)
+#ifdef __C
+#pragma push_macro("__C")
+#undef __C
+#define LLAISYS_RESTORE_C_MACRO
+#endif
+#include
+#ifdef LLAISYS_RESTORE_C_MACRO
+#pragma pop_macro("__C")
+#undef LLAISYS_RESTORE_C_MACRO
+#endif
+#endif
+
+#ifdef LLAISYS_USE_OPENBLAS
+extern "C" {
+#include
+}
+#endif
+
+#include "../../../utils.hpp"
+
+#include
+#include
+#include
+#include
+
+namespace {
+
+constexpr std::ptrdiff_t OUTPUT_TILE = 32;
+constexpr std::ptrdiff_t OUTPUT_UNROLL = 4;
+constexpr std::ptrdiff_t ROW_UNROLL = 2;
+constexpr std::ptrdiff_t REDUCTION_TILE = 256;
+
+#ifdef LLAISYS_USE_OPENBLAS
+inline bool should_use_openblas(const std::vector &shapes) {
+ const std::ptrdiff_t dimi = static_cast(shapes[0]);
+ const std::ptrdiff_t dimk = static_cast(shapes[1]);
+ const std::ptrdiff_t dimj = static_cast(shapes[2]);
+ return dimi >= 16 && dimk >= 1024 && dimj >= 1024 && dimk >= dimj;
+}
+#endif
+
+inline bool should_use_gemm_kernel(const std::vector &shapes) {
+ const std::ptrdiff_t dimi = static_cast(shapes[0]);
+ const std::ptrdiff_t dimk = static_cast(shapes[1]);
+ const std::ptrdiff_t dimj = static_cast(shapes[2]);
+ return dimi >= 16 && dimk >= 1024 && dimj >= 1024;
+}
+
+#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__))
+inline bool has_avx2_fma() {
+ return __builtin_cpu_supports("avx2") && __builtin_cpu_supports("fma");
+}
+
+__attribute__((target("avx2,fma,sse3")))
+inline float hsum256_ps(__m256 v) {
+ const __m128 low = _mm256_castps256_ps128(v);
+ const __m128 high = _mm256_extractf128_ps(v, 1);
+ __m128 sum = _mm_add_ps(low, high);
+ sum = _mm_hadd_ps(sum, sum);
+ sum = _mm_hadd_ps(sum, sum);
+ return _mm_cvtss_f32(sum);
+}
+
+__attribute__((target("avx2,fma")))
+void linear_f32_avx2_matvec(float *out, const float *in, const float *weight, const float *bias, const std::vector &shapes) {
+ const std::ptrdiff_t dimi = static_cast(shapes[0]);
+ const std::ptrdiff_t dimk = static_cast(shapes[1]);
+ const std::ptrdiff_t dimj = static_cast(shapes[2]);
+ const std::ptrdiff_t ksimd = dimk - (dimk % 8);
+
+#pragma omp parallel for collapse(2) schedule(static)
+ for (std::ptrdiff_t i = 0; i < dimi; ++i) {
+ for (std::ptrdiff_t j0 = 0; j0 < dimj; j0 += OUTPUT_TILE) {
+ const float *xrow = in + i * dimk;
+ float *outrow = out + i * dimj;
+ const std::ptrdiff_t jend = std::min(j0 + OUTPUT_TILE, dimj);
+ std::ptrdiff_t j = j0;
+
+ for (; j + OUTPUT_UNROLL <= jend; j += OUTPUT_UNROLL) {
+ const float *w0 = weight + (j + 0) * dimk;
+ const float *w1 = weight + (j + 1) * dimk;
+ const float *w2 = weight + (j + 2) * dimk;
+ const float *w3 = weight + (j + 3) * dimk;
+
+ __m256 acc0 = _mm256_setzero_ps();
+ __m256 acc1 = _mm256_setzero_ps();
+ __m256 acc2 = _mm256_setzero_ps();
+ __m256 acc3 = _mm256_setzero_ps();
+
+ for (std::ptrdiff_t k = 0; k < ksimd; k += 8) {
+ const __m256 x = _mm256_loadu_ps(xrow + k);
+ acc0 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w0 + k), acc0);
+ acc1 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w1 + k), acc1);
+ acc2 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w2 + k), acc2);
+ acc3 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w3 + k), acc3);
+ }
+
+ float sum0 = hsum256_ps(acc0) + (bias ? bias[j + 0] : 0.0f);
+ float sum1 = hsum256_ps(acc1) + (bias ? bias[j + 1] : 0.0f);
+ float sum2 = hsum256_ps(acc2) + (bias ? bias[j + 2] : 0.0f);
+ float sum3 = hsum256_ps(acc3) + (bias ? bias[j + 3] : 0.0f);
+
+ for (std::ptrdiff_t k = ksimd; k < dimk; ++k) {
+ const float x = xrow[k];
+ sum0 += x * w0[k];
+ sum1 += x * w1[k];
+ sum2 += x * w2[k];
+ sum3 += x * w3[k];
+ }
+
+ outrow[j + 0] = sum0;
+ outrow[j + 1] = sum1;
+ outrow[j + 2] = sum2;
+ outrow[j + 3] = sum3;
+ }
+
+ for (; j < jend; ++j) {
+ const float *wrow = weight + j * dimk;
+ __m256 acc = _mm256_setzero_ps();
+
+ for (std::ptrdiff_t k = 0; k < ksimd; k += 8) {
+ const __m256 x = _mm256_loadu_ps(xrow + k);
+ acc = _mm256_fmadd_ps(x, _mm256_loadu_ps(wrow + k), acc);
+ }
+
+ float sum = hsum256_ps(acc) + (bias ? bias[j] : 0.0f);
+ for (std::ptrdiff_t k = ksimd; k < dimk; ++k) {
+ sum += xrow[k] * wrow[k];
+ }
+ outrow[j] = sum;
+ }
+ }
+ }
+}
+
+__attribute__((target("avx2,fma")))
+void linear_f32_avx2_gemm(float *out, const float *in, const float *weight, const float *bias, const std::vector &shapes) {
+ const std::ptrdiff_t dimi = static_cast(shapes[0]);
+ const std::ptrdiff_t dimk = static_cast(shapes[1]);
+ const std::ptrdiff_t dimj = static_cast(shapes[2]);
+ const std::ptrdiff_t ksimd = dimk - (dimk % 8);
+
+#pragma omp parallel for collapse(2) schedule(static)
+ for (std::ptrdiff_t i0 = 0; i0 < dimi; i0 += ROW_UNROLL) {
+ for (std::ptrdiff_t j0 = 0; j0 < dimj; j0 += OUTPUT_TILE) {
+ const std::ptrdiff_t iend = std::min(i0 + ROW_UNROLL, dimi);
+ const std::ptrdiff_t jend = std::min(j0 + OUTPUT_TILE, dimj);
+
+ if (iend - i0 < ROW_UNROLL) {
+ for (std::ptrdiff_t i = i0; i < iend; ++i) {
+ const float *xrow = in + i * dimk;
+ float *outrow = out + i * dimj;
+ std::ptrdiff_t j = j0;
+ for (; j + OUTPUT_UNROLL <= jend; j += OUTPUT_UNROLL) {
+ const float *w0 = weight + (j + 0) * dimk;
+ const float *w1 = weight + (j + 1) * dimk;
+ const float *w2 = weight + (j + 2) * dimk;
+ const float *w3 = weight + (j + 3) * dimk;
+
+ __m256 acc0 = _mm256_setzero_ps();
+ __m256 acc1 = _mm256_setzero_ps();
+ __m256 acc2 = _mm256_setzero_ps();
+ __m256 acc3 = _mm256_setzero_ps();
+
+ for (std::ptrdiff_t k = 0; k < ksimd; k += 8) {
+ const __m256 x = _mm256_loadu_ps(xrow + k);
+ acc0 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w0 + k), acc0);
+ acc1 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w1 + k), acc1);
+ acc2 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w2 + k), acc2);
+ acc3 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w3 + k), acc3);
+ }
+
+ float sum0 = hsum256_ps(acc0) + (bias ? bias[j + 0] : 0.0f);
+ float sum1 = hsum256_ps(acc1) + (bias ? bias[j + 1] : 0.0f);
+ float sum2 = hsum256_ps(acc2) + (bias ? bias[j + 2] : 0.0f);
+ float sum3 = hsum256_ps(acc3) + (bias ? bias[j + 3] : 0.0f);
+ for (std::ptrdiff_t k = ksimd; k < dimk; ++k) {
+ const float x = xrow[k];
+ sum0 += x * w0[k];
+ sum1 += x * w1[k];
+ sum2 += x * w2[k];
+ sum3 += x * w3[k];
+ }
+
+ outrow[j + 0] = sum0;
+ outrow[j + 1] = sum1;
+ outrow[j + 2] = sum2;
+ outrow[j + 3] = sum3;
+ }
+
+ for (; j < jend; ++j) {
+ const float *wrow = weight + j * dimk;
+ __m256 acc = _mm256_setzero_ps();
+ for (std::ptrdiff_t k = 0; k < ksimd; k += 8) {
+ acc = _mm256_fmadd_ps(_mm256_loadu_ps(xrow + k), _mm256_loadu_ps(wrow + k), acc);
+ }
+ float sum = hsum256_ps(acc) + (bias ? bias[j] : 0.0f);
+ for (std::ptrdiff_t k = ksimd; k < dimk; ++k) {
+ sum += xrow[k] * wrow[k];
+ }
+ outrow[j] = sum;
+ }
+ }
+ continue;
+ }
+
+ const float *xrow0 = in + i0 * dimk;
+ const float *xrow1 = in + (i0 + 1) * dimk;
+ float *outrow0 = out + i0 * dimj;
+ float *outrow1 = out + (i0 + 1) * dimj;
+ std::ptrdiff_t j = j0;
+
+ for (; j + OUTPUT_UNROLL <= jend; j += OUTPUT_UNROLL) {
+ const float *w0 = weight + (j + 0) * dimk;
+ const float *w1 = weight + (j + 1) * dimk;
+ const float *w2 = weight + (j + 2) * dimk;
+ const float *w3 = weight + (j + 3) * dimk;
+
+ __m256 acc00 = _mm256_setzero_ps();
+ __m256 acc01 = _mm256_setzero_ps();
+ __m256 acc02 = _mm256_setzero_ps();
+ __m256 acc03 = _mm256_setzero_ps();
+ __m256 acc10 = _mm256_setzero_ps();
+ __m256 acc11 = _mm256_setzero_ps();
+ __m256 acc12 = _mm256_setzero_ps();
+ __m256 acc13 = _mm256_setzero_ps();
+
+ for (std::ptrdiff_t k = 0; k < ksimd; k += 8) {
+ const __m256 x0 = _mm256_loadu_ps(xrow0 + k);
+ const __m256 x1 = _mm256_loadu_ps(xrow1 + k);
+ const __m256 wv0 = _mm256_loadu_ps(w0 + k);
+ const __m256 wv1 = _mm256_loadu_ps(w1 + k);
+ const __m256 wv2 = _mm256_loadu_ps(w2 + k);
+ const __m256 wv3 = _mm256_loadu_ps(w3 + k);
+
+ acc00 = _mm256_fmadd_ps(x0, wv0, acc00);
+ acc01 = _mm256_fmadd_ps(x0, wv1, acc01);
+ acc02 = _mm256_fmadd_ps(x0, wv2, acc02);
+ acc03 = _mm256_fmadd_ps(x0, wv3, acc03);
+ acc10 = _mm256_fmadd_ps(x1, wv0, acc10);
+ acc11 = _mm256_fmadd_ps(x1, wv1, acc11);
+ acc12 = _mm256_fmadd_ps(x1, wv2, acc12);
+ acc13 = _mm256_fmadd_ps(x1, wv3, acc13);
+ }
+
+ float sum00 = hsum256_ps(acc00) + (bias ? bias[j + 0] : 0.0f);
+ float sum01 = hsum256_ps(acc01) + (bias ? bias[j + 1] : 0.0f);
+ float sum02 = hsum256_ps(acc02) + (bias ? bias[j + 2] : 0.0f);
+ float sum03 = hsum256_ps(acc03) + (bias ? bias[j + 3] : 0.0f);
+ float sum10 = hsum256_ps(acc10) + (bias ? bias[j + 0] : 0.0f);
+ float sum11 = hsum256_ps(acc11) + (bias ? bias[j + 1] : 0.0f);
+ float sum12 = hsum256_ps(acc12) + (bias ? bias[j + 2] : 0.0f);
+ float sum13 = hsum256_ps(acc13) + (bias ? bias[j + 3] : 0.0f);
+
+ for (std::ptrdiff_t k = ksimd; k < dimk; ++k) {
+ const float x0 = xrow0[k];
+ const float x1 = xrow1[k];
+ const float wv0 = w0[k];
+ const float wv1 = w1[k];
+ const float wv2 = w2[k];
+ const float wv3 = w3[k];
+ sum00 += x0 * wv0;
+ sum01 += x0 * wv1;
+ sum02 += x0 * wv2;
+ sum03 += x0 * wv3;
+ sum10 += x1 * wv0;
+ sum11 += x1 * wv1;
+ sum12 += x1 * wv2;
+ sum13 += x1 * wv3;
+ }
+
+ outrow0[j + 0] = sum00;
+ outrow0[j + 1] = sum01;
+ outrow0[j + 2] = sum02;
+ outrow0[j + 3] = sum03;
+ outrow1[j + 0] = sum10;
+ outrow1[j + 1] = sum11;
+ outrow1[j + 2] = sum12;
+ outrow1[j + 3] = sum13;
+ }
+
+ for (; j < jend; ++j) {
+ const float *wrow = weight + j * dimk;
+ __m256 acc0 = _mm256_setzero_ps();
+ __m256 acc1 = _mm256_setzero_ps();
+ for (std::ptrdiff_t k = 0; k < ksimd; k += 8) {
+ const __m256 wv = _mm256_loadu_ps(wrow + k);
+ acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(xrow0 + k), wv, acc0);
+ acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(xrow1 + k), wv, acc1);
+ }
+ float sum0 = hsum256_ps(acc0) + (bias ? bias[j] : 0.0f);
+ float sum1 = hsum256_ps(acc1) + (bias ? bias[j] : 0.0f);
+ for (std::ptrdiff_t k = ksimd; k < dimk; ++k) {
+ const float wv = wrow[k];
+ sum0 += xrow0[k] * wv;
+ sum1 += xrow1[k] * wv;
+ }
+ outrow0[j] = sum0;
+ outrow1[j] = sum1;
+ }
+ }
+ }
+}
+#endif
+
+#ifdef LLAISYS_USE_OPENBLAS
+void linear_f32_openblas(float *out, const float *in, const float *weight, const float *bias, const std::vector &shapes) {
+ const blasint dimi = static_cast(shapes[0]);
+ const blasint dimk = static_cast(shapes[1]);
+ const blasint dimj = static_cast(shapes[2]);
+
+ cblas_sgemm(
+ CblasRowMajor,
+ CblasNoTrans,
+ CblasTrans,
+ dimi,
+ dimj,
+ dimk,
+ 1.0f,
+ in,
+ dimk,
+ weight,
+ dimk,
+ 0.0f,
+ out,
+ dimj);
+
+ if (bias == nullptr) {
+ return;
+ }
+
+ const std::ptrdiff_t dimi_p = static_cast(dimi);
+ const std::ptrdiff_t dimj_p = static_cast(dimj);
+#pragma omp parallel for schedule(static)
+ for (std::ptrdiff_t i = 0; i < dimi_p; ++i) {
+ float *outrow = out + i * dimj_p;
+ for (std::ptrdiff_t j0 = 0; j0 < dimj_p; j0 += OUTPUT_TILE) {
+ const std::ptrdiff_t jend = std::min(j0 + OUTPUT_TILE, dimj_p);
+#pragma omp simd
+ for (std::ptrdiff_t j = j0; j < jend; ++j) {
+ outrow[j] += bias[j];
+ }
+ }
+ }
+}
+#endif
+
+void linear_f32(float *out, const float *in, const float *weight, const float *bias, const std::vector &shapes) {
+ const std::ptrdiff_t dimi = static_cast(shapes[0]);
+ const std::ptrdiff_t dimk = static_cast(shapes[1]);
+ const std::ptrdiff_t dimj = static_cast(shapes[2]);
+
+#pragma omp parallel for collapse(2) schedule(static)
+ for (std::ptrdiff_t i = 0; i < dimi; ++i) {
+ for (std::ptrdiff_t j0 = 0; j0 < dimj; j0 += OUTPUT_TILE) {
+ const float *xrow = in + i * dimk;
+ float *outrow = out + i * dimj;
+ const std::ptrdiff_t jend = std::min(j0 + OUTPUT_TILE, dimj);
+ std::ptrdiff_t j = j0;
+
+ for (; j + OUTPUT_UNROLL <= jend; j += OUTPUT_UNROLL) {
+ const float *w0 = weight + (j + 0) * dimk;
+ const float *w1 = weight + (j + 1) * dimk;
+ const float *w2 = weight + (j + 2) * dimk;
+ const float *w3 = weight + (j + 3) * dimk;
+
+ float acc0 = bias ? bias[j + 0] : 0.0f;
+ float acc1 = bias ? bias[j + 1] : 0.0f;
+ float acc2 = bias ? bias[j + 2] : 0.0f;
+ float acc3 = bias ? bias[j + 3] : 0.0f;
+
+ for (std::ptrdiff_t k0 = 0; k0 < dimk; k0 += REDUCTION_TILE) {
+ const std::ptrdiff_t kend = std::min(k0 + REDUCTION_TILE, dimk);
+ float part0 = 0.0f;
+ float part1 = 0.0f;
+ float part2 = 0.0f;
+ float part3 = 0.0f;
+
+#pragma omp simd reduction(+ : part0, part1, part2, part3)
+ for (std::ptrdiff_t k = k0; k < kend; ++k) {
+ const float x = xrow[k];
+ part0 += x * w0[k];
+ part1 += x * w1[k];
+ part2 += x * w2[k];
+ part3 += x * w3[k];
+ }
+
+ acc0 += part0;
+ acc1 += part1;
+ acc2 += part2;
+ acc3 += part3;
+ }
+
+ outrow[j + 0] = acc0;
+ outrow[j + 1] = acc1;
+ outrow[j + 2] = acc2;
+ outrow[j + 3] = acc3;
+ }
+
+ for (; j < jend; ++j) {
+ const float *wrow = weight + j * dimk;
+ float acc = bias ? bias[j] : 0.0f;
+
+ for (std::ptrdiff_t k0 = 0; k0 < dimk; k0 += REDUCTION_TILE) {
+ const std::ptrdiff_t kend = std::min(k0 + REDUCTION_TILE, dimk);
+ float part = 0.0f;
+
+#pragma omp simd reduction(+ : part)
+ for (std::ptrdiff_t k = k0; k < kend; ++k) {
+ part += xrow[k] * wrow[k];
+ }
+
+ acc += part;
+ }
+
+ outrow[j] = acc;
+ }
+ }
+ }
+}
+
+template
+void linear_generic(T *out, const T *in, const T *weight, const T *bias, const std::vector &shapes) {
+ const std::ptrdiff_t dimi = static_cast(shapes[0]);
+ const std::ptrdiff_t dimk = static_cast(shapes[1]);
+ const std::ptrdiff_t dimj = static_cast(shapes[2]);
+
+#pragma omp parallel for collapse(2) schedule(static)
+ for (std::ptrdiff_t i = 0; i < dimi; ++i) {
+ for (std::ptrdiff_t j = 0; j < dimj; ++j) {
+ float sum = 0.0f;
+ for (std::ptrdiff_t k = 0; k < dimk; ++k) {
+ sum += llaisys::utils::cast(in[i * dimk + k]) * llaisys::utils::cast(weight[j * dimk + k]);
+ }
+ if (bias != nullptr) {
+ sum += llaisys::utils::cast(bias[j]);
+ }
+ out[i * dimj + j] = llaisys::utils::cast(sum);
+ }
+ }
+}
+
+} // namespace
+
+namespace llaisys::ops::cpu {
+void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, llaisysDataType_t type, std::vector shapes){
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+#ifdef LLAISYS_USE_OPENBLAS
+ if (should_use_openblas(shapes)) {
+ return linear_f32_openblas(
+ reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight),
+ bias ? reinterpret_cast(bias) : nullptr,
+ shapes);
+ }
+#endif
+#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__))
+ if (has_avx2_fma()) {
+ if (should_use_gemm_kernel(shapes)) {
+ return linear_f32_avx2_gemm(
+ reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight),
+ bias ? reinterpret_cast(bias) : nullptr,
+ shapes);
+ }
+ return linear_f32_avx2_matvec(
+ reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight),
+ bias ? reinterpret_cast(bias) : nullptr,
+ shapes);
+ }
+#endif
+ return linear_f32(
+ reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight),
+ bias ? reinterpret_cast(bias) : nullptr,
+ shapes);
+ case LLAISYS_DTYPE_BF16:
+ return linear_generic(
+ reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight),
+ bias ? reinterpret_cast(bias) : nullptr,
+ shapes);
+ case LLAISYS_DTYPE_F16:
+ return linear_generic(
+ reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight),
+ bias ? reinterpret_cast(bias) : nullptr,
+ shapes);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp
new file mode 100644
index 00000000..a9a98ed0
--- /dev/null
+++ b/src/ops/linear/cpu/linear_cpu.hpp
@@ -0,0 +1,9 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+#include
+
+namespace llaisys::ops::cpu {
+ void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, llaisysDataType_t dtype, std::vector shapes);
+}
diff --git a/src/ops/linear/nvidia/linear_nvidia.cu b/src/ops/linear/nvidia/linear_nvidia.cu
new file mode 100644
index 00000000..154d9c03
--- /dev/null
+++ b/src/ops/linear/nvidia/linear_nvidia.cu
@@ -0,0 +1,87 @@
+#include "linear_nvidia.cuh"
+
+#include "../../nvidia/nvidia_common.cuh"
+
+namespace {
+
+template
+__global__ void add_bias_kernel(T *out, const T *bias, size_t dimi, size_t dimj) {
+ const size_t idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x);
+ const size_t total = dimi * dimj;
+ if (idx >= total) {
+ return;
+ }
+
+ const size_t col = idx % dimj;
+ const float out_v = llaisys::device::nvidia::cuda_utils::toFloat(out[idx]);
+ const float bias_v = llaisys::device::nvidia::cuda_utils::toFloat