diff --git a/README.md b/README.md index 456067c8..716e1ebf 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Welcome to LLAISYS +> Implementation summary, experiment results, and reproduction steps for the completed extension tasks are documented in [doc/implementation_report.md](doc/implementation_report.md). +

English中文 diff --git a/doc/implementation_report.md b/doc/implementation_report.md new file mode 100644 index 00000000..118018cd --- /dev/null +++ b/doc/implementation_report.md @@ -0,0 +1,495 @@ +# LLAISYS 扩展实现与实验报告 + +## 1. 报告说明 + +本报告对应仓库扩展任务的当前实现状态,重点记录以下内容: + +- 已完成或部分完成的项目及其对应实现 +- 关键功能与当前边界 +- 主要实验结果 +- 复现方式 + +本次工作围绕以下方向展开: + +- 项目 #1:CPU 优化 +- 项目 #2:Nvidia CUDA 后端 +- 项目 #3:聊天机器人与采样生成 +- 项目 #4:多用户推理服务 +- 项目 #5:分布式推理中的张量并行 +- 项目 #6:新增模型类型支持入口 + +## 2. 任务完成情况概览 + +| 项目 | 当前状态 | 说明 | +| --- | --- | --- | +| #1 CPU 优化 | 已完成 | CPU 主推理路径上的核心算子均已做专门优化 | +| #2 CUDA 集成 | 已完成一类平台 | 已完成 Nvidia 平台;未继续实现其他 CUDA/CUDA-ish 平台 | +| #3 聊天机器人 | 已完成 | 采样、流式输出、CLI、会话管理均已实现 | +| #4 多用户推理服务 | 已完成可运行版本 | 已实现请求池、调度线程、worker 池、会话复用;尚未实现真正的后端 batched decode | +| #5 分布式推理 | 已完成可运行版本 | 已实现 Nvidia/Qwen2 张量并行;当前切分策略对 `tp_size` 有约束 | +| #6 新模型支持 | 已完成基础支持入口 | 已支持 `qwen2`、`llama`、`mistral` 的模型识别与创建入口 | + +## 3. 各项目实现情况 + +### 3.1 项目 #1:Optimize LLAISYS for CPU + +#### 3.1.1 算子优化 + +CPU 路径上已经完成一轮专门优化的算子包括: + +- `add` +- `argmax` +- `embedding` +- `linear` +- `rearrange` +- `rms_norm` +- `rope` +- `self_attention` +- `swiglu` +- `sample` + +其中,`linear` 是重点优化对象,涉及的主要工作包括: + +- OpenMP 并行 +- 多级分块 +- AVX2/FMA 快路径 +- 对部分大规模 `f32 GEMM` 接入 OpenBLAS +- 根据矩阵形状启发式选择内核,而非无条件走 BLAS + +相关实现分布在以下路径: + +- `src/ops/linear/cpu/linear_cpu.cpp` +- `src/ops/add/cpu/add_cpu.cpp` +- `src/ops/argmax/cpu/argmax_cpu.cpp` +- `src/ops/embedding/cpu/embedding_cpu.cpp` +- `src/ops/rearrange/cpu/rearrange_cpu.cpp` +- `src/ops/rms_norm/cpu/rms_norm_cpu.cpp` +- `src/ops/rope/cpu/rope_cpu.cpp` +- `src/ops/self_attention/cpu/self_attention_cpu.cpp` +- `src/ops/swiglu/cpu/swiglu_cpu.cpp` +- `src/ops/sample/cpu/sample_cpu.cpp` + +#### 3.1.2 模型推理路径优化 + +在 Qwen2 推理路径中,还增加了 buffer 复用机制,以降低频繁创建临时张量带来的额外开销。 + +已完成的优化包括: + +- decode 路径 buffer 复用 +- prefill 路径 scratch buffer 复用 + +相关实现位于: + +- `src/models/qwen2/qwen2.cpp` + +### 3.2 项目 #2:Integrate CUDA into LLAISYS + +#### 3.2.1 Nvidia Runtime API + +已完成 Nvidia 平台 Runtime API 的实现,并打通了构建流程。 + +相关文件包括: + +- `xmake/nvidia.lua` +- `xmake.lua` +- `src/device/nvidia/cuda_utils.cuh` +- `src/device/nvidia/nvidia_runtime_api.cu` + +同时修复了多设备 Runtime 在错误 device 上创建 CUDA stream 的问题,保证多 GPU 环境下每个 Runtime 使用与自身 device 对应的 stream。 + +相关实现位于: + +- `src/core/runtime/runtime.cpp` + +#### 3.2.2 Nvidia 算子实现 + +CUDA 算子没有集中放在单一 `.cu` 文件中,而是按照算子目录拆分实现。当前已完成的 Nvidia 算子包括: + +- `src/ops/add/nvidia` +- `src/ops/argmax/nvidia` +- `src/ops/embedding/nvidia` +- `src/ops/linear/nvidia` +- `src/ops/rearrange/nvidia` +- `src/ops/rms_norm/nvidia` +- `src/ops/rope/nvidia` +- `src/ops/sample/nvidia` +- `src/ops/self_attention/nvidia` +- `src/ops/swiglu/nvidia` +- `src/ops/nvidia/nvidia_common.cuh` + +#### 3.2.3 CUDA 推理路径 + +Qwen2 的 CUDA 推理路径已经打通,能够在 Nvidia 设备上完成真实模型推理与采样生成。 + +相关路径: + +- `src/models/qwen2/qwen2.cpp` +- `python/llaisys/models/qwen2.py` + +当前未继续实现第二个 CUDA/CUDA-ish 平台,因此项目 #2 的完成范围限定为 Nvidia。 + +### 3.3 项目 #3:Build an AI chatbot + +#### 3.3.1 随机采样 + +新增 `sample` 算子,并接入模型生成路径。当前支持: + +- `temperature` +- `top-k` +- `top-p` +- 当 `temperature <= 0` 或 `top_k == 1` 时退化为 argmax + +相关路径: + +- `src/ops/sample/op.cpp` +- `src/ops/sample/cpu/sample_cpu.cpp` +- `src/ops/sample/nvidia/sample_nvidia.cu` +- `python/llaisys/ops.py` + +#### 3.3.2 聊天服务与 CLI + +已实现 OpenAI 风格的聊天接口: + +- `POST /v1/chat/completions` + +并支持: + +- SSE 流式输出 +- 命令行交互式聊天 +- 非流式请求 + +相关路径: + +- `python/llaisys/chat_server.py` +- `python/llaisys/chat_cli.py` + +#### 3.3.3 会话管理 + +当前已经实现基础会话管理能力,包括: + +- 会话创建、查询、更新、删除 +- 会话历史查看 +- 重生成上一次助手回复 +- 编辑历史用户消息并重新生成 +- 基于 `truncate` 的 KV 状态回退 + +服务端接口包括: + +- `GET /v1/sessions` +- `POST /v1/sessions` +- `GET /v1/sessions/{session_id}` +- `PUT /v1/sessions/{session_id}` +- `DELETE /v1/sessions/{session_id}` + +CLI 支持: + +- `/new` +- `/list` +- `/switch` +- `/history` +- `/regen` +- `/edit` +- `/delete` +- `/session` +- `/help` + +### 3.4 项目 #4:Multi-user Inference Service + +#### 3.4.1 服务层调度 + +当前多用户服务采用三层结构: + +- HTTP 层:参数校验、建请求对象、返回响应 +- 调度层:后台线程从请求池取出请求并分发 +- 执行层:worker 持有模型实例并实际执行生成 + +已实现的能力包括: + +- 请求池 +- 调度线程 +- worker 池 +- 微批次出队 +- 迭代级连续调度 +- 会话与 worker 亲和 +- 同一会话前缀复用 +- 同一 `session_id` 并发保护 + +相关路径: + +- `python/llaisys/chat_server.py` +- `test/test_chat_api.py` +- `test/chat_test_utils.py` + +#### 3.4.2 当前边界 + +项目 #4 当前已经具备“服务层 continuous batching”的核心形态,但尚未实现以下能力: + +- 单个后端模型对象内的真正 batched decode +- batched KV-cache +- 多请求单轮合并后的统一 batch forward +- batched matmul 路径 + +也就是说,当前调度器会以迭代级粒度推进请求,但单个 worker 一次仍然只执行一个请求的一步生成。 + +### 3.5 项目 #5:Distributed Inference + +#### 3.5.1 张量并行实现 + +当前已经实现 Nvidia 平台上的 Qwen2 张量并行版本,使用 `torch.distributed` 和 NCCL。 + +模型入口位于: + +- `python/llaisys/models/tensor_parallel.py` +- `python/llaisys/models/__init__.py` + +支持的能力包括: + +- `reset` +- `truncate` +- `generate_next` +- `generate` +- `stream_generate` + +服务端已支持以下参数: + +- `--tp-size` +- `--tp-device-ids` + +#### 3.5.2 当前切分策略与限制 + +当前实现采用均匀切分策略: + +- `Q heads` 均匀切分 +- `KV heads` 均匀切分 +- `MLP intermediate` 均匀切分 + +因此,`tp_size` 需要同时满足这些相关维度的均匀切分要求。对当前使用的 `DeepSeek-R1-Distill-Qwen-1.5B` 而言: + +- `num_attention_heads = 12` +- `num_key_value_heads = 2` +- `intermediate_size = 8960` + +在当前实现下,允许的 `tp_size` 为: + +- `1` +- `2` + +因此,两卡张量并行可以正常工作,四卡在当前实现下不支持。 + +#### 3.5.3 当前边界 + +当前项目 #5 尚不包括: + +- MPI/CPU 分布式推理 +- 更复杂的 KV 复制式张量并行 +- 针对低 `num_key_value_heads` 模型的更高卡数切分策略 + +### 3.6 项目 #6:Support New Models + +当前已完成统一模型创建入口,并支持自动识别 `config.json` 中的 `model_type`。 + +目前支持的模型类型包括: + +- `qwen2` +- `llama` +- `mistral` + +相关路径: + +- `python/llaisys/models/__init__.py` + +该部分工作的重点是先统一接入路径与创建流程。当前真实完整验证仍以 Qwen2 为主。 + +## 4. 关键实验结果 + +### 4.1 CPU `linear` 初始基线 + +在最早的 CPU `linear` profile 中,观测到如下基线: + +- 矩阵形状:`(64, 512) x (512, 512)` +- PyTorch:约 `0.269 ms` +- LLAISYS:约 `25.427 ms` + +### 4.2 CPU `linear` 优化后结果 + +在 CPU 优化后,小规模 `linear` 的结果达到过: + +- 矩阵形状:`(64, 512) x (512, 512)` +- LLAISYS:约 `0.0568 ms` + +### 4.3 接近真实模型形状的 `linear` + +对更接近真实模型的矩阵形状,测得: + +- `out=(198, 8960), x=(198, 1536), w=(8960, 1536)` + - Torch:`4.88102 ms` + - LLAISYS:`10.06528 ms` + +- `out=(198, 1536), x=(198, 8960), w=(1536, 8960)` + - Torch:`1.96279 ms` + - LLAISYS:`4.23397 ms` + +### 4.4 CPU 真实模型推理 + +在 `DeepSeek-R1-Distill-Qwen-1.5B` 上,短生成曾测得: + +- 生成 8 个 token + - `short_generate_s=0.600684` + - `short_tok_per_s=13.318142` + +官方回归测试的一次结果: + +- Hugging Face:约 `0.71 s` +- LLAISYS CPU:约 `0.72 s` +- token 序列:完全一致 + +### 4.5 CUDA 真实模型推理 + +单卡 Nvidia 回归测试的一次结果: + +- Hugging Face:约 `0.79 s` +- LLAISYS Nvidia:约 `0.14 s` +- token 序列:完全一致 + +### 4.6 两卡张量并行 + +两卡张量并行 smoke test 已通过,标准 prompt 的生成 token 与期望结果一致。 + +## 5. 复现方式 + +### 5.1 构建 + +启用 CPU BLAS 与 Nvidia: + +```bash +xmake f --cpu-blas=y --openblas-prefix="$CONDA_PREFIX" --nv-gpu=y -cv +xmake +xmake install +``` + +若不启用 BLAS: + +```bash +xmake f --nv-gpu=y -cv +xmake +xmake install +``` + +### 5.2 CPU 回归测试 + +```bash +conda run -n llaisys env PYTHONPATH=python:test python test/test_runtime.py --device cpu +conda run -n llaisys env PYTHONPATH=python:test python test/ops/linear.py --device cpu +conda run -n llaisys env PYTHONPATH=python:test python test/ops/sample.py --device cpu +conda run -n llaisys env PYTHONPATH=python:test python test/test_generate_sampling.py --device cpu --model [模型目录] +conda run -n llaisys env PYTHONPATH=python:test python test/test_infer.py --device cpu --model [模型目录] --test --max_steps 8 +``` + +### 5.3 Nvidia 回归测试 + +```bash +conda run -n llaisys env PYTHONPATH=python:test python test/test_runtime.py --device nvidia +conda run -n llaisys env PYTHONPATH=python:test python test/ops/add.py --device nvidia +conda run -n llaisys env PYTHONPATH=python:test python test/ops/argmax.py --device nvidia +conda run -n llaisys env PYTHONPATH=python:test python test/ops/embedding.py --device nvidia +conda run -n llaisys env PYTHONPATH=python:test python test/ops/linear.py --device nvidia +conda run -n llaisys env PYTHONPATH=python:test python test/ops/rms_norm.py --device nvidia +conda run -n llaisys env PYTHONPATH=python:test python test/ops/rope.py --device nvidia +conda run -n llaisys env PYTHONPATH=python:test python test/ops/swiglu.py --device nvidia +conda run -n llaisys env PYTHONPATH=python:test python test/ops/self_attention.py --device nvidia +conda run -n llaisys env PYTHONPATH=python:test python test/test_generate_sampling.py --device nvidia --model [模型目录] +conda run -n llaisys env PYTHONPATH=python:test python test/test_infer.py --device nvidia --model [模型目录] --test --max_steps 8 +``` + +### 5.4 聊天与会话管理测试 + +```bash +conda run -n llaisys env PYTHONPATH=python:test python test/test_chat_api.py +conda run -n llaisys env PYTHONPATH=python:test python test/test_chat_cli.py +``` + +### 5.5 张量并行测试 + +```bash +conda run -n llaisys env PYTHONPATH=python:test python test/test_tensor_parallel.py --model [模型目录] --tp-size 2 --max-steps 8 +``` + +### 5.6 启动单卡 CUDA 聊天服务 + +```bash +PYTHONPATH=python python -m llaisys.chat_server \ + --model-path [模型目录] \ + --device nvidia \ + --device-id 0 \ + --host 127.0.0.1 \ + --port 8000 \ + --num-workers 1 \ + --max-batch-size 1 \ + --batch-wait-ms 0 +``` + +### 5.7 启动两卡张量并行聊天服务 + +```bash +PYTHONPATH=python python -m llaisys.chat_server \ + --model-path [模型目录] \ + --device nvidia \ + --device-id 0 \ + --tp-size 2 \ + --tp-device-ids 0,1 \ + --host 127.0.0.1 \ + --port 8000 \ + --num-workers 1 \ + --max-batch-size 2 \ + --batch-wait-ms 5 +``` + +### 5.8 交互式 CLI + +```bash +PYTHONPATH=python python -m llaisys.chat_cli \ + --url http://127.0.0.1:8000 \ + --model llaisys-qwen2 \ + --max-tokens 512 +``` + +### 5.9 curl 示例 + +```bash +curl -N http://127.0.0.1:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "llaisys-qwen2", + "session_id": "demo", + "messages": [{"role": "user", "content": "请解释操作系统的页表机制"}], + "max_tokens": 512, + "temperature": 0.8, + "top_p": 0.8, + "top_k": 50, + "stream": true + }' +``` + +服务统计接口: + +```bash +curl http://127.0.0.1:8000/v1/service/stats +``` + +## 6. 当前边界说明 + +当前仓库已经从基础的 CPU/Qwen2 推理项目扩展为: + +- CPU 优化版本 +- Nvidia CUDA 版本 +- 支持采样、流式输出和会话管理的聊天服务 +- 支持多用户请求池与服务层连续调度的推理服务 +- 支持两卡张量并行的 Nvidia 推理版本 +- 支持多模型类型统一创建入口的版本 + +尚未纳入当前完成范围的内容主要包括: + +- 第二个 CUDA/CUDA-ish 平台 +- 真正后端层面的 batched decode / batched KV-cache +- 更复杂的 KV 复制式张量并行 +- MPI/CPU 分布式推理 diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h index 7054626d..37936d04 100644 --- a/include/llaisys/models/qwen2.h +++ b/include/llaisys/models/qwen2.h @@ -14,8 +14,8 @@ __C { struct LlaisysQwen2Weights { llaisysTensor_t in_embed; llaisysTensor_t out_embed; - llaisysTensor_t out_norm_w; // a.k.a. model.norm.weight - llaisysTensor_t *attn_norm_w; // a.k.a. input_layernorm.weight + llaisysTensor_t out_norm_w; + llaisysTensor_t *attn_norm_w; llaisysTensor_t *attn_q_w; llaisysTensor_t *attn_q_b; llaisysTensor_t *attn_k_w; @@ -23,7 +23,7 @@ __C { llaisysTensor_t *attn_v_w; llaisysTensor_t *attn_v_b; llaisysTensor_t *attn_o_w; - llaisysTensor_t *mlp_norm_w; // a.k.a. post_attention_layernorm.weight + llaisysTensor_t *mlp_norm_w; llaisysTensor_t *mlp_gate_w; llaisysTensor_t *mlp_up_w; llaisysTensor_t *mlp_down_w; @@ -37,6 +37,12 @@ __C { __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model); + __export void llaisysQwen2ModelReset(struct LlaisysQwen2Model * model); + + __export void llaisysQwen2ModelTruncate(struct LlaisysQwen2Model * model, size_t position); + __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); + + __export int64_t llaisysQwen2ModelGenerateNext(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken, int top_k, float top_p, float temperature); } #endif // LLAISYS_MODELS_QWEN2_H diff --git a/include/llaisys/ops.h b/include/llaisys/ops.h index ddb3be24..18e1b356 100644 --- a/include/llaisys/ops.h +++ b/include/llaisys/ops.h @@ -8,6 +8,7 @@ __C { __export void llaisysArgmax(llaisysTensor_t max_idx, llaisysTensor_t max_val, llaisysTensor_t vals); __export void llaisysEmbedding(llaisysTensor_t out, llaisysTensor_t index, llaisysTensor_t weight); __export void llaisysLinear(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, llaisysTensor_t bias); + __export int64_t llaisysSample(llaisysTensor_t logits, int top_k, float top_p, float temperature); __export void llaisysRearrange(llaisysTensor_t out, llaisysTensor_t in); __export void llaisysRmsNorm(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, float eps); __export void llaisysROPE(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t pos_ids, float theta); diff --git a/python/llaisys/chat_cli.py b/python/llaisys/chat_cli.py new file mode 100644 index 00000000..f953a780 --- /dev/null +++ b/python/llaisys/chat_cli.py @@ -0,0 +1,414 @@ +import argparse +import json +import sys +import uuid +from typing import Any, Dict, List, Optional, Tuple + +try: + import httpx +except ImportError as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "HTTP client support is optional. Install with `pip install ./python[server]`." + ) from exc + + +def _build_payload( + model: str, + messages: List[Dict[str, str]], + max_tokens: int, + top_k: int, + top_p: float, + temperature: float, + stream: bool, + session_id: str, +) -> Dict[str, Any]: + return { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "stream": stream, + "session_id": session_id, + } + + +def _stream_completion(client: httpx.Client, url: str, payload: Dict[str, Any]) -> str: + parts: List[str] = [] + with client.stream("POST", url, json=payload, timeout=None) as response: + response.raise_for_status() + for line in response.iter_lines(): + if not line or not line.startswith("data: "): + continue + data = line[6:] + if data == "[DONE]": + break + event = json.loads(data) + delta = event["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + print(content, end="", flush=True) + parts.append(content) + print() + return "".join(parts) + + +def _non_stream_completion(client: httpx.Client, url: str, payload: Dict[str, Any]) -> str: + response = client.post(url, json=payload, timeout=None) + response.raise_for_status() + content = response.json()["choices"][0]["message"]["content"] + print(content) + return content + + +def _request_json(client: httpx.Client, method: str, url: str, payload: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + response = client.request(method, url, json=payload, timeout=None) + response.raise_for_status() + if not response.content: + return {} + return response.json() + + +def _create_session( + client: httpx.Client, + base_url: str, + session_id: str, + messages: List[Dict[str, str]], +) -> Dict[str, Any]: + payload: Dict[str, Any] = {"session_id": session_id} + if messages: + payload["messages"] = messages + return _request_json(client, "POST", f"{base_url}/v1/sessions", payload) + + +def _get_session(client: httpx.Client, base_url: str, session_id: str) -> Dict[str, Any]: + return _request_json(client, "GET", f"{base_url}/v1/sessions/{session_id}") + + +def _list_sessions(client: httpx.Client, base_url: str) -> List[Dict[str, Any]]: + return _request_json(client, "GET", f"{base_url}/v1/sessions").get("data", []) + + +def _delete_session(client: httpx.Client, base_url: str, session_id: str) -> Dict[str, Any]: + return _request_json(client, "DELETE", f"{base_url}/v1/sessions/{session_id}") + + +def _ensure_session( + client: httpx.Client, + base_url: str, + session_id: str, + seed_messages: List[Dict[str, str]], +) -> Tuple[str, List[Dict[str, str]]]: + try: + payload = _get_session(client, base_url, session_id) + return session_id, payload.get("messages", []) + except httpx.HTTPStatusError as exc: + if exc.response.status_code != 404: + raise + _create_session(client, base_url, session_id, seed_messages) + return session_id, list(seed_messages) + + +def _print_history(messages: List[Dict[str, str]]) -> None: + if not messages: + print("(empty session)") + return + for idx, message in enumerate(messages, start=1): + print(f"{idx}. {message['role']}: {message['content']}") + + +def _print_sessions(sessions: List[Dict[str, Any]], current_session_id: Optional[str]) -> None: + if not sessions: + print("(no sessions)") + return + for session in sessions: + marker = "*" if session["id"] == current_session_id else " " + print( + f"{marker} {session['id']} " + f"{session['title']} " + f"messages={session['message_count']} " + f"tokens={session['token_count']}" + ) + + +def _print_help() -> None: + print("/help Show this help") + print("/session Show the current session id") + print("/new [session_id] Start a new session") + print("/list List sessions on the server") + print("/switch Switch to another session") + print("/history Show the current conversation") + print("/regen Regenerate the last assistant reply") + print("/edit Edit a past user message and regenerate from there") + print("/delete [session_id] Delete a session") + print("/quit Exit") + + +def _complete( + client: httpx.Client, + endpoint: str, + model: str, + messages: List[Dict[str, str]], + max_tokens: int, + top_k: int, + top_p: float, + temperature: float, + stream: bool, + session_id: str, +) -> str: + payload = _build_payload( + model, + messages, + max_tokens, + top_k, + top_p, + temperature, + stream, + session_id, + ) + return ( + _stream_completion(client, endpoint, payload) + if stream + else _non_stream_completion(client, endpoint, payload) + ) + + +def _handle_command( + command_line: str, + *, + client: httpx.Client, + base_url: str, + endpoint: str, + model: str, + max_tokens: int, + top_k: int, + top_p: float, + temperature: float, + stream: bool, + session_id: str, + messages: List[Dict[str, str]], + initial_messages: List[Dict[str, str]], +) -> Tuple[str, List[Dict[str, str]], bool]: + command, _, rest = command_line.partition(" ") + command = command.lower() + rest = rest.strip() + + if command == "/help": + _print_help() + return session_id, messages, False + + if command == "/session": + print(session_id) + return session_id, messages, False + + if command == "/list": + _print_sessions(_list_sessions(client, base_url), session_id) + return session_id, messages, False + + if command == "/history": + _print_history(messages) + return session_id, messages, False + + if command == "/new": + next_session_id = rest or f"session-{uuid.uuid4().hex}" + _create_session(client, base_url, next_session_id, initial_messages) + print(f"Switched to {next_session_id}") + return next_session_id, list(initial_messages), False + + if command == "/switch": + if not rest: + print("Usage: /switch ") + return session_id, messages, False + payload = _get_session(client, base_url, rest) + print(f"Switched to {rest}") + return rest, payload.get("messages", []), False + + if command == "/delete": + target_session_id = rest or session_id + _delete_session(client, base_url, target_session_id) + print(f"Deleted {target_session_id}") + if target_session_id == session_id: + next_session_id = f"session-{uuid.uuid4().hex}" + _create_session(client, base_url, next_session_id, initial_messages) + print(f"Switched to {next_session_id}") + return next_session_id, list(initial_messages), False + return session_id, messages, False + + if command == "/regen": + if not messages or messages[-1]["role"] != "assistant": + print("No assistant reply to regenerate.") + return session_id, messages, False + next_messages = messages[:-1] + reply = _complete( + client, + endpoint, + model, + next_messages, + max_tokens, + top_k, + top_p, + temperature, + stream, + session_id, + ) + next_messages = next_messages + [{"role": "assistant", "content": reply}] + return session_id, next_messages, False + + if command == "/edit": + parts = rest.split(" ", 1) + if len(parts) != 2: + print("Usage: /edit ") + return session_id, messages, False + try: + index = int(parts[0]) + except ValueError: + print("Message index must be an integer.") + return session_id, messages, False + if index < 1 or index > len(messages): + print("Message index out of range.") + return session_id, messages, False + if messages[index - 1]["role"] != "user": + print("Only user messages can be edited.") + return session_id, messages, False + + next_messages = [{"role": message["role"], "content": message["content"]} for message in messages[:index]] + next_messages[-1]["content"] = parts[1] + reply = _complete( + client, + endpoint, + model, + next_messages, + max_tokens, + top_k, + top_p, + temperature, + stream, + session_id, + ) + next_messages.append({"role": "assistant", "content": reply}) + return session_id, next_messages, False + + if command in {"/quit", "/exit"}: + return session_id, messages, True + + print("Unknown command. Use /help to list session commands.") + return session_id, messages, False + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--url", default="http://127.0.0.1:8000", type=str) + parser.add_argument("--model", default="llaisys-qwen2", type=str) + parser.add_argument("--prompt", default=None, type=str) + parser.add_argument("--system-prompt", default=None, type=str) + parser.add_argument("--max-tokens", default=128, type=int) + parser.add_argument("--top-k", default=50, type=int) + parser.add_argument("--top-p", default=0.8, type=float) + parser.add_argument("--temperature", default=0.8, type=float) + parser.add_argument("--no-stream", action="store_true") + parser.add_argument("--session-id", default=None, type=str) + parser.add_argument("--list-sessions", action="store_true") + parser.add_argument("--show-session", action="store_true") + parser.add_argument("--create-session", action="store_true") + parser.add_argument("--delete-session", action="store_true") + args = parser.parse_args() + + base_url = args.url.rstrip("/") + endpoint = f"{base_url}/v1/chat/completions" + stream = not args.no_stream + session_id = args.session_id or f"session-{uuid.uuid4().hex}" + initial_messages: List[Dict[str, str]] = [] + if args.system_prompt: + initial_messages.append({"role": "system", "content": args.system_prompt}) + + if args.prompt is None and not sys.stdin.isatty() and not any( + [args.list_sessions, args.show_session, args.create_session, args.delete_session] + ): + raise RuntimeError( + "Interactive mode requires a TTY on stdin. " + "Run this command directly inside the activated `llaisys` environment, " + "or pass `--prompt` for one-shot mode." + ) + + with httpx.Client(trust_env=False) as client: + if args.list_sessions: + _print_sessions(_list_sessions(client, base_url), session_id) + return + + if args.create_session: + payload = _create_session(client, base_url, session_id, initial_messages) + print(payload["id"]) + return + + if args.show_session: + payload = _get_session(client, base_url, session_id) + _print_history(payload.get("messages", [])) + return + + if args.delete_session: + payload = _delete_session(client, base_url, session_id) + print(json.dumps(payload, ensure_ascii=False)) + return + + session_id, messages = _ensure_session(client, base_url, session_id, initial_messages) + prompt = args.prompt + + while True: + if prompt is None: + try: + prompt = input("user> ").strip() + except EOFError: + break + except KeyboardInterrupt: + print() + break + + if not prompt: + break + + if prompt.startswith("/"): + session_id, messages, should_exit = _handle_command( + prompt, + client=client, + base_url=base_url, + endpoint=endpoint, + model=args.model, + max_tokens=args.max_tokens, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + stream=stream, + session_id=session_id, + messages=messages, + initial_messages=initial_messages, + ) + if should_exit: + break + if args.prompt is not None: + break + prompt = None + continue + + messages.append({"role": "user", "content": prompt}) + reply = _complete( + client, + endpoint, + args.model, + messages, + args.max_tokens, + args.top_k, + args.top_p, + args.temperature, + stream, + session_id, + ) + messages.append({"role": "assistant", "content": reply}) + + if args.prompt is not None: + break + prompt = None + + +if __name__ == "__main__": + main() diff --git a/python/llaisys/chat_server.py b/python/llaisys/chat_server.py new file mode 100644 index 00000000..8867c724 --- /dev/null +++ b/python/llaisys/chat_server.py @@ -0,0 +1,1099 @@ +import argparse +import json +import queue +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Deque, Dict, Iterable, List, Optional + +import llaisys +from transformers import AutoTokenizer + +try: + from fastapi import Body, FastAPI, HTTPException + from fastapi.responses import JSONResponse, StreamingResponse +except ImportError as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "FastAPI support is optional. Install with `pip install ./python[server]`." + ) from exc + + +_STREAM_END = object() + + +def _device_from_name(device_name: str) -> llaisys.DeviceType: + if device_name == "cpu": + return llaisys.DeviceType.CPU + if device_name == "nvidia": + return llaisys.DeviceType.NVIDIA + raise ValueError(f"Unsupported device: {device_name}") + + +def _normalize_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, str]]: + normalized = [] + for message in messages: + if not isinstance(message, dict): + raise ValueError("Each message must be an object") + role = str(message.get("role", "")).strip() + content = message.get("content", "") + if not role: + raise ValueError("Each message must include a role") + if isinstance(content, list): + text_parts = [str(part.get("text", "")) for part in content if isinstance(part, dict)] + content = "".join(text_parts) + normalized.append({"role": role, "content": str(content)}) + if not normalized: + raise ValueError("messages must not be empty") + return normalized + + +def _clone_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]: + return [{"role": message["role"], "content": message["content"]} for message in messages] + + +def _build_input_ids(tokenizer, messages: List[Dict[str, Any]]) -> List[int]: + prompt = tokenizer.apply_chat_template( + conversation=_normalize_messages(messages), + add_generation_prompt=True, + tokenize=False, + ) + return list(tokenizer.encode(prompt)) + + +def _finish_reason(end_token: Optional[int], generated_ids: List[int], max_tokens: int) -> str: + if generated_ids and generated_ids[-1] == end_token: + return "stop" + if len(generated_ids) >= max_tokens: + return "length" + return "stop" + + +def _usage(prompt_tokens: int, completion_tokens: int) -> Dict[str, int]: + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + + +def _sse_payload(payload: Dict[str, Any]) -> str: + return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + + +def _resolve_session_id(payload: Dict[str, Any]) -> str: + session_id = payload.get("session_id") + return str(session_id) if session_id else f"session-{uuid.uuid4().hex}" + + +def _session_title(messages: List[Dict[str, str]]) -> str: + for message in messages: + if message["role"] == "user" and message["content"]: + title = message["content"].strip().replace("\n", " ") + return title[:64] if title else "Untitled Session" + return "Untitled Session" + + +def _longest_common_prefix_len(lhs: List[int], rhs: List[int]) -> int: + limit = min(len(lhs), len(rhs)) + idx = 0 + while idx < limit and lhs[idx] == rhs[idx]: + idx += 1 + return idx + + +def _decode_text_delta(tokenizer, generated_ids: List[int], previous_text: str) -> tuple[str, str]: + decoded = tokenizer.decode( + generated_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + if decoded.startswith(previous_text): + return decoded[len(previous_text):], decoded + return decoded, decoded + + +def _model_generate_next( + model: Any, + inputs: List[int], + *, + top_k: int, + top_p: float, + temperature: float, + reset_state: bool, +) -> int: + generate_next = getattr(model, "generate_next", None) + if callable(generate_next): + return int( + generate_next( + inputs, + top_k=top_k, + top_p=top_p, + temperature=temperature, + reset_state=reset_state, + ) + ) + + if reset_state and hasattr(model, "reset"): + model.reset() + + private_generate_next = getattr(model, "_generate_next", None) + if callable(private_generate_next): + return int( + private_generate_next( + inputs, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + ) + + output_ids = model.generate( + inputs, + max_new_tokens=1, + top_k=top_k, + top_p=top_p, + temperature=temperature, + reset_state=reset_state, + ) + return int(output_ids[-1]) + + +@dataclass +class SessionState: + session_id: str + messages: List[Dict[str, str]] + token_ids: List[int] + created_at: float + updated_at: float + worker_id: Optional[int] = None + + +@dataclass +class ReusePlan: + reset_state: bool + generation_inputs: List[int] + reused_tokens: int + truncate_to: Optional[int] + + +@dataclass +class CompletionRequest: + completion_id: str + session_id: str + messages: List[Dict[str, str]] + input_ids: List[int] + max_tokens: int + top_k: int + top_p: float + temperature: float + stream: bool + created: int + request_model: str + result_queue: "queue.Queue[Any]" + enqueued_at: float + batch_id: Optional[int] = None + first_batch_id: Optional[int] = None + generated_ids: List[int] = field(default_factory=list) + last_text: str = "" + cache_reused_tokens: int = 0 + started: bool = False + prompt_inputs: List[int] = field(default_factory=list) + reset_state: bool = True + steps_dispatched: int = 0 + assigned_worker_id: Optional[int] = None + first_dispatched_at: Optional[float] = None + finished_at: Optional[float] = None + + +@dataclass +class WorkerState: + worker_id: int + model: Any + job_queue: "queue.Queue[Optional[CompletionRequest]]" = field(default_factory=queue.Queue) + thread: Optional[threading.Thread] = None + busy: bool = False + current_request_id: Optional[str] = None + active_session_id: Optional[str] = None + active_token_ids: List[int] = field(default_factory=list) + needs_reset: bool = True + total_requests: int = 0 + total_tokens_generated: int = 0 + total_steps: int = 0 + last_used_at: float = 0.0 + + +@dataclass +class ServiceState: + tokenizer: Any + model_name: str + workers: List[WorkerState] + max_batch_size: int + batch_wait_s: float + sessions: Dict[str, SessionState] = field(default_factory=dict) + session_inflight: Dict[str, int] = field(default_factory=dict) + pending_requests: Deque[CompletionRequest] = field(default_factory=deque) + condition: threading.Condition = field(default_factory=threading.Condition) + shutdown: bool = False + next_batch_id: int = 1 + scheduled_batches: int = 0 + max_observed_batch_size: int = 0 + completed_requests: int = 0 + failed_requests: int = 0 + total_enqueued: int = 0 + total_generated_tokens: int = 0 + requeued_requests: int = 0 + cache_reuse_hits: int = 0 + total_queue_wait_s: float = 0.0 + total_request_time_s: float = 0.0 + scheduler_thread: Optional[threading.Thread] = None + + +def _session_payload( + state: SessionState, + *, + include_messages: bool = False, + inflight: int = 0, +) -> Dict[str, Any]: + payload = { + "id": state.session_id, + "object": "session", + "title": _session_title(state.messages), + "created": int(state.created_at), + "updated": int(state.updated_at), + "message_count": len(state.messages), + "token_count": len(state.token_ids), + "worker_id": state.worker_id, + "inflight": inflight, + } + if include_messages: + payload["messages"] = _clone_messages(state.messages) + return payload + + +def _prepare_worker_reuse(worker: WorkerState, session_id: str, input_ids: List[int]) -> ReusePlan: + if worker.needs_reset or worker.active_session_id != session_id or not worker.active_token_ids: + return ReusePlan(reset_state=True, generation_inputs=input_ids, reused_tokens=0, truncate_to=None) + + prefix_len = _longest_common_prefix_len(worker.active_token_ids, input_ids) + if prefix_len <= 0: + return ReusePlan(reset_state=True, generation_inputs=input_ids, reused_tokens=0, truncate_to=None) + + reusable_prefix = prefix_len + if reusable_prefix >= len(input_ids): + reusable_prefix = len(input_ids) - 1 + + if reusable_prefix <= 0: + return ReusePlan(reset_state=True, generation_inputs=input_ids, reused_tokens=0, truncate_to=None) + + truncate_to = reusable_prefix if reusable_prefix < len(worker.active_token_ids) else None + return ReusePlan( + reset_state=False, + generation_inputs=input_ids[reusable_prefix:], + reused_tokens=reusable_prefix, + truncate_to=truncate_to, + ) + + +def _detach_worker_session_locked(service: ServiceState, worker: WorkerState) -> None: + if worker.active_session_id is None: + worker.active_token_ids = [] + worker.needs_reset = True + return + + state = service.sessions.get(worker.active_session_id) + if state is not None and state.worker_id == worker.worker_id: + state.worker_id = None + worker.active_session_id = None + worker.active_token_ids = [] + worker.needs_reset = True + + +def _clear_session_cache_locked(service: ServiceState, session_id: str) -> None: + state = service.sessions.get(session_id) + if state is None: + return + worker_id = state.worker_id + state.worker_id = None + state.token_ids = [] + if worker_id is None: + return + worker = service.workers[worker_id] + if worker.active_session_id == session_id: + worker.active_session_id = None + worker.active_token_ids = [] + worker.needs_reset = True + + +def _decrement_inflight_locked(service: ServiceState, session_id: str) -> None: + remaining = service.session_inflight.get(session_id, 0) - 1 + if remaining > 0: + service.session_inflight[session_id] = remaining + else: + service.session_inflight.pop(session_id, None) + + +def _commit_session_locked( + service: ServiceState, + worker: WorkerState, + session_id: str, + messages: List[Dict[str, str]], + cached_tokens: List[int], +) -> None: + now = time.time() + previous = service.sessions.get(session_id) + created_at = previous.created_at if previous is not None else now + service.sessions[session_id] = SessionState( + session_id=session_id, + messages=_clone_messages(messages), + token_ids=list(cached_tokens), + created_at=created_at, + updated_at=now, + worker_id=worker.worker_id, + ) + worker.active_session_id = session_id + worker.active_token_ids = list(cached_tokens) + worker.needs_reset = False + + +def _select_worker_locked( + service: ServiceState, + request: CompletionRequest, + idle_workers: Dict[int, WorkerState], + pinned_session_ids: set[str], +) -> Optional[WorkerState]: + session = service.sessions.get(request.session_id) + if session is not None and session.worker_id in idle_workers: + return idle_workers[session.worker_id] + + for worker in idle_workers.values(): + if worker.active_session_id == request.session_id and not worker.needs_reset: + return worker + + never_used = [worker for worker in idle_workers.values() if worker.active_session_id is None] + if never_used: + return min(never_used, key=lambda item: item.last_used_at) + + reusable_workers = [ + worker + for worker in idle_workers.values() + if worker.active_session_id not in pinned_session_ids + ] + if reusable_workers: + return min(reusable_workers, key=lambda item: item.last_used_at) + + return min(idle_workers.values(), key=lambda item: item.last_used_at, default=None) + + +def _service_stats(service: ServiceState) -> Dict[str, Any]: + with service.condition: + workers = [ + { + "worker_id": worker.worker_id, + "busy": worker.busy, + "active_session_id": worker.active_session_id, + "cached_tokens": len(worker.active_token_ids), + "needs_reset": worker.needs_reset, + "total_requests": worker.total_requests, + "total_tokens_generated": worker.total_tokens_generated, + "total_steps": worker.total_steps, + } + for worker in service.workers + ] + return { + "object": "service.stats", + "model": service.model_name, + "worker_count": len(service.workers), + "max_batch_size": service.max_batch_size, + "batch_wait_ms": int(service.batch_wait_s * 1000), + "queue_depth": len(service.pending_requests), + "active_requests": sum(1 for worker in service.workers if worker.busy), + "session_count": len(service.sessions), + "scheduled_batches": service.scheduled_batches, + "max_observed_batch_size": service.max_observed_batch_size, + "completed_requests": service.completed_requests, + "failed_requests": service.failed_requests, + "total_enqueued": service.total_enqueued, + "total_generated_tokens": service.total_generated_tokens, + "requeued_requests": service.requeued_requests, + "cache_reuse_hits": service.cache_reuse_hits, + "avg_queue_wait_ms": ( + (service.total_queue_wait_s / service.completed_requests) * 1000.0 + if service.completed_requests + else 0.0 + ), + "avg_request_time_ms": ( + (service.total_request_time_s / service.completed_requests) * 1000.0 + if service.completed_requests + else 0.0 + ), + "cached_session_count": sum(1 for session in service.sessions.values() if session.worker_id is not None), + "workers": workers, + } + + +def _scheduler_loop(service: ServiceState) -> None: + while True: + assigned: List[tuple[WorkerState, CompletionRequest]] = [] + with service.condition: + while not service.shutdown and not service.pending_requests: + service.condition.wait() + if service.shutdown: + return + + if len(service.pending_requests) < service.max_batch_size and service.batch_wait_s > 0: + service.condition.wait(timeout=service.batch_wait_s) + if service.shutdown: + return + + candidates: List[CompletionRequest] = [] + while service.pending_requests and len(candidates) < service.max_batch_size: + candidates.append(service.pending_requests.popleft()) + + idle_workers = { + worker.worker_id: worker + for worker in service.workers + if not worker.busy + } + pinned_session_ids = { + pending_request.session_id + for pending_request in service.pending_requests + if pending_request.started + } + pinned_session_ids.update( + candidate.session_id + for candidate in candidates + if candidate.started + ) + unassigned: List[CompletionRequest] = [] + if idle_workers: + batch_id = service.next_batch_id + service.next_batch_id += 1 + for request in candidates: + worker = _select_worker_locked(service, request, idle_workers, pinned_session_ids) + if worker is None: + unassigned.append(request) + continue + if worker.active_session_id is not None and worker.active_session_id != request.session_id: + _detach_worker_session_locked(service, worker) + worker.busy = True + worker.current_request_id = request.completion_id + request.batch_id = batch_id + if request.first_batch_id is None: + request.first_batch_id = batch_id + request.assigned_worker_id = worker.worker_id + request.steps_dispatched += 1 + if request.first_dispatched_at is None: + request.first_dispatched_at = time.time() + assigned.append((worker, request)) + idle_workers.pop(worker.worker_id, None) + pinned_session_ids.add(request.session_id) + + if assigned: + service.scheduled_batches += 1 + service.max_observed_batch_size = max( + service.max_observed_batch_size, + len(assigned), + ) + else: + for request in reversed(candidates): + service.pending_requests.appendleft(request) + service.condition.wait(timeout=max(service.batch_wait_s, 0.01)) + continue + + for request in reversed(unassigned): + service.pending_requests.appendleft(request) + + for worker, request in assigned: + worker.job_queue.put(request) + + +def _worker_loop(service: ServiceState, worker: WorkerState) -> None: + while True: + request = worker.job_queue.get() + if request is None: + return + _run_request(service, worker, request) + + +def _run_request(service: ServiceState, worker: WorkerState, request: CompletionRequest) -> None: + try: + emit_role_chunk = False + with service.condition: + batch_id = request.batch_id + worker_id = worker.worker_id + truncate_to = None + + if not request.started: + plan = _prepare_worker_reuse(worker, request.session_id, request.input_ids) + request.cache_reused_tokens = plan.reused_tokens + request.prompt_inputs = list(plan.generation_inputs) + request.reset_state = plan.reset_state + request.started = True + request.assigned_worker_id = worker_id + truncate_to = plan.truncate_to + if request.cache_reused_tokens > 0: + service.cache_reuse_hits += 1 + emit_role_chunk = request.stream + + if truncate_to is not None: + worker.model.truncate(truncate_to) + + if emit_role_chunk: + request.result_queue.put( + _sse_payload( + { + "id": request.completion_id, + "object": "chat.completion.chunk", + "created": request.created, + "model": request.request_model, + "session_id": request.session_id, + "worker_id": worker_id, + "batch_id": batch_id, + "cache_reused_tokens": request.cache_reused_tokens, + "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], + } + ) + ) + + token_source = request.prompt_inputs if not request.generated_ids else [request.generated_ids[-1]] + next_token = _model_generate_next( + worker.model, + token_source, + top_k=request.top_k, + top_p=request.top_p, + temperature=request.temperature, + reset_state=request.reset_state if not request.generated_ids else False, + ) + request.generated_ids.append(next_token) + text_chunk, request.last_text = _decode_text_delta( + service.tokenizer, + request.generated_ids, + request.last_text, + ) + + finished = ( + next_token == getattr(worker.model.meta, "end_token", None) + or len(request.generated_ids) >= request.max_tokens + ) + assistant_content = service.tokenizer.decode( + request.generated_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + finish_reason = _finish_reason( + getattr(worker.model.meta, "end_token", None), + request.generated_ids, + request.max_tokens, + ) + + with service.condition: + worker.busy = False + worker.current_request_id = None + worker.last_used_at = time.time() + worker.total_steps += 1 + worker.total_tokens_generated += 1 + service.total_generated_tokens += 1 + + if finished: + request.finished_at = time.time() + _commit_session_locked( + service, + worker, + request.session_id, + request.messages + [{"role": "assistant", "content": assistant_content}], + request.input_ids + request.generated_ids, + ) + worker.total_requests += 1 + service.completed_requests += 1 + if request.first_dispatched_at is not None: + service.total_queue_wait_s += max(0.0, request.first_dispatched_at - request.enqueued_at) + service.total_request_time_s += max(0.0, request.finished_at - request.enqueued_at) + _decrement_inflight_locked(service, request.session_id) + else: + worker.active_session_id = request.session_id + worker.active_token_ids = list(request.input_ids + request.generated_ids) + worker.needs_reset = False + service.requeued_requests += 1 + service.pending_requests.append(request) + service.condition.notify_all() + service.condition.notify_all() + + if request.stream and text_chunk: + request.result_queue.put( + _sse_payload( + { + "id": request.completion_id, + "object": "chat.completion.chunk", + "created": request.created, + "model": request.request_model, + "session_id": request.session_id, + "worker_id": worker_id, + "batch_id": batch_id, + "cache_reused_tokens": request.cache_reused_tokens, + "choices": [{"index": 0, "delta": {"content": text_chunk}, "finish_reason": None}], + } + ) + ) + + if not finished: + return + + if request.stream: + request.result_queue.put( + _sse_payload( + { + "id": request.completion_id, + "object": "chat.completion.chunk", + "created": request.created, + "model": request.request_model, + "session_id": request.session_id, + "worker_id": worker_id, + "batch_id": request.first_batch_id, + "last_batch_id": batch_id, + "dispatch_count": request.steps_dispatched, + "cache_reused_tokens": request.cache_reused_tokens, + "queue_wait_ms": ( + max(0.0, (request.first_dispatched_at or request.enqueued_at) - request.enqueued_at) * 1000.0 + ), + "total_time_ms": ( + max(0.0, (request.finished_at or time.time()) - request.enqueued_at) * 1000.0 + ), + "choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}], + } + ) + ) + request.result_queue.put("data: [DONE]\n\n") + request.result_queue.put(_STREAM_END) + return + + request.result_queue.put( + { + "id": request.completion_id, + "object": "chat.completion", + "created": request.created, + "model": request.request_model, + "session_id": request.session_id, + "worker_id": worker_id, + "batch_id": request.first_batch_id, + "last_batch_id": batch_id, + "dispatch_count": request.steps_dispatched, + "cache_reused_tokens": request.cache_reused_tokens, + "queue_wait_ms": ( + max(0.0, (request.first_dispatched_at or request.enqueued_at) - request.enqueued_at) * 1000.0 + ), + "total_time_ms": ( + max(0.0, (request.finished_at or time.time()) - request.enqueued_at) * 1000.0 + ), + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": assistant_content}, + "finish_reason": finish_reason, + } + ], + "usage": _usage(len(request.input_ids), len(request.generated_ids)), + } + ) + except Exception as exc: # pragma: no cover - defensive path + with service.condition: + _detach_worker_session_locked(service, worker) + worker.busy = False + worker.current_request_id = None + worker.last_used_at = time.time() + service.failed_requests += 1 + _decrement_inflight_locked(service, request.session_id) + service.condition.notify_all() + + if request.stream: + request.result_queue.put( + _sse_payload( + { + "id": request.completion_id, + "object": "chat.completion.chunk", + "created": request.created, + "model": request.request_model, + "session_id": request.session_id, + "worker_id": worker.worker_id, + "batch_id": request.batch_id, + "choices": [{"index": 0, "delta": {"content": ""}, "finish_reason": "error"}], + "error": str(exc), + } + ) + ) + request.result_queue.put("data: [DONE]\n\n") + request.result_queue.put(_STREAM_END) + return + + request.result_queue.put(exc) + + +def _shutdown_service(service: ServiceState) -> None: + with service.condition: + if service.shutdown: + return + service.shutdown = True + service.condition.notify_all() + + if service.scheduler_thread is not None: + service.scheduler_thread.join(timeout=5) + service.scheduler_thread = None + + for worker in service.workers: + worker.job_queue.put(None) + for worker in service.workers: + if worker.thread is not None: + worker.thread.join(timeout=5) + worker.thread = None + + +def _resolve_runtime_components( + model_path: Optional[str], + device: str, + device_id: int, + tp_size: int, + tp_device_ids: Optional[List[int]], + *, + tokenizer, + model, + model_factory, + model_name: Optional[str], + num_workers: int, +) -> tuple[Any, Callable[[], Any], str]: + if tokenizer is None: + if model_path is None: + raise ValueError("model_path is required when tokenizer is not injected") + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + if model is not None and model_factory is not None: + raise ValueError("Pass either `model` or `model_factory`, not both") + + if model_factory is None: + if model is not None: + if num_workers != 1: + raise ValueError("model_factory is required when injecting a model with num_workers > 1") + model_factory = lambda: model + resolved_model_name = model_name or "llaisys-chat" + return tokenizer, model_factory, resolved_model_name + + if model_path is None: + raise ValueError("model_path is required when model/model_factory are not injected") + + device_type = _device_from_name(device) + + def model_factory() -> Any: + return llaisys.models.create_model( + model_path, + device_type, + device_id, + tp_size=tp_size, + tp_device_ids=tp_device_ids, + ) + + resolved_model_name = model_name or llaisys.models.default_model_name(model_path) + return tokenizer, model_factory, resolved_model_name + + resolved_model_name = model_name or "llaisys-chat" + return tokenizer, model_factory, resolved_model_name + + +def create_app( + model_path: Optional[str] = None, + device: str = "cpu", + device_id: int = 0, + tp_size: int = 1, + tp_device_ids: Optional[List[int]] = None, + *, + tokenizer=None, + model=None, + model_factory: Optional[Callable[[], Any]] = None, + model_name: Optional[str] = None, + num_workers: int = 1, + max_batch_size: int = 1, + batch_wait_ms: int = 5, +) -> FastAPI: + if num_workers <= 0: + raise ValueError("num_workers must be positive") + if max_batch_size <= 0: + raise ValueError("max_batch_size must be positive") + if batch_wait_ms < 0: + raise ValueError("batch_wait_ms must be non-negative") + if tp_size <= 0: + raise ValueError("tp_size must be positive") + if tp_size > 1 and device != "nvidia": + raise ValueError("Tensor parallel service mode currently supports only NVIDIA devices") + if tp_size > 1 and num_workers != 1: + raise ValueError("tp_size > 1 currently requires num_workers == 1") + + tokenizer, model_factory, resolved_model_name = _resolve_runtime_components( + model_path, + device, + device_id, + tp_size, + tp_device_ids, + tokenizer=tokenizer, + model=model, + model_factory=model_factory, + model_name=model_name, + num_workers=num_workers, + ) + + workers = [WorkerState(worker_id=i, model=model_factory()) for i in range(num_workers)] + service = ServiceState( + tokenizer=tokenizer, + model_name=resolved_model_name, + workers=workers, + max_batch_size=max_batch_size, + batch_wait_s=batch_wait_ms / 1000.0, + ) + + for worker in service.workers: + worker.thread = threading.Thread( + target=_worker_loop, + args=(service, worker), + name=f"llaisys-worker-{worker.worker_id}", + daemon=True, + ) + worker.thread.start() + + service.scheduler_thread = threading.Thread( + target=_scheduler_loop, + args=(service,), + name="llaisys-scheduler", + daemon=True, + ) + service.scheduler_thread.start() + + app = FastAPI(title="LLAISYS Chat API", version="0.3.0") + app.state.service = service + app.state.tokenizer = tokenizer + app.state.model_name = resolved_model_name + + @app.on_event("shutdown") + def shutdown_event() -> None: + _shutdown_service(app.state.service) + + @app.get("/health") + def health() -> Dict[str, Any]: + stats = _service_stats(app.state.service) + return { + "status": "ok", + "model": app.state.model_name, + "sessions": stats["session_count"], + "queue_depth": stats["queue_depth"], + "active_requests": stats["active_requests"], + "worker_count": stats["worker_count"], + "tp_size": tp_size, + "tp_device_ids": list(tp_device_ids or []), + } + + @app.get("/v1/service/stats") + def service_stats() -> Dict[str, Any]: + return _service_stats(app.state.service) + + @app.get("/v1/sessions") + def list_sessions() -> Dict[str, Any]: + service = app.state.service + with service.condition: + sessions = sorted(service.sessions.values(), key=lambda item: item.updated_at, reverse=True) + return { + "object": "list", + "data": [ + _session_payload( + session, + inflight=service.session_inflight.get(session.session_id, 0), + ) + for session in sessions + ], + } + + @app.post("/v1/sessions") + def create_session(payload: Optional[Dict[str, Any]] = Body(default=None)) -> JSONResponse: + payload = payload or {} + session_id = _resolve_session_id(payload) + messages = payload.get("messages") + normalized_messages = _normalize_messages(messages) if messages else [] + service = app.state.service + + with service.condition: + if session_id in service.sessions: + raise HTTPException(status_code=409, detail=f"Session `{session_id}` already exists") + + now = time.time() + service.sessions[session_id] = SessionState( + session_id=session_id, + messages=normalized_messages, + token_ids=[], + created_at=now, + updated_at=now, + ) + return JSONResponse(_session_payload(service.sessions[session_id], include_messages=True)) + + @app.get("/v1/sessions/{session_id}") + def get_session(session_id: str) -> Dict[str, Any]: + service = app.state.service + with service.condition: + state = service.sessions.get(session_id) + if state is None: + raise HTTPException(status_code=404, detail=f"Unknown session `{session_id}`") + payload = _session_payload( + state, + include_messages=True, + inflight=service.session_inflight.get(session_id, 0), + ) + payload["active"] = state.worker_id is not None + return payload + + @app.put("/v1/sessions/{session_id}") + def replace_session(session_id: str, payload: Dict[str, Any] = Body(...)) -> Dict[str, Any]: + if "messages" not in payload: + raise HTTPException(status_code=400, detail="`messages` is required") + + try: + normalized_messages = _normalize_messages(payload["messages"]) + except (TypeError, ValueError) as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + service = app.state.service + with service.condition: + state = service.sessions.get(session_id) + if state is None: + raise HTTPException(status_code=404, detail=f"Unknown session `{session_id}`") + if service.session_inflight.get(session_id, 0) > 0: + raise HTTPException(status_code=409, detail=f"Session `{session_id}` is busy") + + _clear_session_cache_locked(service, session_id) + state.messages = normalized_messages + state.updated_at = time.time() + return _session_payload(state, include_messages=True) + + @app.delete("/v1/sessions/{session_id}") + def delete_session(session_id: str) -> Dict[str, Any]: + service = app.state.service + with service.condition: + state = service.sessions.get(session_id) + if state is None: + raise HTTPException(status_code=404, detail=f"Unknown session `{session_id}`") + if service.session_inflight.get(session_id, 0) > 0: + raise HTTPException(status_code=409, detail=f"Session `{session_id}` is busy") + + _clear_session_cache_locked(service, session_id) + service.sessions.pop(session_id, None) + return {"id": session_id, "object": "session.deleted", "deleted": True} + + @app.post("/v1/chat/completions") + async def chat_completions(payload: Dict[str, Any] = Body(...)): + service = app.state.service + session_id = _resolve_session_id(payload) + with service.condition: + stored_session = service.sessions.get(session_id) + + raw_messages = payload.get("messages") + if raw_messages is None: + if stored_session is None: + raise HTTPException(status_code=400, detail="`messages` is required for a new session") + raw_messages = stored_session.messages + + try: + messages = _normalize_messages(raw_messages) + input_ids = _build_input_ids(service.tokenizer, messages) + max_tokens = int(payload.get("max_tokens", payload.get("max_completion_tokens", 128))) + top_k = int(payload.get("top_k", 1)) + top_p = float(payload.get("top_p", 1.0)) + temperature = float(payload.get("temperature", 1.0)) + stream = bool(payload.get("stream", False)) + except (TypeError, ValueError) as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + if max_tokens <= 0: + raise HTTPException(status_code=400, detail="`max_tokens` must be positive") + + completion_id = f"chatcmpl-{uuid.uuid4().hex}" + created = int(time.time()) + request_model = str(payload.get("model", app.state.model_name)) + result_queue: "queue.Queue[Any]" = queue.Queue() + request = CompletionRequest( + completion_id=completion_id, + session_id=session_id, + messages=messages, + input_ids=input_ids, + max_tokens=max_tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + stream=stream, + created=created, + request_model=request_model, + result_queue=result_queue, + enqueued_at=time.time(), + ) + + with service.condition: + if service.session_inflight.get(session_id, 0) > 0: + raise HTTPException(status_code=409, detail=f"Session `{session_id}` already has an in-flight request") + service.session_inflight[session_id] = service.session_inflight.get(session_id, 0) + 1 + service.pending_requests.append(request) + service.total_enqueued += 1 + service.condition.notify_all() + + if stream: + def event_stream() -> Iterable[str]: + while True: + item = result_queue.get() + if item is _STREAM_END: + break + yield item + + return StreamingResponse(event_stream(), media_type="text/event-stream") + + result = result_queue.get() + if isinstance(result, Exception): + raise HTTPException(status_code=500, detail=str(result)) + return JSONResponse(result) + + return app + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", required=True, type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device-id", default=0, type=int) + parser.add_argument("--host", default="127.0.0.1", type=str) + parser.add_argument("--port", default=8000, type=int) + parser.add_argument("--model-name", default=None, type=str) + parser.add_argument("--num-workers", default=1, type=int) + parser.add_argument("--max-batch-size", default=1, type=int) + parser.add_argument("--batch-wait-ms", default=5, type=int) + parser.add_argument("--tp-size", default=1, type=int) + parser.add_argument("--tp-device-ids", default=None, type=str) + args = parser.parse_args() + + try: + import uvicorn + except ImportError as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "Uvicorn is optional. Install with `pip install ./python[server]`." + ) from exc + + tp_device_ids = None + if args.tp_device_ids: + tp_device_ids = [int(part.strip()) for part in args.tp_device_ids.split(",") if part.strip()] + + app = create_app( + model_path=args.model_path, + device=args.device, + device_id=args.device_id, + tp_size=args.tp_size, + tp_device_ids=tp_device_ids, + model_name=args.model_name, + num_workers=args.num_workers, + max_batch_size=args.max_batch_size, + batch_wait_ms=args.batch_wait_ms, + ) + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/python/llaisys/libllaisys/llaisys_types.py b/python/llaisys/libllaisys/llaisys_types.py index c5a0b467..af9998f0 100644 --- a/python/llaisys/libllaisys/llaisys_types.py +++ b/python/llaisys/libllaisys/llaisys_types.py @@ -2,7 +2,6 @@ from enum import IntEnum -# Device Type enum class DeviceType(IntEnum): CPU = 0 NVIDIA = 1 @@ -12,7 +11,6 @@ class DeviceType(IntEnum): llaisysDeviceType_t = ctypes.c_int -# Data Type enum class DataType(IntEnum): INVALID = 0 BYTE = 1 @@ -39,7 +37,6 @@ class DataType(IntEnum): llaisysDataType_t = ctypes.c_int -# Memory Copy Kind enum class MemcpyKind(IntEnum): H2H = 0 H2D = 1 @@ -48,8 +45,13 @@ class MemcpyKind(IntEnum): llaisysMemcpyKind_t = ctypes.c_int +llaisysTensor_t = ctypes.c_void_p + +class LlaisysQwen2Model(ctypes.Structure): + pass +llaisysQwen2ModelHandle = ctypes.POINTER(LlaisysQwen2Model) +llaisysQwen2Weights_p = ctypes.c_void_p -# Stream type (opaque pointer) llaisysStream_t = ctypes.c_void_p __all__ = [ diff --git a/python/llaisys/libllaisys/models.py b/python/llaisys/libllaisys/models.py new file mode 100644 index 00000000..b8a44911 --- /dev/null +++ b/python/llaisys/libllaisys/models.py @@ -0,0 +1,42 @@ +import ctypes +from .llaisys_types import llaisysDataType_t, llaisysTensor_t, llaisysDeviceType_t + +class LlaisysQwen2Meta(ctypes.Structure): + _fields_ = [ + ("dtype", llaisysDataType_t), + ("nlayer", ctypes.c_size_t), + ("hs", ctypes.c_size_t), + ("nh", ctypes.c_size_t), + ("nkvh", ctypes.c_size_t), + ("dh", ctypes.c_size_t), + ("di", ctypes.c_size_t), + ("maxseq", ctypes.c_size_t), + ("voc", ctypes.c_size_t), + ("epsilon", ctypes.c_float), + ("theta", ctypes.c_float), + ("end_token", ctypes.c_int64), + ] + +class LlaisysQwen2Weights(ctypes.Structure): + _fields_ = [ + ("in_embed", llaisysTensor_t), + ("out_embed", llaisysTensor_t), + ("out_norm_w", llaisysTensor_t), + ("attn_norm_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_q_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_q_b", ctypes.POINTER(llaisysTensor_t)), + ("attn_k_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_k_b", ctypes.POINTER(llaisysTensor_t)), + ("attn_v_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_v_b", ctypes.POINTER(llaisysTensor_t)), + ("attn_o_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_norm_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_gate_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_up_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_down_w", ctypes.POINTER(llaisysTensor_t)), + ] + +class LlaisysQwen2Model(ctypes.Structure): + pass + +llaisysQwen2ModelHandle = ctypes.POINTER(LlaisysQwen2Model) diff --git a/python/llaisys/libllaisys/ops.py b/python/llaisys/libllaisys/ops.py index 5be095ef..c56f4749 100644 --- a/python/llaisys/libllaisys/ops.py +++ b/python/llaisys/libllaisys/ops.py @@ -1,5 +1,5 @@ from .tensor import llaisysTensor_t -from ctypes import c_float +from ctypes import c_float, c_int, c_int64 def load_ops(lib): lib.llaisysAdd.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t] @@ -14,6 +14,9 @@ def load_ops(lib): lib.llaisysLinear.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t, llaisysTensor_t] lib.llaisysLinear.restype = None + lib.llaisysSample.argtypes = [llaisysTensor_t, c_int, c_float, c_float] + lib.llaisysSample.restype = c_int64 + lib.llaisysRearrange.argtypes = [llaisysTensor_t, llaisysTensor_t] lib.llaisysRearrange.restype = None diff --git a/python/llaisys/models/__init__.py b/python/llaisys/models/__init__.py index af9918b0..5b2b096d 100644 --- a/python/llaisys/models/__init__.py +++ b/python/llaisys/models/__init__.py @@ -1 +1,59 @@ +import json +from pathlib import Path +from typing import Type + +from ..libllaisys import DeviceType +from .llama import Llama from .qwen2 import Qwen2 +from .tensor_parallel import TensorParallelQwen2 + +MODEL_CLASSES: tuple[Type[Qwen2], ...] = (Qwen2, Llama) + + +def detect_model_type(model_path: str) -> str: + config_path = Path(model_path) / "config.json" + with open(config_path, "r") as handle: + config = json.load(handle) + return str(config.get("model_type", "")).lower() + + +def create_model( + model_path: str, + device: DeviceType = DeviceType.CPU, + device_id: int = 0, + *, + tp_size: int = 1, + tp_device_ids=None, +): + model_type = detect_model_type(model_path) + if int(tp_size) > 1: + if device != DeviceType.NVIDIA: + raise ValueError("Tensor parallel inference currently supports only NVIDIA devices") + if TensorParallelQwen2.supports_model_type(model_type): + return TensorParallelQwen2( + model_path, + device, + device_id, + tp_size=int(tp_size), + tp_device_ids=tp_device_ids, + ) + raise ValueError(f"Tensor parallel inference is not supported for model type: {model_type or ''}") + for cls in MODEL_CLASSES: + if cls.supports_model_type(model_type): + return cls(model_path, device, device_id) + raise ValueError(f"Unsupported model type: {model_type or ''}") + + +def default_model_name(model_path: str) -> str: + model_type = detect_model_type(model_path) + return f"llaisys-{model_type or 'model'}" + + +__all__ = [ + "Qwen2", + "Llama", + "TensorParallelQwen2", + "create_model", + "detect_model_type", + "default_model_name", +] diff --git a/python/llaisys/models/llama.py b/python/llaisys/models/llama.py new file mode 100644 index 00000000..888eaaf5 --- /dev/null +++ b/python/llaisys/models/llama.py @@ -0,0 +1,5 @@ +from .qwen2 import Qwen2 + + +class Llama(Qwen2): + SUPPORTED_MODEL_TYPES = {"llama", "mistral"} diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b2..562534ec 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,33 +1,387 @@ -from typing import Sequence -from ..libllaisys import LIB_LLAISYS -from ..libllaisys import DeviceType - +import ctypes +import json +import mmap +import os +import struct from pathlib import Path -import safetensors +from typing import ClassVar, Dict, Iterator, List, Sequence, Tuple + +import numpy as np + +from ..libllaisys import LIB_LLAISYS, llaisysTensor_t, llaisysDataType_t, llaisysDeviceType_t, DataType, DeviceType +from ..libllaisys.models import LlaisysQwen2Meta, LlaisysQwen2Weights, LlaisysQwen2Model, llaisysQwen2ModelHandle +from ..tensor import Tensor + +LIB_LLAISYS.llaisysQwen2ModelCreate.argtypes = [ctypes.POINTER(LlaisysQwen2Meta), llaisysDeviceType_t, ctypes.POINTER(ctypes.c_int), ctypes.c_int] +LIB_LLAISYS.llaisysQwen2ModelCreate.restype = llaisysQwen2ModelHandle + +LIB_LLAISYS.llaisysQwen2ModelDestroy.argtypes = [llaisysQwen2ModelHandle] +LIB_LLAISYS.llaisysQwen2ModelDestroy.restype = None + +LIB_LLAISYS.llaisysQwen2ModelWeights.argtypes = [llaisysQwen2ModelHandle] +LIB_LLAISYS.llaisysQwen2ModelWeights.restype = ctypes.POINTER(LlaisysQwen2Weights) + +LIB_LLAISYS.llaisysQwen2ModelReset.argtypes = [llaisysQwen2ModelHandle] +LIB_LLAISYS.llaisysQwen2ModelReset.restype = None + +LIB_LLAISYS.llaisysQwen2ModelTruncate.argtypes = [llaisysQwen2ModelHandle, ctypes.c_size_t] +LIB_LLAISYS.llaisysQwen2ModelTruncate.restype = None + +LIB_LLAISYS.llaisysQwen2ModelInfer.argtypes = [llaisysQwen2ModelHandle, ctypes.POINTER(ctypes.c_int64), ctypes.c_size_t] +LIB_LLAISYS.llaisysQwen2ModelInfer.restype = ctypes.c_int64 + +LIB_LLAISYS.llaisysQwen2ModelGenerateNext.argtypes = [ + llaisysQwen2ModelHandle, + ctypes.POINTER(ctypes.c_int64), + ctypes.c_size_t, + ctypes.c_int, + ctypes.c_float, + ctypes.c_float, +] +LIB_LLAISYS.llaisysQwen2ModelGenerateNext.restype = ctypes.c_int64 class Qwen2: + SUPPORTED_MODEL_TYPES: ClassVar[set[str]] = {"qwen2"} - def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor + @classmethod + def supports_model_type(cls, model_type: str) -> bool: + return str(model_type).lower() in cls.SUPPORTED_MODEL_TYPES + def __init__(self, model_path: str, device: DeviceType = DeviceType.CPU, device_id: int = 0): model_path = Path(model_path) + config_path = model_path / "config.json" + + with open(config_path, "r") as f: + config = json.load(f) + self.config = config + self.model_type = str(config.get("model_type", "")).lower() + + self.device = device + self.device_id = device_id + + self.meta = LlaisysQwen2Meta() + self._configure_meta(config) + + dev_ids = (ctypes.c_int * 1)(device_id) + self.handle = LIB_LLAISYS.llaisysQwen2ModelCreate(ctypes.byref(self.meta), device, dev_ids, 1) + self.weights_ptr = LIB_LLAISYS.llaisysQwen2ModelWeights(self.handle) + self.tensors_ref = [] for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") - for name_ in data_.keys(): - ## TODO: load the model weights - pass + print(f"Loading weights from {file}...") + weights_data = self._load_safetensors(file) + for key, (arr, shape, source_dtype) in weights_data.items(): + raw = self._convert_array_for_model_dtype(arr, source_dtype) + if not raw.flags["C_CONTIGUOUS"]: + raw = np.ascontiguousarray(raw) + + t = Tensor(list(shape), self.meta.dtype, device, device_id) + t.load(ctypes.c_void_p(raw.ctypes.data)) + self._assign_weight(key, t) + + self._finalize_weights() + + def _configure_meta(self, config: Dict[str, object]) -> None: + dtype_str = str(config.get("torch_dtype", "float32")) + self.meta.dtype = self._runtime_dtype(dtype_str, self.device) + self.meta.nlayer = int(config.get("num_hidden_layers", 24)) + self.meta.hs = int(config.get("hidden_size", 2048)) + self.meta.nh = int(config.get("num_attention_heads", 16)) + self.meta.nkvh = int(config.get("num_key_value_heads", self.meta.nh)) + self.meta.dh = self.meta.hs // self.meta.nh + self.meta.di = int(config.get("intermediate_size", 11008)) + self.meta.maxseq = int(config.get("max_position_embeddings", 8192)) + self.meta.voc = int(config.get("vocab_size", 151936)) + self.meta.epsilon = float(config.get("rms_norm_eps", 1e-6)) + self.meta.theta = float(config.get("rope_theta", 1000000.0)) + eos_token = config.get("eos_token_id", 151643) + if isinstance(eos_token, list): + eos_token = eos_token[0] if eos_token else 151643 + self.meta.end_token = int(eos_token) + + @staticmethod + def _llaisys_dtype_from_string(dtype_str: str) -> DataType: + normalized = str(dtype_str).lower() + if normalized in {"bfloat16", "bf16"}: + return DataType.BF16 + if normalized in {"float16", "f16", "half"}: + return DataType.F16 + return DataType.F32 + + @staticmethod + def _runtime_dtype(dtype_str: str, device: DeviceType) -> DataType: + # CPU inference is currently optimized only for F32 kernels. Keep that fast path + # as the default, and allow native 16-bit weights only as an explicit opt-in. + if device == DeviceType.CPU and os.getenv("LLAISYS_CPU_NATIVE_DTYPE", "").lower() not in {"1", "true", "yes", "on"}: + return DataType.F32 + return Qwen2._llaisys_dtype_from_string(dtype_str) + + def _load_safetensors(self, path: Path) -> Dict[str, Tuple[np.ndarray, Sequence[int], str]]: + tensors = {} + with open(path, "rb") as f: + length_bytes = f.read(8) + if not length_bytes: + return {} + header_size = struct.unpack(" np.ndarray: + target_dtype = self.meta.dtype + if target_dtype == DataType.F32: + return self._to_float32_array(arr, source_dtype) + if target_dtype == DataType.F16: + return self._to_float16_array(arr, source_dtype) + if target_dtype == DataType.BF16: + return self._to_bfloat16_bytes(arr, source_dtype) + raise ValueError(f"Unsupported model dtype: {target_dtype}") + + @staticmethod + def _to_float32_array(arr: np.ndarray, source_dtype: str) -> np.ndarray: + if source_dtype == "bf16": + u32 = arr.astype(np.uint32) << 16 + return u32.view(np.float32) + if source_dtype == "f16": + return arr.astype(np.float32) + return np.asarray(arr, dtype=np.float32) + + @staticmethod + def _to_float16_array(arr: np.ndarray, source_dtype: str) -> np.ndarray: + if source_dtype == "f16": + return np.asarray(arr, dtype=np.float16) + return Qwen2._to_float32_array(arr, source_dtype).astype(np.float16) + + @staticmethod + def _to_bfloat16_bytes(arr: np.ndarray, source_dtype: str) -> np.ndarray: + if source_dtype == "bf16": + return np.asarray(arr, dtype=np.uint16) + float32_arr = Qwen2._to_float32_array(arr, source_dtype) + u32 = float32_arr.view(np.uint32) + return (u32 >> 16).astype(np.uint16) + + def _finalize_weights(self) -> None: + weights = self.weights_ptr.contents + if weights.out_embed: + return + if self.config.get("tie_word_embeddings") and weights.in_embed: + weights.out_embed = weights.in_embed + return + raise ValueError("Missing output embedding weights: expected `lm_head.weight` or tied word embeddings") + + def _assign_weight(self, name: str, t: Tensor): + w = self.weights_ptr.contents + # Keep Python Tensor wrappers alive because the model stores raw C tensor handles. + self.tensors_ref.append(t) + if name == "model.embed_tokens.weight": + w.in_embed = t.lib_tensor() + elif name == "lm_head.weight": + w.out_embed = t.lib_tensor() + elif name == "model.norm.weight": + w.out_norm_w = t.lib_tensor() + elif name.startswith("model.layers."): + parts = name.split(".") + layer_idx = int(parts[2]) + suffix = ".".join(parts[3:]) + + def set_w(target_ptr): + target_ptr[layer_idx] = t.lib_tensor() + + if suffix == "input_layernorm.weight": + set_w(w.attn_norm_w) + elif suffix == "self_attn.q_proj.weight": + set_w(w.attn_q_w) + elif suffix == "self_attn.q_proj.bias": + set_w(w.attn_q_b) + elif suffix == "self_attn.k_proj.weight": + set_w(w.attn_k_w) + elif suffix == "self_attn.k_proj.bias": + set_w(w.attn_k_b) + elif suffix == "self_attn.v_proj.weight": + set_w(w.attn_v_w) + elif suffix == "self_attn.v_proj.bias": + set_w(w.attn_v_b) + elif suffix == "self_attn.o_proj.weight": + set_w(w.attn_o_w) + elif suffix == "post_attention_layernorm.weight": + set_w(w.mlp_norm_w) + elif suffix == "mlp.gate_proj.weight": + set_w(w.mlp_gate_w) + elif suffix == "mlp.up_proj.weight": + set_w(w.mlp_up_w) + elif suffix == "mlp.down_proj.weight": + set_w(w.mlp_down_w) + + def __del__(self): + if hasattr(self, "handle") and self.handle: + LIB_LLAISYS.llaisysQwen2ModelDestroy(self.handle) + + def reset(self) -> None: + LIB_LLAISYS.llaisysQwen2ModelReset(self.handle) + + def truncate(self, position: int) -> None: + if position < 0: + raise ValueError("position must be non-negative") + LIB_LLAISYS.llaisysQwen2ModelTruncate(self.handle, int(position)) + + def _generate_next( + self, + inputs: Sequence[int], + top_k: int = 1, + top_p: float = 1.0, + temperature: float = 1.0, + ) -> int: + if not inputs: + raise ValueError("inputs must not be empty") + + arr = (ctypes.c_int64 * len(inputs))(*inputs) + return int( + LIB_LLAISYS.llaisysQwen2ModelGenerateNext( + self.handle, + arr, + len(inputs), + int(top_k), + float(top_p), + float(temperature), + ) + ) + + def generate_next( + self, + inputs: Sequence[int], + top_k: int = 1, + top_p: float = 1.0, + temperature: float = 1.0, + *, + reset_state: bool = False, + ) -> int: + if reset_state: + self.reset() + return self._generate_next(inputs, top_k=top_k, top_p=top_p, temperature=temperature) def generate( self, inputs: Sequence[int], - max_new_tokens: int = None, + max_new_tokens: int = 20, top_k: int = 1, top_p: float = 0.8, temperature: float = 0.8, - ): + *, + reset_state: bool = True, + ) -> List[int]: + if not inputs: + raise ValueError("inputs must not be empty") + + if reset_state: + self.reset() + generated = [] + tokens = list(inputs) + + next_token = self.generate_next( + tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + generated.append(next_token) + tokens = [next_token] + + for _ in range(max_new_tokens - 1): + next_token = self.generate_next( + tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + generated.append(next_token) + tokens = [next_token] + + if next_token == self.meta.end_token: + break + + return list(inputs) + generated + + def stream_generate( + self, + inputs: Sequence[int], + *, + tokenizer=None, + max_new_tokens: int = 20, + top_k: int = 1, + top_p: float = 0.8, + temperature: float = 0.8, + reset_state: bool = True, + ) -> Iterator[Tuple[int, str]]: + if not inputs: + raise ValueError("inputs must not be empty") + + if reset_state: + self.reset() + prompt_tokens = list(inputs) + generated: List[int] = [] + last_text = "" + + for step in range(max_new_tokens): + token_source = prompt_tokens if step == 0 else [generated[-1]] + next_token = self.generate_next( + token_source, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + generated.append(next_token) + + text_chunk = "" + if tokenizer is not None: + decoded = tokenizer.decode( + generated, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + text_chunk = decoded[len(last_text):] if decoded.startswith(last_text) else decoded + last_text = decoded - # TODO: Implement generate function + yield next_token, text_chunk - return [] + if next_token == self.meta.end_token: + break diff --git a/python/llaisys/models/tensor_parallel.py b/python/llaisys/models/tensor_parallel.py new file mode 100644 index 00000000..07d43f1f --- /dev/null +++ b/python/llaisys/models/tensor_parallel.py @@ -0,0 +1,803 @@ +import ctypes +import json +import mmap +import os +import queue +import socket +import struct +import math +from dataclasses import dataclass +from pathlib import Path +from typing import ClassVar, Dict, Iterator, List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from ..libllaisys import DataType, DeviceType, MemcpyKind +from ..ops import Ops +from ..runtime import RuntimeAPI +from ..tensor import Tensor +from .qwen2 import Qwen2 + + +def _tp_log(rank: int, message: str) -> None: + if os.getenv("LLAISYS_TP_DEBUG", "").lower() not in {"1", "true", "yes", "on"}: + return + with open(f"/tmp/llaisys_tp_rank{rank}.log", "a", encoding="utf-8") as handle: + handle.write(f"{message}\n") + + +@dataclass +class _TensorParallelMeta: + dtype: DataType + nlayer: int + hs: int + nh: int + nkvh: int + dh: int + di: int + maxseq: int + voc: int + epsilon: float + theta: float + end_token: int + + +def _divisors(value: int) -> List[int]: + divisors: List[int] = [] + for candidate in range(1, int(math.isqrt(value)) + 1): + if value % candidate != 0: + continue + divisors.append(candidate) + paired = value // candidate + if paired != candidate: + divisors.append(paired) + return sorted(divisors) + + +def _valid_tp_sizes(meta: _TensorParallelMeta) -> List[int]: + common = math.gcd(math.gcd(int(meta.nh), int(meta.nkvh)), int(meta.di)) + return _divisors(common) + + +def _validate_tp_size(meta: _TensorParallelMeta, tp_size: int) -> None: + tp_size = int(tp_size) + valid_sizes = _valid_tp_sizes(meta) + if tp_size in valid_sizes: + return + raise ValueError( + "Unsupported tp_size for current tensor-parallel implementation: " + f"tp_size={tp_size}, num_attention_heads={meta.nh}, " + f"num_key_value_heads={meta.nkvh}, intermediate_size={meta.di}, " + f"valid_tp_sizes={valid_sizes}. " + "This implementation shards attention heads, KV heads, and MLP intermediate dims evenly across ranks." + ) + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _llaisys_dtype_to_torch(dtype: DataType) -> torch.dtype: + if dtype == DataType.F32: + return torch.float32 + if dtype == DataType.F16: + return torch.float16 + if dtype == DataType.BF16: + return torch.bfloat16 + raise ValueError(f"Unsupported runtime dtype: {dtype}") + + +def _resolve_tp_device_ids(device_id: int, tp_size: int, tp_device_ids: Optional[Sequence[int]]) -> List[int]: + if tp_device_ids is not None: + device_ids = [int(item) for item in tp_device_ids] + else: + device_ids = list(range(int(device_id), int(device_id) + int(tp_size))) + if len(device_ids) != int(tp_size): + raise ValueError("tp_device_ids length must match tp_size") + return device_ids + + +def _normalize_command_payload(command: str, seq: int, **payload) -> Dict[str, object]: + message = {"command": command, "seq": int(seq)} + message.update(payload) + return message + + +class _ShardedQwen2Rank: + def __init__(self, model_path: str, rank: int, world_size: int, device_id: int): + self.model_path = Path(model_path) + self.rank = int(rank) + self.world_size = int(world_size) + self.device_id = int(device_id) + self.device = DeviceType.NVIDIA + self.runtime = RuntimeAPI(DeviceType.NVIDIA) + self.runtime.set_device(self.device_id) + torch.cuda.set_device(self.device_id) + + with open(self.model_path / "config.json", "r") as handle: + config = json.load(handle) + self.config = config + if str(config.get("model_type", "")).lower() != "qwen2": + raise ValueError("TensorParallelQwen2 currently supports only Qwen2 models") + + dtype = Qwen2._runtime_dtype(str(config.get("torch_dtype", "float32")), DeviceType.NVIDIA) + eos_token = config.get("eos_token_id", 151643) + if isinstance(eos_token, list): + eos_token = eos_token[0] if eos_token else 151643 + self.meta = _TensorParallelMeta( + dtype=dtype, + nlayer=int(config.get("num_hidden_layers", 24)), + hs=int(config.get("hidden_size", 2048)), + nh=int(config.get("num_attention_heads", 16)), + nkvh=int(config.get("num_key_value_heads", int(config.get("num_attention_heads", 16)))), + dh=int(config.get("hidden_size", 2048)) // int(config.get("num_attention_heads", 16)), + di=int(config.get("intermediate_size", 11008)), + maxseq=int(config.get("max_position_embeddings", 8192)), + voc=int(config.get("vocab_size", 151936)), + epsilon=float(config.get("rms_norm_eps", 1e-6)), + theta=float(config.get("rope_theta", 1000000.0)), + end_token=int(eos_token), + ) + + _validate_tp_size(self.meta, self.world_size) + + self.local_nh = self.meta.nh // self.world_size + self.local_nkvh = self.meta.nkvh // self.world_size + self.local_q_dim = self.local_nh * self.meta.dh + self.local_kv_dim = self.local_nkvh * self.meta.dh + self.local_di = self.meta.di // self.world_size + self.torch_dtype = _llaisys_dtype_to_torch(self.meta.dtype) + self.cuda_device = torch.device(f"cuda:{self.device_id}") + + self.in_embed: Optional[Tensor] = None + self.out_embed: Optional[Tensor] = None + self.out_norm_w: Optional[Tensor] = None + self.layers: List[Dict[str, Tensor]] = [dict() for _ in range(self.meta.nlayer)] + self._tensors: List[Tensor] = [] + self._load_weights() + + if self.out_embed is None and self.config.get("tie_word_embeddings") and self.in_embed is not None: + self.out_embed = self.in_embed + if self.in_embed is None or self.out_embed is None or self.out_norm_w is None: + raise ValueError("Missing required embedding/final norm weights for tensor-parallel Qwen2") + + self.k_caches = [ + Tensor((self.meta.maxseq, self.local_nkvh, self.meta.dh), self.meta.dtype, self.device, self.device_id) + for _ in range(self.meta.nlayer) + ] + self.v_caches = [ + Tensor((self.meta.maxseq, self.local_nkvh, self.meta.dh), self.meta.dtype, self.device, self.device_id) + for _ in range(self.meta.nlayer) + ] + self._cur_pos = 0 + + def reset(self) -> None: + self._cur_pos = 0 + + def truncate(self, position: int) -> None: + position = int(position) + if position < 0 or position > self._cur_pos: + raise ValueError("truncate position exceeds current cache length") + self._cur_pos = position + + def generate_next(self, inputs: Sequence[int], top_k: int, top_p: float, temperature: float) -> int: + if not inputs: + raise ValueError("inputs must not be empty") + if self._cur_pos + len(inputs) > self.meta.maxseq: + raise ValueError("sequence exceeds KV-cache capacity") + + _tp_log(self.rank, f"generate_next start cur_pos={self._cur_pos} ntoken={len(inputs)}") + hidden = self._embedding(list(inputs)) + pos_ids = self._make_int64_tensor(np.arange(self._cur_pos, self._cur_pos + len(inputs), dtype=np.int64)) + scale = 1.0 / float(self.meta.dh) ** 0.5 + + for layer_idx, layer in enumerate(self.layers): + _tp_log(self.rank, f"layer {layer_idx} start") + normed = Tensor(hidden.shape(), self.meta.dtype, self.device, self.device_id) + Ops.rms_norm(normed, hidden, layer["attn_norm_w"], self.meta.epsilon) + self._debug_sync(f"layer {layer_idx} attn_norm") + + q = Tensor((len(inputs), self.local_q_dim), self.meta.dtype, self.device, self.device_id) + k = Tensor((len(inputs), self.local_kv_dim), self.meta.dtype, self.device, self.device_id) + v = Tensor((len(inputs), self.local_kv_dim), self.meta.dtype, self.device, self.device_id) + Ops.linear(q, normed, layer["attn_q_w"], layer["attn_q_b"]) + self._debug_sync(f"layer {layer_idx} attn_q") + Ops.linear(k, normed, layer["attn_k_w"], layer["attn_k_b"]) + self._debug_sync(f"layer {layer_idx} attn_k") + Ops.linear(v, normed, layer["attn_v_w"], layer["attn_v_b"]) + self._debug_sync(f"layer {layer_idx} attn_v") + + q_heads = q.view(len(inputs), self.local_nh, self.meta.dh) + k_heads = k.view(len(inputs), self.local_nkvh, self.meta.dh) + v_heads = v.view(len(inputs), self.local_nkvh, self.meta.dh) + Ops.rope(q_heads, q_heads, pos_ids, self.meta.theta) + self._debug_sync(f"layer {layer_idx} rope_q") + Ops.rope(k_heads, k_heads, pos_ids, self.meta.theta) + self._debug_sync(f"layer {layer_idx} rope_k") + + k_slot = self.k_caches[layer_idx].slice(0, self._cur_pos, self._cur_pos + len(inputs)) + v_slot = self.v_caches[layer_idx].slice(0, self._cur_pos, self._cur_pos + len(inputs)) + Ops.rearrange(k_slot, k_heads) + self._debug_sync(f"layer {layer_idx} cache_k") + Ops.rearrange(v_slot, v_heads) + self._debug_sync(f"layer {layer_idx} cache_v") + + k_full = self.k_caches[layer_idx].slice(0, 0, self._cur_pos + len(inputs)) + v_full = self.v_caches[layer_idx].slice(0, 0, self._cur_pos + len(inputs)) + attn_local = Tensor((len(inputs), self.local_nh, self.meta.dh), self.meta.dtype, self.device, self.device_id) + Ops.self_attention(attn_local, q_heads, k_full, v_full, scale) + self._debug_sync(f"layer {layer_idx} self_attention") + + attn_flat = attn_local.view(len(inputs), self.local_q_dim) + attn_partial = Tensor((len(inputs), self.meta.hs), self.meta.dtype, self.device, self.device_id) + Ops.linear(attn_partial, attn_flat, layer["attn_o_w"], None) + self._debug_sync(f"layer {layer_idx} attn_o") + self._all_reduce_tensor(attn_partial) + self._debug_sync(f"layer {layer_idx} attn_all_reduce") + Ops.add(hidden, hidden, attn_partial) + self._debug_sync(f"layer {layer_idx} attn_residual") + + mlp_normed = Tensor(hidden.shape(), self.meta.dtype, self.device, self.device_id) + Ops.rms_norm(mlp_normed, hidden, layer["mlp_norm_w"], self.meta.epsilon) + self._debug_sync(f"layer {layer_idx} mlp_norm") + gate = Tensor((len(inputs), self.local_di), self.meta.dtype, self.device, self.device_id) + up = Tensor((len(inputs), self.local_di), self.meta.dtype, self.device, self.device_id) + Ops.linear(gate, mlp_normed, layer["mlp_gate_w"], None) + self._debug_sync(f"layer {layer_idx} mlp_gate") + Ops.linear(up, mlp_normed, layer["mlp_up_w"], None) + self._debug_sync(f"layer {layer_idx} mlp_up") + swiglu_out = Tensor((len(inputs), self.local_di), self.meta.dtype, self.device, self.device_id) + Ops.swiglu(swiglu_out, gate, up) + self._debug_sync(f"layer {layer_idx} swiglu") + mlp_partial = Tensor((len(inputs), self.meta.hs), self.meta.dtype, self.device, self.device_id) + Ops.linear(mlp_partial, swiglu_out, layer["mlp_down_w"], None) + self._debug_sync(f"layer {layer_idx} mlp_down") + self._all_reduce_tensor(mlp_partial) + self._debug_sync(f"layer {layer_idx} mlp_all_reduce") + Ops.add(hidden, hidden, mlp_partial) + self._debug_sync(f"layer {layer_idx} mlp_residual") + _tp_log(self.rank, f"layer {layer_idx} done") + + self._cur_pos += len(inputs) + _tp_log(self.rank, f"generate_next layers done cur_pos={self._cur_pos}") + + next_token = -1 + if self.rank == 0: + normed = Tensor(hidden.shape(), self.meta.dtype, self.device, self.device_id) + Ops.rms_norm(normed, hidden, self.out_norm_w, self.meta.epsilon) + last_hidden = normed.slice(0, len(inputs) - 1, len(inputs)).view(1, self.meta.hs) + logits = Tensor((1, self.meta.voc), self.meta.dtype, self.device, self.device_id) + Ops.linear(logits, last_hidden, self.out_embed, None) + next_token = Ops.sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) + + return self._broadcast_token(next_token) + + def _debug_sync(self, label: str) -> None: + if os.getenv("LLAISYS_TP_DEBUG", "").lower() not in {"1", "true", "yes", "on"}: + return + self.runtime.set_device(self.device_id) + self.runtime.device_synchronize() + _tp_log(self.rank, f"{label} ok") + + def _all_reduce_tensor(self, tensor: Tensor) -> None: + buffer = self._to_torch(tensor) + dist.all_reduce(buffer) + self._from_torch(buffer, tensor) + + def _broadcast_token(self, token: int) -> int: + value = torch.tensor([int(token) if self.rank == 0 else 0], dtype=torch.int64, device=self.cuda_device) + dist.broadcast(value, src=0) + return int(value.item()) + + def _embedding(self, token_ids: List[int]) -> Tensor: + tokens = self._make_int64_tensor(np.asarray(token_ids, dtype=np.int64)) + hidden = Tensor((len(token_ids), self.meta.hs), self.meta.dtype, self.device, self.device_id) + Ops.embedding(hidden, tokens, self.in_embed) + return hidden + + def _to_torch(self, tensor: Tensor) -> torch.Tensor: + torch_tensor = torch.empty(tensor.shape(), dtype=self.torch_dtype, device=self.cuda_device) + self.runtime.set_device(self.device_id) + self.runtime.memcpy_sync( + torch_tensor.data_ptr(), + tensor.data_ptr(), + torch_tensor.numel() * torch_tensor.element_size(), + MemcpyKind.D2D, + ) + return torch_tensor + + def _from_torch(self, torch_tensor: torch.Tensor, tensor: Tensor) -> None: + self.runtime.set_device(self.device_id) + self.runtime.memcpy_sync( + tensor.data_ptr(), + torch_tensor.data_ptr(), + torch_tensor.numel() * torch_tensor.element_size(), + MemcpyKind.D2D, + ) + + def _make_int64_tensor(self, values: np.ndarray) -> Tensor: + tensor = Tensor(values.shape, DataType.I64, self.device, self.device_id) + tensor.load(ctypes.c_void_p(values.ctypes.data)) + return tensor + + def _tensor_from_array(self, array: np.ndarray, *, dtype: Optional[DataType] = None) -> Tensor: + target_dtype = self.meta.dtype if dtype is None else dtype + tensor = Tensor(array.shape, target_dtype, self.device, self.device_id) + tensor.load(ctypes.c_void_p(array.ctypes.data)) + self._tensors.append(tensor) + return tensor + + def _load_weights(self) -> None: + for file in sorted(self.model_path.glob("*.safetensors")): + for key, array, source_dtype in self._iter_safetensors(file): + self._assign_weight(key, array, source_dtype) + + def _assign_weight(self, name: str, array: np.ndarray, source_dtype: str) -> None: + if name == "model.embed_tokens.weight": + converted = self._convert_array(array, source_dtype) + self.in_embed = self._tensor_from_array(np.ascontiguousarray(converted)) + return + if name == "lm_head.weight": + converted = self._convert_array(array, source_dtype) + self.out_embed = self._tensor_from_array(np.ascontiguousarray(converted)) + return + if name == "model.norm.weight": + converted = self._convert_array(array, source_dtype) + self.out_norm_w = self._tensor_from_array(np.ascontiguousarray(converted)) + return + if not name.startswith("model.layers."): + return + + parts = name.split(".") + layer_idx = int(parts[2]) + suffix = ".".join(parts[3:]) + layer = self.layers[layer_idx] + + if suffix == "input_layernorm.weight": + layer["attn_norm_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(array, source_dtype))) + return + if suffix == "post_attention_layernorm.weight": + layer["mlp_norm_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(array, source_dtype))) + return + if suffix == "self_attn.q_proj.weight": + shard = self._shard_rows(array) + layer["attn_q_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype))) + return + if suffix == "self_attn.q_proj.bias": + shard = self._shard_rows_bias(array) + layer["attn_q_b"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype))) + return + if suffix == "self_attn.k_proj.weight": + shard = self._shard_rows(array, kv=True) + layer["attn_k_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype))) + return + if suffix == "self_attn.k_proj.bias": + shard = self._shard_rows_bias(array, kv=True) + layer["attn_k_b"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype))) + return + if suffix == "self_attn.v_proj.weight": + shard = self._shard_rows(array, kv=True) + layer["attn_v_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype))) + return + if suffix == "self_attn.v_proj.bias": + shard = self._shard_rows_bias(array, kv=True) + layer["attn_v_b"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype))) + return + if suffix == "self_attn.o_proj.weight": + shard = self._shard_cols(array, self.local_q_dim) + layer["attn_o_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype))) + return + if suffix == "mlp.gate_proj.weight": + shard = self._shard_rows(array, mlp=True) + layer["mlp_gate_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype))) + return + if suffix == "mlp.up_proj.weight": + shard = self._shard_rows(array, mlp=True) + layer["mlp_up_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype))) + return + if suffix == "mlp.down_proj.weight": + shard = self._shard_cols(array, self.local_di) + layer["mlp_down_w"] = self._tensor_from_array(np.ascontiguousarray(self._convert_array(shard, source_dtype))) + return + + def _shard_rows(self, array: np.ndarray, *, kv: bool = False, mlp: bool = False) -> np.ndarray: + if kv: + chunk = self.local_kv_dim + elif mlp: + chunk = self.local_di + else: + chunk = self.local_q_dim + start = self.rank * chunk + end = start + chunk + return np.asarray(array[start:end, :]) + + def _shard_rows_bias(self, array: np.ndarray, *, kv: bool = False) -> np.ndarray: + chunk = self.local_kv_dim if kv else self.local_q_dim + start = self.rank * chunk + end = start + chunk + return np.asarray(array[start:end]) + + def _shard_cols(self, array: np.ndarray, chunk: int) -> np.ndarray: + start = self.rank * chunk + end = start + chunk + return np.asarray(array[:, start:end]) + + def _iter_safetensors(self, path: Path): + with open(path, "rb") as handle: + length_bytes = handle.read(8) + if not length_bytes: + return + header_size = struct.unpack(" np.ndarray: + if self.meta.dtype == DataType.F32: + return Qwen2._to_float32_array(arr, source_dtype) + if self.meta.dtype == DataType.F16: + return Qwen2._to_float16_array(arr, source_dtype) + if self.meta.dtype == DataType.BF16: + return Qwen2._to_bfloat16_bytes(arr, source_dtype) + raise ValueError(f"Unsupported model dtype: {self.meta.dtype}") + + +def _tensor_parallel_worker_main( + rank: int, + world_size: int, + model_path: str, + device_ids: List[int], + master_addr: str, + master_port: int, + command_queue, + result_queue, +) -> None: + os.environ.setdefault("MASTER_ADDR", master_addr) + os.environ.setdefault("MASTER_PORT", str(master_port)) + os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") + + torch.cuda.set_device(device_ids[rank]) + _tp_log(rank, f"worker start device={device_ids[rank]}") + dist.init_process_group( + backend="nccl", + init_method=f"tcp://{master_addr}:{master_port}", + world_size=world_size, + rank=rank, + device_id=torch.device(f"cuda:{device_ids[rank]}"), + ) + worker = None + try: + worker = _ShardedQwen2Rank(model_path, rank, world_size, device_ids[rank]) + _tp_log(rank, "model loaded") + dist.barrier() + if rank == 0: + result_queue.put({"status": "ready"}) + + while True: + message = command_queue.get() + command = str(message["command"]) + seq = int(message["seq"]) + _tp_log(rank, f"command {command} seq={seq}") + + if command == "shutdown": + if rank == 0: + result_queue.put({"seq": seq, "status": "ok"}) + break + if command == "reset": + worker.reset() + if rank == 0: + result_queue.put({"seq": seq, "status": "ok"}) + continue + if command == "truncate": + worker.truncate(int(message["position"])) + if rank == 0: + result_queue.put({"seq": seq, "status": "ok"}) + continue + if command == "generate_next": + if bool(message.get("reset_state", False)): + worker.reset() + token = worker.generate_next( + message["inputs"], + top_k=int(message["top_k"]), + top_p=float(message["top_p"]), + temperature=float(message["temperature"]), + ) + if rank == 0: + result_queue.put({"seq": seq, "status": "ok", "token": int(token)}) + continue + raise ValueError(f"Unknown tensor-parallel command: {command}") + except Exception as exc: # pragma: no cover - defensive worker path + _tp_log(rank, f"worker exception {repr(exc)}") + result_queue.put({"status": "error", "rank": rank, "error": repr(exc)}) + raise + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +class TensorParallelQwen2: + SUPPORTED_MODEL_TYPES: ClassVar[set[str]] = {"qwen2"} + + @classmethod + def supports_model_type(cls, model_type: str) -> bool: + return str(model_type).lower() in cls.SUPPORTED_MODEL_TYPES + + def __init__( + self, + model_path: str, + device: DeviceType = DeviceType.NVIDIA, + device_id: int = 0, + *, + tp_size: int = 2, + tp_device_ids: Optional[Sequence[int]] = None, + ): + if device != DeviceType.NVIDIA: + raise ValueError("TensorParallelQwen2 currently supports only NVIDIA devices") + if int(tp_size) <= 1: + raise ValueError("tp_size must be greater than 1 for tensor parallel inference") + + self.model_path = str(model_path) + self.device = device + self.device_id = int(device_id) + self.device_ids = _resolve_tp_device_ids(device_id, tp_size, tp_device_ids) + self.tp_size = len(self.device_ids) + + with open(Path(model_path) / "config.json", "r") as handle: + config = json.load(handle) + self.config = config + self.model_type = str(config.get("model_type", "")).lower() + if not self.supports_model_type(self.model_type): + raise ValueError(f"Unsupported tensor-parallel model type: {self.model_type or ''}") + + dtype = Qwen2._runtime_dtype(str(config.get("torch_dtype", "float32")), DeviceType.NVIDIA) + eos_token = config.get("eos_token_id", 151643) + if isinstance(eos_token, list): + eos_token = eos_token[0] if eos_token else 151643 + self.meta = _TensorParallelMeta( + dtype=dtype, + nlayer=int(config.get("num_hidden_layers", 24)), + hs=int(config.get("hidden_size", 2048)), + nh=int(config.get("num_attention_heads", 16)), + nkvh=int(config.get("num_key_value_heads", int(config.get("num_attention_heads", 16)))), + dh=int(config.get("hidden_size", 2048)) // int(config.get("num_attention_heads", 16)), + di=int(config.get("intermediate_size", 11008)), + maxseq=int(config.get("max_position_embeddings", 8192)), + voc=int(config.get("vocab_size", 151936)), + epsilon=float(config.get("rms_norm_eps", 1e-6)), + theta=float(config.get("rope_theta", 1000000.0)), + end_token=int(eos_token), + ) + _validate_tp_size(self.meta, self.tp_size) + + ctx = mp.get_context("spawn") + self._command_queues = [ctx.Queue() for _ in range(self.tp_size)] + self._result_queue = ctx.Queue() + self._processes = [] + self._closed = False + self._seq = 0 + self._master_addr = "127.0.0.1" + self._master_port = _find_free_port() + + for rank in range(self.tp_size): + process = ctx.Process( + target=_tensor_parallel_worker_main, + args=( + rank, + self.tp_size, + self.model_path, + self.device_ids, + self._master_addr, + self._master_port, + self._command_queues[rank], + self._result_queue, + ), + daemon=True, + ) + process.start() + self._processes.append(process) + + self._await_ready() + + def __del__(self): + try: + self.close() + except Exception: + pass + + def close(self) -> None: + if self._closed: + return + self._closed = True + try: + self._request("shutdown") + except Exception: + pass + for process in self._processes: + process.join(timeout=10) + if process.is_alive(): + process.kill() + self._processes.clear() + + def reset(self) -> None: + self._request("reset") + + def truncate(self, position: int) -> None: + self._request("truncate", position=int(position)) + + def generate_next( + self, + inputs: Sequence[int], + top_k: int = 1, + top_p: float = 1.0, + temperature: float = 1.0, + *, + reset_state: bool = False, + ) -> int: + if not inputs: + raise ValueError("inputs must not be empty") + response = self._request( + "generate_next", + inputs=[int(token) for token in inputs], + top_k=int(top_k), + top_p=float(top_p), + temperature=float(temperature), + reset_state=bool(reset_state), + ) + return int(response["token"]) + + def generate( + self, + inputs: Sequence[int], + max_new_tokens: int = 20, + top_k: int = 1, + top_p: float = 0.8, + temperature: float = 0.8, + *, + reset_state: bool = True, + ) -> List[int]: + if not inputs: + raise ValueError("inputs must not be empty") + if reset_state: + self.reset() + + generated: List[int] = [] + token_source = list(inputs) + for _ in range(int(max_new_tokens)): + next_token = self.generate_next( + token_source, + top_k=top_k, + top_p=top_p, + temperature=temperature, + reset_state=False, + ) + generated.append(next_token) + token_source = [next_token] + if next_token == self.meta.end_token: + break + return list(inputs) + generated + + def stream_generate( + self, + inputs: Sequence[int], + *, + tokenizer=None, + max_new_tokens: int = 20, + top_k: int = 1, + top_p: float = 0.8, + temperature: float = 0.8, + reset_state: bool = True, + ) -> Iterator[Tuple[int, str]]: + if reset_state: + self.reset() + generated: List[int] = [] + previous_text = "" + token_source = list(inputs) + + for _ in range(int(max_new_tokens)): + next_token = self.generate_next( + token_source, + top_k=top_k, + top_p=top_p, + temperature=temperature, + reset_state=False, + ) + generated.append(next_token) + token_source = [next_token] + + text_chunk = "" + if tokenizer is not None: + decoded = tokenizer.decode( + generated, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + text_chunk = decoded[len(previous_text):] if decoded.startswith(previous_text) else decoded + previous_text = decoded + + yield next_token, text_chunk + if next_token == self.meta.end_token: + break + + def _request(self, command: str, **payload): + self._seq += 1 + message = _normalize_command_payload(command, self._seq, **payload) + for queue_ in self._command_queues: + queue_.put(message) + + deadline_s = 600.0 + waited_s = 0.0 + while True: + try: + response = self._result_queue.get(timeout=5) + break + except queue.Empty: + waited_s += 5.0 + dead = { + process.pid: process.exitcode + for process in self._processes + if process.exitcode not in (None, 0) + } + if dead: + raise RuntimeError( + f"Tensor-parallel worker exited while handling `{command}`: {dead}" + ) + if waited_s >= deadline_s: + raise TimeoutError(f"Timed out waiting for tensor-parallel command `{command}`") + + if response.get("status") == "error": + raise RuntimeError(response.get("error", "Unknown tensor-parallel worker error")) + if response.get("seq") != self._seq: + raise RuntimeError(f"Unexpected tensor-parallel response ordering: {response}") + return response + + def _await_ready(self) -> None: + while True: + try: + ready = self._result_queue.get(timeout=5) + except queue.Empty: + dead = [process.pid for process in self._processes if process.exitcode not in (None, 0)] + if dead: + raise RuntimeError(f"Tensor-parallel worker exited before ready: pids={dead}") + continue + if ready.get("status") == "ready": + return + if ready.get("status") == "error": + raise RuntimeError( + f"Tensor-parallel worker rank {ready.get('rank')} failed to start: {ready.get('error')}" + ) + raise RuntimeError(f"Unexpected tensor-parallel startup response: {ready}") diff --git a/python/llaisys/ops.py b/python/llaisys/ops.py index ed0180bc..0e3b3b81 100644 --- a/python/llaisys/ops.py +++ b/python/llaisys/ops.py @@ -21,7 +21,21 @@ def embedding(out: Tensor, index: Tensor, weight: Tensor): @staticmethod def linear(out: Tensor, inp: Tensor, weight: Tensor, bias: Tensor): LIB_LLAISYS.llaisysLinear( - out.lib_tensor(), inp.lib_tensor(), weight.lib_tensor(), bias.lib_tensor() + out.lib_tensor(), + inp.lib_tensor(), + weight.lib_tensor(), + bias.lib_tensor() if bias is not None else None, + ) + + @staticmethod + def sample(logits: Tensor, top_k: int = 1, top_p: float = 1.0, temperature: float = 1.0) -> int: + return int( + LIB_LLAISYS.llaisysSample( + logits.lib_tensor(), + c_int(top_k), + c_float(top_p), + c_float(temperature), + ) ) @staticmethod diff --git a/python/setup.cfg b/python/setup.cfg index b35fc65f..3aeef424 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -14,6 +14,17 @@ install_requires = transformers accelerate +[options.extras_require] +server = + fastapi + uvicorn + httpx + +[options.entry_points] +console_scripts = + llaisys-chat-server = llaisys.chat_server:main + llaisys-chat-cli = llaisys.chat_cli:main + [options.package_data] llaisys = libllaisys/*.so diff --git a/src/core/context/context.cpp b/src/core/context/context.cpp index 44894b9e..63756fab 100644 --- a/src/core/context/context.cpp +++ b/src/core/context/context.cpp @@ -52,7 +52,7 @@ Context::~Context() { void Context::setDevice(llaisysDeviceType_t device_type, int device_id) { // If doest not match the current runtime. if (_current_runtime == nullptr || _current_runtime->deviceType() != device_type || _current_runtime->deviceId() != device_id) { - auto runtimes = _runtime_map[device_type]; + auto &runtimes = _runtime_map[device_type]; CHECK_ARGUMENT((size_t)device_id < runtimes.size() && device_id >= 0, "invalid device id"); if (_current_runtime != nullptr) { _current_runtime->_deactivate(); diff --git a/src/core/runtime/runtime.cpp b/src/core/runtime/runtime.cpp index 7f03a862..7d10fae7 100644 --- a/src/core/runtime/runtime.cpp +++ b/src/core/runtime/runtime.cpp @@ -7,6 +7,7 @@ namespace llaisys::core { Runtime::Runtime(llaisysDeviceType_t device_type, int device_id) : _device_type(device_type), _device_id(device_id), _is_active(false) { _api = llaisys::device::getRuntimeAPI(_device_type); + _api->set_device(_device_id); _stream = _api->create_stream(); _allocator = new allocators::NaiveAllocator(_api); } @@ -17,6 +18,7 @@ Runtime::~Runtime() { } delete _allocator; _allocator = nullptr; + _api->set_device(_device_id); _api->destroy_stream(_stream); _api = nullptr; } diff --git a/src/device/nvidia/cuda_utils.cuh b/src/device/nvidia/cuda_utils.cuh new file mode 100644 index 00000000..5096013c --- /dev/null +++ b/src/device/nvidia/cuda_utils.cuh @@ -0,0 +1,160 @@ +#pragma once + +#include "../../utils.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace llaisys::device::nvidia::cuda_utils { + +inline void checkCuda(cudaError_t status, const char *expr, const char *file, int line) { + if (status == cudaSuccess) { + return; + } + std::cerr << "[ERROR] CUDA call failed: " << expr << " -> " << cudaGetErrorString(status) + << " at " << file << ":" << line << std::endl; + throw std::runtime_error("CUDA error"); +} + +inline const char *cublasStatusName(cublasStatus_t status) { + switch (status) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + default: + return "CUBLAS_STATUS_UNKNOWN"; + } +} + +inline void checkCublas(cublasStatus_t status, const char *expr, const char *file, int line) { + if (status == CUBLAS_STATUS_SUCCESS) { + return; + } + std::cerr << "[ERROR] cuBLAS call failed: " << expr << " -> " << cublasStatusName(status) + << " at " << file << ":" << line << std::endl; + throw std::runtime_error("cuBLAS error"); +} + +#define LLAISYS_CUDA_CHECK(EXPR__) ::llaisys::device::nvidia::cuda_utils::checkCuda((EXPR__), #EXPR__, __FILE__, __LINE__) +#define LLAISYS_CUBLAS_CHECK(EXPR__) ::llaisys::device::nvidia::cuda_utils::checkCublas((EXPR__), #EXPR__, __FILE__, __LINE__) + +inline cudaMemcpyKind memcpyKind(llaisysMemcpyKind_t kind) { + switch (kind) { + case LLAISYS_MEMCPY_H2H: + return cudaMemcpyHostToHost; + case LLAISYS_MEMCPY_H2D: + return cudaMemcpyHostToDevice; + case LLAISYS_MEMCPY_D2H: + return cudaMemcpyDeviceToHost; + case LLAISYS_MEMCPY_D2D: + return cudaMemcpyDeviceToDevice; + default: + throw std::invalid_argument("Unsupported memcpy kind"); + } +} + +inline cudaDataType_t cublasDataType(llaisysDataType_t dtype) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return CUDA_R_32F; + case LLAISYS_DTYPE_F16: + return CUDA_R_16F; + case LLAISYS_DTYPE_BF16: + return CUDA_R_16BF; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +inline cublasComputeType_t cublasComputeType(llaisysDataType_t dtype) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + case LLAISYS_DTYPE_F16: + case LLAISYS_DTYPE_BF16: + return CUBLAS_COMPUTE_32F; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +template +__device__ __forceinline__ float toFloat(T value) { + return static_cast(value); +} + +template <> +__device__ __forceinline__ float toFloat(float value) { + return value; +} + +template <> +__device__ __forceinline__ float toFloat(llaisys::fp16_t value) { + union { + __half half_value; + __half_raw raw; + } bits{}; + bits.raw.x = value._v; + return __half2float(bits.half_value); +} + +template <> +__device__ __forceinline__ float toFloat(llaisys::bf16_t value) { + union { + __nv_bfloat16 bf16_value; + __nv_bfloat16_raw raw; + } bits{}; + bits.raw.x = value._v; + return __bfloat162float(bits.bf16_value); +} + +template +__device__ __forceinline__ T fromFloat(float value) { + return static_cast(value); +} + +template <> +__device__ __forceinline__ float fromFloat(float value) { + return value; +} + +template <> +__device__ __forceinline__ llaisys::fp16_t fromFloat(float value) { + union { + __half half_value; + __half_raw raw; + } bits{}; + bits.half_value = __float2half_rn(value); + return llaisys::fp16_t{bits.raw.x}; +} + +template <> +__device__ __forceinline__ llaisys::bf16_t fromFloat(float value) { + union { + __nv_bfloat16 bf16_value; + __nv_bfloat16_raw raw; + } bits{}; + bits.bf16_value = __float2bfloat16(value); + return llaisys::bf16_t{bits.raw.x}; +} + +} // namespace llaisys::device::nvidia::cuda_utils diff --git a/src/device/nvidia/nvidia_resource.cu b/src/device/nvidia/nvidia_resource.cu index 2e63647e..84c59c6f 100644 --- a/src/device/nvidia/nvidia_resource.cu +++ b/src/device/nvidia/nvidia_resource.cu @@ -3,5 +3,6 @@ namespace llaisys::device::nvidia { Resource::Resource(int device_id) : llaisys::device::DeviceResource(LLAISYS_DEVICE_NVIDIA, device_id) {} +Resource::~Resource() = default; } // namespace llaisys::device::nvidia diff --git a/src/device/nvidia/nvidia_runtime_api.cu b/src/device/nvidia/nvidia_runtime_api.cu index cab92826..ba95472a 100644 --- a/src/device/nvidia/nvidia_runtime_api.cu +++ b/src/device/nvidia/nvidia_runtime_api.cu @@ -1,56 +1,80 @@ #include "../runtime_api.hpp" - -#include -#include +#include "cuda_utils.cuh" namespace llaisys::device::nvidia { namespace runtime_api { int getDeviceCount() { - TO_BE_IMPLEMENTED(); + int count = 0; + LLAISYS_CUDA_CHECK(cudaGetDeviceCount(&count)); + return count; } -void setDevice(int) { - TO_BE_IMPLEMENTED(); +void setDevice(int device_id) { + LLAISYS_CUDA_CHECK(cudaSetDevice(device_id)); } void deviceSynchronize() { - TO_BE_IMPLEMENTED(); + LLAISYS_CUDA_CHECK(cudaDeviceSynchronize()); } llaisysStream_t createStream() { - TO_BE_IMPLEMENTED(); + cudaStream_t stream = nullptr; + LLAISYS_CUDA_CHECK(cudaStreamCreate(&stream)); + return reinterpret_cast(stream); } void destroyStream(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + if (stream == nullptr) { + return; + } + LLAISYS_CUDA_CHECK(cudaStreamDestroy(reinterpret_cast(stream))); } void streamSynchronize(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + if (stream == nullptr) { + return; + } + LLAISYS_CUDA_CHECK(cudaStreamSynchronize(reinterpret_cast(stream))); } void *mallocDevice(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr = nullptr; + LLAISYS_CUDA_CHECK(cudaMalloc(&ptr, size)); + return ptr; } void freeDevice(void *ptr) { - TO_BE_IMPLEMENTED(); + if (ptr == nullptr) { + return; + } + LLAISYS_CUDA_CHECK(cudaFree(ptr)); } void *mallocHost(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr = nullptr; + LLAISYS_CUDA_CHECK(cudaMallocHost(&ptr, size)); + return ptr; } void freeHost(void *ptr) { - TO_BE_IMPLEMENTED(); + if (ptr == nullptr) { + return; + } + LLAISYS_CUDA_CHECK(cudaFreeHost(ptr)); } void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); + LLAISYS_CUDA_CHECK(cudaDeviceSynchronize()); + LLAISYS_CUDA_CHECK(cudaMemcpy(dst, src, size, cuda_utils::memcpyKind(kind))); } -void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); +void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) { + LLAISYS_CUDA_CHECK(cudaMemcpyAsync( + dst, + src, + size, + cuda_utils::memcpyKind(kind), + reinterpret_cast(stream))); } static const LlaisysRuntimeAPI RUNTIME_API = { diff --git a/src/llaisys/models/qwen2.cc b/src/llaisys/models/qwen2.cc new file mode 100644 index 00000000..209ea8b1 --- /dev/null +++ b/src/llaisys/models/qwen2.cc @@ -0,0 +1,31 @@ +#include "llaisys/models/qwen2.h" +#include "../../models/qwen2/model.hpp" + +using namespace llaisys::models::qwen2; + +extern "C" { + +struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) { + int dev_id = (ndevice > 0 && device_ids != nullptr) ? device_ids[0] : 0; + Qwen2Model* model = new Qwen2Model(*meta, device, dev_id); + return reinterpret_cast(model); +} + +void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model) { + if (model) { + delete reinterpret_cast(model); + } +} + +struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model) { + if (!model) return nullptr; + return reinterpret_cast(model)->getWeightsStruct(); +} + +int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken) { + if (!model) return -1; + std::vector tokens(token_ids, token_ids + ntoken); + return reinterpret_cast(model)->infer(tokens); +} + +} diff --git a/src/llaisys/ops.cc b/src/llaisys/ops.cc index c99fbc32..924978ba 100644 --- a/src/llaisys/ops.cc +++ b/src/llaisys/ops.cc @@ -6,6 +6,7 @@ #include "../ops/argmax/op.hpp" #include "../ops/embedding/op.hpp" #include "../ops/linear/op.hpp" +#include "../ops/sample/op.hpp" #include "../ops/rearrange/op.hpp" #include "../ops/rms_norm/op.hpp" #include "../ops/rope/op.hpp" @@ -23,7 +24,10 @@ __C { llaisys::ops::embedding(out->tensor, index->tensor, weight->tensor); } void llaisysLinear(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, llaisysTensor_t bias) { - llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, bias->tensor); + llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, bias ? bias->tensor : nullptr); + } + int64_t llaisysSample(llaisysTensor_t logits, int top_k, float top_p, float temperature) { + return llaisys::ops::sample(logits->tensor, top_k, top_p, temperature); } void llaisysRearrange(llaisysTensor_t out, llaisysTensor_t in) { llaisys::ops::rearrange(out->tensor, in->tensor); diff --git a/src/llaisys/qwen2.cc b/src/llaisys/qwen2.cc new file mode 100644 index 00000000..24936310 --- /dev/null +++ b/src/llaisys/qwen2.cc @@ -0,0 +1,54 @@ +#include "llaisys/models/qwen2.h" +#include "../models/qwen2/qwen2.hpp" + +extern "C" { + +struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) { + if (!meta || ndevice < 1) return nullptr; + // For now support single device + int device_id = device_ids ? device_ids[0] : 0; + + // Copy meta + LlaisysQwen2Meta cpp_meta = *meta; + + auto* model = new llaisys::Qwen2Model(cpp_meta, device, device_id); + return reinterpret_cast(model); +} + +void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model) { + if (model) { + delete reinterpret_cast(model); + } +} + +struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model) { + if (!model) return nullptr; + auto* cpp_model = reinterpret_cast(model); + return cpp_model->getWeights(); +} + +void llaisysQwen2ModelReset(struct LlaisysQwen2Model * model) { + if (!model) return; + auto* cpp_model = reinterpret_cast(model); + cpp_model->reset(); +} + +void llaisysQwen2ModelTruncate(struct LlaisysQwen2Model * model, size_t position) { + if (!model) return; + auto* cpp_model = reinterpret_cast(model); + cpp_model->truncate(position); +} + +int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken) { + if (!model) return -1; + auto* cpp_model = reinterpret_cast(model); + return cpp_model->infer(token_ids, ntoken); +} + +int64_t llaisysQwen2ModelGenerateNext(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken, int top_k, float top_p, float temperature) { + if (!model) return -1; + auto* cpp_model = reinterpret_cast(model); + return cpp_model->generateNext(token_ids, ntoken, top_k, top_p, temperature); +} + +} diff --git a/src/models/qwen2/qwen2.cpp b/src/models/qwen2/qwen2.cpp new file mode 100644 index 00000000..60340e7b --- /dev/null +++ b/src/models/qwen2/qwen2.cpp @@ -0,0 +1,328 @@ +#include "qwen2.hpp" +#include "llaisys/ops.h" +#include "../../ops/add/op.hpp" +#include "../../ops/argmax/op.hpp" +#include "../../ops/embedding/op.hpp" +#include "../../ops/linear/op.hpp" +#include "../../ops/sample/op.hpp" +#include "../../ops/rearrange/op.hpp" +#include "../../ops/rms_norm/op.hpp" +#include "../../ops/rope/op.hpp" +#include "../../ops/self_attention/op.hpp" +#include "../../ops/swiglu/op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils/check.hpp" +#include +#include +#include +#include "../../llaisys/llaisys_tensor.hpp" + +namespace llaisys { + +inline tensor_t to_cpp(llaisysTensor_t t) { + if (!t) return nullptr; + return reinterpret_cast(t)->tensor; +} + +inline tensor_t to_cpp(llaisysTensor_t* t_array, size_t idx) { + if (!t_array || !t_array[idx]) return nullptr; + return reinterpret_cast(t_array[idx])->tensor; +} + +Qwen2Model::Qwen2Model(const LlaisysQwen2Meta& meta, llaisysDeviceType_t device, int device_id) + : _meta(meta), _device_type(device), _device_id(device_id) { + + _weights.attn_norm_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_q_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_q_b = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_k_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_k_b = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_v_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_v_b = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_o_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.mlp_norm_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.mlp_gate_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.mlp_up_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.mlp_down_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + + init_buffers(); +} + +Qwen2Model::~Qwen2Model() { + free(_weights.attn_norm_w); + free(_weights.attn_q_w); + free(_weights.attn_q_b); + free(_weights.attn_k_w); + free(_weights.attn_k_b); + free(_weights.attn_v_w); + free(_weights.attn_v_b); + free(_weights.attn_o_w); + free(_weights.mlp_norm_w); + free(_weights.mlp_gate_w); + free(_weights.mlp_up_w); + free(_weights.mlp_down_w); +} + +void Qwen2Model::init_buffers() { + core::context().setDevice(_device_type, _device_id); + + const size_t q_dim = _meta.nh * _meta.dh; + const size_t k_dim = _meta.nkvh * _meta.dh; + const float scale = 1.0f / std::sqrt(static_cast(_meta.dh)); + (void)scale; + + for(size_t i=0; i<_meta.nlayer; ++i) { + _kv_caches.push_back({ + Tensor::create({_meta.maxseq, _meta.nkvh, _meta.dh}, _meta.dtype, _device_type, _device_id), + Tensor::create({_meta.maxseq, _meta.nkvh, _meta.dh}, _meta.dtype, _device_type, _device_id) + }); + + auto q_flat = Tensor::create({1, q_dim}, _meta.dtype, _device_type, _device_id); + auto k_flat = Tensor::create({1, k_dim}, _meta.dtype, _device_type, _device_id); + auto v_flat = Tensor::create({1, k_dim}, _meta.dtype, _device_type, _device_id); + auto attn_val_3d = Tensor::create({1, _meta.nh, _meta.dh}, _meta.dtype, _device_type, _device_id); + + _decode_layers.push_back({ + q_flat, + k_flat, + v_flat, + q_flat->view({1, _meta.nh, _meta.dh}), + k_flat->view({1, _meta.nkvh, _meta.dh}), + v_flat->view({1, _meta.nkvh, _meta.dh}), + attn_val_3d, + attn_val_3d->view({1, _meta.hs}), + Tensor::create({1, _meta.di}, _meta.dtype, _device_type, _device_id), + Tensor::create({1, _meta.di}, _meta.dtype, _device_type, _device_id), + Tensor::create({1, _meta.di}, _meta.dtype, _device_type, _device_id) + }); + } + + _hidden_states = Tensor::create({1, _meta.hs}, _meta.dtype, _device_type, _device_id); + _ln_out = Tensor::create({1, _meta.hs}, _meta.dtype, _device_type, _device_id); + _attn_out = Tensor::create({1, _meta.hs}, _meta.dtype, _device_type, _device_id); + _mlp_out = Tensor::create({1, _meta.hs}, _meta.dtype, _device_type, _device_id); + _logits = Tensor::create({1, _meta.voc}, _meta.dtype, _device_type, _device_id); + _tokens_tensor = Tensor::create({1}, LLAISYS_DTYPE_I64, _device_type, _device_id); + _single_pos_id = Tensor::create({1}, LLAISYS_DTYPE_I64, _device_type, _device_id); + _pos_ids = Tensor::create({_meta.maxseq}, LLAISYS_DTYPE_I64, _device_type, _device_id); + _argmax_idx = Tensor::create({1}, LLAISYS_DTYPE_I64, _device_type, _device_id); + _argmax_val = Tensor::create({1}, _meta.dtype, _device_type, _device_id); +} + +void Qwen2Model::ensure_prefill_buffers(size_t ntoken) { + if (_prefill.capacity >= ntoken) { + return; + } + + const size_t next_capacity = std::max(ntoken, _prefill.capacity == 0 ? ntoken : _prefill.capacity * 2); + const size_t q_dim = _meta.nh * _meta.dh; + const size_t k_dim = _meta.nkvh * _meta.dh; + + _prefill.capacity = next_capacity; + _prefill.tokens = Tensor::create({next_capacity}, LLAISYS_DTYPE_I64, _device_type, _device_id); + _prefill.hidden_states = Tensor::create({next_capacity, _meta.hs}, _meta.dtype, _device_type, _device_id); + _prefill.normed = Tensor::create({next_capacity, _meta.hs}, _meta.dtype, _device_type, _device_id); + _prefill.q_flat = Tensor::create({next_capacity, q_dim}, _meta.dtype, _device_type, _device_id); + _prefill.k_flat = Tensor::create({next_capacity, k_dim}, _meta.dtype, _device_type, _device_id); + _prefill.v_flat = Tensor::create({next_capacity, k_dim}, _meta.dtype, _device_type, _device_id); + _prefill.attn_val_3d = Tensor::create({next_capacity, _meta.nh, _meta.dh}, _meta.dtype, _device_type, _device_id); + _prefill.attn_out = Tensor::create({next_capacity, _meta.hs}, _meta.dtype, _device_type, _device_id); + _prefill.gate = Tensor::create({next_capacity, _meta.di}, _meta.dtype, _device_type, _device_id); + _prefill.up = Tensor::create({next_capacity, _meta.di}, _meta.dtype, _device_type, _device_id); + _prefill.swiglu = Tensor::create({next_capacity, _meta.di}, _meta.dtype, _device_type, _device_id); + _prefill.mlp_out = Tensor::create({next_capacity, _meta.hs}, _meta.dtype, _device_type, _device_id); +} + +tensor_t Qwen2Model::forward(const int64_t* token_ids, size_t ntoken) { + CHECK_ARGUMENT(token_ids != nullptr, "token_ids must not be null"); + CHECK_ARGUMENT(ntoken > 0, "ntoken must be greater than zero"); + CHECK_ARGUMENT(_cur_pos + ntoken <= _meta.maxseq, "sequence exceeds KV-cache capacity"); + + core::context().setDevice(_device_type, _device_id); + + if (ntoken == 1) { + return forward_single_token(token_ids); + } + + ensure_prefill_buffers(ntoken); + + tensor_t input_tokens = _prefill.tokens->slice(0, 0, ntoken); + input_tokens->load(token_ids); + + tensor_t current_pos_ids = _pos_ids->slice(0, 0, ntoken); + std::vector pos_data(ntoken); + for(size_t i=0; iload(pos_data.data()); + + std::vector seq_shape = {ntoken, _meta.hs}; + + tensor_t hidden_states = _prefill.hidden_states->slice(0, 0, ntoken); + hidden_states = hidden_states->view(seq_shape); + ops::embedding(hidden_states, input_tokens, to_cpp(_weights.in_embed)); + + const size_t q_dim = _meta.nh * _meta.dh; + const size_t k_dim = _meta.nkvh * _meta.dh; + const float scale = 1.0f / std::sqrt(static_cast(_meta.dh)); + + for(size_t i=0; i<_meta.nlayer; ++i) { + tensor_t normed = _prefill.normed->slice(0, 0, ntoken); + normed = normed->view(seq_shape); + ops::rms_norm(normed, hidden_states, to_cpp(_weights.attn_norm_w, i), _meta.epsilon); + + tensor_t q = _prefill.q_flat->slice(0, 0, ntoken); + q = q->view({ntoken, q_dim}); + tensor_t k = _prefill.k_flat->slice(0, 0, ntoken); + k = k->view({ntoken, k_dim}); + tensor_t v = _prefill.v_flat->slice(0, 0, ntoken); + v = v->view({ntoken, k_dim}); + + ops::linear(q, normed, to_cpp(_weights.attn_q_w, i), to_cpp(_weights.attn_q_b, i)); + ops::linear(k, normed, to_cpp(_weights.attn_k_w, i), to_cpp(_weights.attn_k_b, i)); + ops::linear(v, normed, to_cpp(_weights.attn_v_w, i), to_cpp(_weights.attn_v_b, i)); + + q = q->view({ntoken, _meta.nh, _meta.dh}); + k = k->view({ntoken, _meta.nkvh, _meta.dh}); + v = v->view({ntoken, _meta.nkvh, _meta.dh}); + + ops::rope(q, q, current_pos_ids, _meta.theta); + ops::rope(k, k, current_pos_ids, _meta.theta); + + tensor_t k_cache_slot = _kv_caches[i].k->slice(0, _cur_pos, _cur_pos + ntoken); + tensor_t v_cache_slot = _kv_caches[i].v->slice(0, _cur_pos, _cur_pos + ntoken); + + ops::rearrange(k_cache_slot, k); + ops::rearrange(v_cache_slot, v); + + tensor_t k_full = _kv_caches[i].k->slice(0, 0, _cur_pos + ntoken); + tensor_t v_full = _kv_caches[i].v->slice(0, 0, _cur_pos + ntoken); + + tensor_t attn_val = _prefill.attn_val_3d->slice(0, 0, ntoken); + attn_val = attn_val->view({ntoken, _meta.nh, _meta.dh}); + + ops::self_attention(attn_val, q, k_full, v_full, scale); + + attn_val = attn_val->view({ntoken, _meta.hs}); + + tensor_t attn_output = _prefill.attn_out->slice(0, 0, ntoken); + attn_output = attn_output->view(seq_shape); + ops::linear(attn_output, attn_val, to_cpp(_weights.attn_o_w, i), nullptr); + + ops::add(hidden_states, hidden_states, attn_output); + + normed = _prefill.normed->slice(0, 0, ntoken); + normed = normed->view(seq_shape); + ops::rms_norm(normed, hidden_states, to_cpp(_weights.mlp_norm_w, i), _meta.epsilon); + + tensor_t gate = _prefill.gate->slice(0, 0, ntoken); + gate = gate->view({ntoken, _meta.di}); + tensor_t up = _prefill.up->slice(0, 0, ntoken); + up = up->view({ntoken, _meta.di}); + + ops::linear(gate, normed, to_cpp(_weights.mlp_gate_w, i), nullptr); + ops::linear(up, normed, to_cpp(_weights.mlp_up_w, i), nullptr); + + tensor_t swiglu_out = _prefill.swiglu->slice(0, 0, ntoken); + swiglu_out = swiglu_out->view({ntoken, _meta.di}); + ops::swiglu(swiglu_out, gate, up); + + tensor_t mlp_output = _prefill.mlp_out->slice(0, 0, ntoken); + mlp_output = mlp_output->view(seq_shape); + ops::linear(mlp_output, swiglu_out, to_cpp(_weights.mlp_down_w, i), nullptr); + + ops::add(hidden_states, hidden_states, mlp_output); + } + + tensor_t final_normed = _prefill.normed->slice(0, 0, ntoken); + final_normed = final_normed->view(seq_shape); + ops::rms_norm(final_normed, hidden_states, to_cpp(_weights.out_norm_w), _meta.epsilon); + + tensor_t last_hidden = final_normed->slice(0, ntoken-1, ntoken); + last_hidden = last_hidden->view({1, _meta.hs}); + + ops::linear(_logits, last_hidden, to_cpp(_weights.out_embed), nullptr); + + _cur_pos += ntoken; + + return _logits; +} + +tensor_t Qwen2Model::forward_single_token(const int64_t* token_id) { + CHECK_ARGUMENT(token_id != nullptr, "token_id must not be null"); + + _tokens_tensor->load(token_id); + const int64_t current_pos = static_cast(_cur_pos); + _single_pos_id->load(¤t_pos); + + ops::embedding(_hidden_states, _tokens_tensor, to_cpp(_weights.in_embed)); + + const float scale = 1.0f / std::sqrt(static_cast(_meta.dh)); + for (size_t i = 0; i < _meta.nlayer; ++i) { + auto &layer = _decode_layers[i]; + + ops::rms_norm(_ln_out, _hidden_states, to_cpp(_weights.attn_norm_w, i), _meta.epsilon); + + ops::linear(layer.q_flat, _ln_out, to_cpp(_weights.attn_q_w, i), to_cpp(_weights.attn_q_b, i)); + ops::linear(layer.k_flat, _ln_out, to_cpp(_weights.attn_k_w, i), to_cpp(_weights.attn_k_b, i)); + ops::linear(layer.v_flat, _ln_out, to_cpp(_weights.attn_v_w, i), to_cpp(_weights.attn_v_b, i)); + + ops::rope(layer.q_view, layer.q_view, _single_pos_id, _meta.theta); + ops::rope(layer.k_view, layer.k_view, _single_pos_id, _meta.theta); + + tensor_t k_cache_slot = _kv_caches[i].k->slice(0, _cur_pos, _cur_pos + 1); + tensor_t v_cache_slot = _kv_caches[i].v->slice(0, _cur_pos, _cur_pos + 1); + ops::rearrange(k_cache_slot, layer.k_view); + ops::rearrange(v_cache_slot, layer.v_view); + + tensor_t k_full = _kv_caches[i].k->slice(0, 0, _cur_pos + 1); + tensor_t v_full = _kv_caches[i].v->slice(0, 0, _cur_pos + 1); + ops::self_attention(layer.attn_val_3d, layer.q_view, k_full, v_full, scale); + + ops::linear(_attn_out, layer.attn_val_2d, to_cpp(_weights.attn_o_w, i), nullptr); + ops::add(_hidden_states, _hidden_states, _attn_out); + + ops::rms_norm(_ln_out, _hidden_states, to_cpp(_weights.mlp_norm_w, i), _meta.epsilon); + ops::linear(layer.gate, _ln_out, to_cpp(_weights.mlp_gate_w, i), nullptr); + ops::linear(layer.up, _ln_out, to_cpp(_weights.mlp_up_w, i), nullptr); + ops::swiglu(layer.swiglu, layer.gate, layer.up); + ops::linear(_mlp_out, layer.swiglu, to_cpp(_weights.mlp_down_w, i), nullptr); + ops::add(_hidden_states, _hidden_states, _mlp_out); + } + + ops::rms_norm(_ln_out, _hidden_states, to_cpp(_weights.out_norm_w), _meta.epsilon); + ops::linear(_logits, _ln_out, to_cpp(_weights.out_embed), nullptr); + + _cur_pos += 1; + return _logits; +} + +int64_t Qwen2Model::infer(const int64_t* token_ids, size_t ntoken) { + tensor_t logits = forward(token_ids, ntoken); + ops::argmax(_argmax_idx, _argmax_val, logits); + + if (_device_type == LLAISYS_DEVICE_CPU) { + return *reinterpret_cast(_argmax_idx->data()); + } + int64_t result = 0; + core::context().runtime().api()->memcpy_sync( + &result, + _argmax_idx->data(), + sizeof(result), + LLAISYS_MEMCPY_D2H); + return result; +} + +int64_t Qwen2Model::generateNext(const int64_t* token_ids, size_t ntoken, int top_k, float top_p, float temperature) { + tensor_t logits = forward(token_ids, ntoken); + return ops::sample(logits, top_k, top_p, temperature); +} + +void Qwen2Model::reset() { + _cur_pos = 0; +} + +void Qwen2Model::truncate(size_t position) { + CHECK_ARGUMENT(position <= _cur_pos, "truncate position exceeds current cache length"); + _cur_pos = position; +} + +} // namespace llaisys diff --git a/src/models/qwen2/qwen2.hpp b/src/models/qwen2/qwen2.hpp new file mode 100644 index 00000000..894d1b0d --- /dev/null +++ b/src/models/qwen2/qwen2.hpp @@ -0,0 +1,83 @@ +#pragma once +#include "llaisys/models/qwen2.h" +#include "../../tensor/tensor.hpp" +#include +#include + +namespace llaisys { + +class Qwen2Model { +public: + Qwen2Model(const LlaisysQwen2Meta& meta, llaisysDeviceType_t device, int device_id); + ~Qwen2Model(); + + LlaisysQwen2Weights* getWeights() { return &_weights; } + + int64_t infer(const int64_t* token_ids, size_t ntoken); + int64_t generateNext(const int64_t* token_ids, size_t ntoken, int top_k, float top_p, float temperature); + void reset(); + void truncate(size_t position); + +private: + LlaisysQwen2Meta _meta; + LlaisysQwen2Weights _weights; + + llaisysDeviceType_t _device_type; + int _device_id; + + struct KVCache { + tensor_t k; + tensor_t v; + }; + struct DecodeLayerBuffers { + tensor_t q_flat; + tensor_t k_flat; + tensor_t v_flat; + tensor_t q_view; + tensor_t k_view; + tensor_t v_view; + tensor_t attn_val_3d; + tensor_t attn_val_2d; + tensor_t gate; + tensor_t up; + tensor_t swiglu; + }; + struct PrefillBuffers { + size_t capacity = 0; + tensor_t tokens; + tensor_t hidden_states; + tensor_t normed; + tensor_t q_flat; + tensor_t k_flat; + tensor_t v_flat; + tensor_t attn_val_3d; + tensor_t attn_out; + tensor_t gate; + tensor_t up; + tensor_t swiglu; + tensor_t mlp_out; + }; + std::vector _kv_caches; + std::vector _decode_layers; + PrefillBuffers _prefill; + + size_t _cur_pos = 0; + + tensor_t _hidden_states; + tensor_t _ln_out; + tensor_t _attn_out; + tensor_t _mlp_out; + tensor_t _logits; + tensor_t _tokens_tensor; + tensor_t _single_pos_id; + tensor_t _pos_ids; + tensor_t _argmax_idx; + tensor_t _argmax_val; + + void init_buffers(); + void ensure_prefill_buffers(size_t ntoken); + tensor_t forward(const int64_t* token_ids, size_t ntoken); + tensor_t forward_single_token(const int64_t* token_id); +}; + +} // namespace llaisys diff --git a/src/ops/add/cpu/add_cpu.cpp b/src/ops/add/cpu/add_cpu.cpp index 47f6a3d4..90c332a8 100644 --- a/src/ops/add/cpu/add_cpu.cpp +++ b/src/ops/add/cpu/add_cpu.cpp @@ -1,12 +1,58 @@ #include "add_cpu.hpp" +#if defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) +#ifdef __C +#pragma push_macro("__C") +#undef __C +#define LLAISYS_RESTORE_C_MACRO +#endif +#include +#ifdef LLAISYS_RESTORE_C_MACRO +#pragma pop_macro("__C") +#undef LLAISYS_RESTORE_C_MACRO +#endif +#endif + #include "../../../utils.hpp" -#include +#include +#include + +namespace { + +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) +inline bool has_avx2() { + return __builtin_cpu_supports("avx2"); +} + +__attribute__((target("avx2"))) +void add_f32_avx2(float *c, const float *a, const float *b, std::ptrdiff_t numel) { + const std::ptrdiff_t simd_numel = numel - (numel % 8); + +#pragma omp parallel for schedule(static) if (numel >= 4096) + for (std::ptrdiff_t i = 0; i < simd_numel; i += 8) { + const __m256 va = _mm256_loadu_ps(a + i); + const __m256 vb = _mm256_loadu_ps(b + i); + _mm256_storeu_ps(c + i, _mm256_add_ps(va, vb)); + } + + for (std::ptrdiff_t i = simd_numel; i < numel; ++i) { + c[i] = a[i] + b[i]; + } +} +#endif + +void add_f32(float *c, const float *a, const float *b, std::ptrdiff_t numel) { +#pragma omp parallel for schedule(static) if (numel >= 4096) + for (std::ptrdiff_t i = 0; i < numel; ++i) { + c[i] = a[i] + b[i]; + } +} template -void add_(T *c, const T *a, const T *b, size_t numel) { - for (size_t i = 0; i < numel; i++) { +void add_generic(T *c, const T *a, const T *b, std::ptrdiff_t numel) { +#pragma omp parallel for schedule(static) if (numel >= 4096) + for (std::ptrdiff_t i = 0; i < numel; ++i) { if constexpr (std::is_same_v || std::is_same_v) { c[i] = llaisys::utils::cast(llaisys::utils::cast(a[i]) + llaisys::utils::cast(b[i])); } else { @@ -15,17 +61,25 @@ void add_(T *c, const T *a, const T *b, size_t numel) { } } +} // namespace + namespace llaisys::ops::cpu { void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel) { + const auto elem_count = static_cast(numel); switch (type) { case LLAISYS_DTYPE_F32: - return add_(reinterpret_cast(c), reinterpret_cast(a), reinterpret_cast(b), numel); +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) + if (has_avx2()) { + return add_f32_avx2(reinterpret_cast(c), reinterpret_cast(a), reinterpret_cast(b), elem_count); + } +#endif + return add_f32(reinterpret_cast(c), reinterpret_cast(a), reinterpret_cast(b), elem_count); case LLAISYS_DTYPE_BF16: - return add_(reinterpret_cast(c), reinterpret_cast(a), - reinterpret_cast(b), numel); + return add_generic(reinterpret_cast(c), reinterpret_cast(a), + reinterpret_cast(b), elem_count); case LLAISYS_DTYPE_F16: - return add_(reinterpret_cast(c), reinterpret_cast(a), - reinterpret_cast(b), numel); + return add_generic(reinterpret_cast(c), reinterpret_cast(a), + reinterpret_cast(b), elem_count); default: EXCEPTION_UNSUPPORTED_DATATYPE(type); } diff --git a/src/ops/add/nvidia/add_nvidia.cu b/src/ops/add/nvidia/add_nvidia.cu new file mode 100644 index 00000000..8485030e --- /dev/null +++ b/src/ops/add/nvidia/add_nvidia.cu @@ -0,0 +1,47 @@ +#include "add_nvidia.cuh" + +#include "../../nvidia/nvidia_common.cuh" + +namespace { + +template +__global__ void add_kernel(T *c, const T *a, const T *b, size_t numel) { + const size_t idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x); + if (idx >= numel) { + return; + } + + const float lhs = llaisys::device::nvidia::cuda_utils::toFloat(a[idx]); + const float rhs = llaisys::device::nvidia::cuda_utils::toFloat(b[idx]); + c[idx] = llaisys::device::nvidia::cuda_utils::fromFloat(lhs + rhs); +} + +template +void add_impl(std::byte *c, const std::byte *a, const std::byte *b, size_t numel) { + constexpr int threads = 256; + add_kernel<<>>( + reinterpret_cast(c), + reinterpret_cast(a), + reinterpret_cast(b), + numel); + LLAISYS_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t dtype, size_t numel) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return add_impl(c, a, b, numel); + case LLAISYS_DTYPE_F16: + return add_impl(c, a, b, numel); + case LLAISYS_DTYPE_BF16: + return add_impl(c, a, b, numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/add/nvidia/add_nvidia.cuh b/src/ops/add/nvidia/add_nvidia.cuh new file mode 100644 index 00000000..cd7e4584 --- /dev/null +++ b/src/ops/add/nvidia/add_nvidia.cuh @@ -0,0 +1,9 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t dtype, size_t numel); +} diff --git a/src/ops/add/op.cpp b/src/ops/add/op.cpp index a057330d..6c973fcc 100644 --- a/src/ops/add/op.cpp +++ b/src/ops/add/op.cpp @@ -4,6 +4,7 @@ #include "../../utils.hpp" #include "cpu/add_cpu.hpp" +#include "nvidia/add_nvidia.cuh" namespace llaisys::ops { void add(tensor_t c, tensor_t a, tensor_t b) { @@ -25,8 +26,7 @@ void add(tensor_t c, tensor_t a, tensor_t b) { return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #ifdef ENABLE_NVIDIA_API case LLAISYS_DEVICE_NVIDIA: - TO_BE_IMPLEMENTED(); - return; + return nvidia::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #endif default: EXCEPTION_UNSUPPORTED_DEVICE; diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 00000000..c29bc532 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,142 @@ +#include "argmax_cpu.hpp" + +#if defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) +#ifdef __C +#pragma push_macro("__C") +#undef __C +#define LLAISYS_RESTORE_C_MACRO +#endif +#include +#ifdef LLAISYS_RESTORE_C_MACRO +#pragma pop_macro("__C") +#undef LLAISYS_RESTORE_C_MACRO +#endif +#endif + +#include "../../../utils.hpp" + +#include +#include +#include + +namespace { + +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) +inline bool has_avx2() { + return __builtin_cpu_supports("avx2"); +} + +__attribute__((target("avx2"))) +void argmax_f32_avx2(size_t *max_idx, float *max_val, const float *vals, std::ptrdiff_t numel) { + if (numel <= 0) { + *max_idx = 0; + *max_val = 0.0f; + return; + } + + std::ptrdiff_t i = 0; + float best_val = vals[0]; + size_t best_idx = 0; + + if (numel >= 8) { + __m256 max_vals = _mm256_loadu_ps(vals); + __m256i max_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + __m256i cur_indices = _mm256_setr_epi32(8, 9, 10, 11, 12, 13, 14, 15); + const __m256i step = _mm256_set1_epi32(8); + i = 8; + + for (; i + 8 <= numel; i += 8) { + const __m256 cur_vals = _mm256_loadu_ps(vals + i); + const __m256 mask = _mm256_cmp_ps(cur_vals, max_vals, _CMP_GT_OQ); + max_vals = _mm256_blendv_ps(max_vals, cur_vals, mask); + max_indices = _mm256_blendv_epi8(max_indices, cur_indices, _mm256_castps_si256(mask)); + cur_indices = _mm256_add_epi32(cur_indices, step); + } + + alignas(32) float lane_vals[8]; + alignas(32) int lane_indices[8]; + _mm256_store_ps(lane_vals, max_vals); + _mm256_store_si256(reinterpret_cast<__m256i *>(lane_indices), max_indices); + for (int lane = 0; lane < 8; ++lane) { + if (lane_vals[lane] > best_val) { + best_val = lane_vals[lane]; + best_idx = static_cast(lane_indices[lane]); + } + } + } + + for (; i < numel; ++i) { + if (vals[i] > best_val) { + best_val = vals[i]; + best_idx = static_cast(i); + } + } + + *max_idx = best_idx; + *max_val = best_val; +} +#endif + +template +void argmax_generic(size_t *max_idx, T *max_val, const T *vals, std::ptrdiff_t numel){ + size_t max_index = 0; + float max_value = llaisys::utils::cast(vals[0]); + + std::ptrdiff_t i = 1; + for (; i + 3 < numel; i += 4) { + const float v0 = llaisys::utils::cast(vals[i + 0]); + const float v1 = llaisys::utils::cast(vals[i + 1]); + const float v2 = llaisys::utils::cast(vals[i + 2]); + const float v3 = llaisys::utils::cast(vals[i + 3]); + if (v0 > max_value) { + max_value = v0; + max_index = static_cast(i + 0); + } + if (v1 > max_value) { + max_value = v1; + max_index = static_cast(i + 1); + } + if (v2 > max_value) { + max_value = v2; + max_index = static_cast(i + 2); + } + if (v3 > max_value) { + max_value = v3; + max_index = static_cast(i + 3); + } + } + + for (; i < numel; ++i) { + const float current_value = llaisys::utils::cast(vals[i]); + if (current_value > max_value) { + max_value = current_value; + max_index = static_cast(i); + } + } + + *max_idx = max_index; + *max_val = llaisys::utils::cast(max_value); +} + +} // namespace + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel){ + const auto elem_count = static_cast(numel); + switch (type) { + case LLAISYS_DTYPE_F32: +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) + if (has_avx2()) { + return argmax_f32_avx2(reinterpret_cast(max_idx), reinterpret_cast(max_val), reinterpret_cast(vals), elem_count); + } +#endif + return argmax_generic(reinterpret_cast(max_idx), reinterpret_cast(max_val), reinterpret_cast(vals), elem_count); + case LLAISYS_DTYPE_BF16: + return argmax_generic(reinterpret_cast(max_idx), reinterpret_cast(max_val), reinterpret_cast(vals), elem_count); + case LLAISYS_DTYPE_F16: + return argmax_generic(reinterpret_cast(max_idx), reinterpret_cast(max_val), reinterpret_cast(vals), elem_count); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 00000000..5f58c207 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { + void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t dtype, size_t numel); +} diff --git a/src/ops/argmax/nvidia/argmax_nvidia.cu b/src/ops/argmax/nvidia/argmax_nvidia.cu new file mode 100644 index 00000000..337f24da --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_nvidia.cu @@ -0,0 +1,40 @@ +#include "argmax_nvidia.cuh" + +#include "../../nvidia/nvidia_common.cuh" +#include "../cpu/argmax_cpu.hpp" + +#include + +namespace llaisys::ops::nvidia { + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t dtype, size_t numel) { + std::vector host_vals(numel * utils::dsize(dtype)); + std::vector host_idx(sizeof(size_t)); + std::vector host_max_val(utils::dsize(dtype)); + + LLAISYS_CUDA_CHECK(cudaMemcpyAsync( + host_vals.data(), + vals, + host_vals.size(), + cudaMemcpyDeviceToHost, + current_stream())); + LLAISYS_CUDA_CHECK(cudaStreamSynchronize(current_stream())); + + cpu::argmax(host_idx.data(), host_max_val.data(), host_vals.data(), dtype, numel); + + LLAISYS_CUDA_CHECK(cudaMemcpyAsync( + max_idx, + host_idx.data(), + host_idx.size(), + cudaMemcpyHostToDevice, + current_stream())); + LLAISYS_CUDA_CHECK(cudaMemcpyAsync( + max_val, + host_max_val.data(), + host_max_val.size(), + cudaMemcpyHostToDevice, + current_stream())); + LLAISYS_CUDA_CHECK(cudaStreamSynchronize(current_stream())); +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/argmax/nvidia/argmax_nvidia.cuh b/src/ops/argmax/nvidia/argmax_nvidia.cuh new file mode 100644 index 00000000..710a2498 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_nvidia.cuh @@ -0,0 +1,9 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t dtype, size_t numel); +} diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d42..c2b5ae5e 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,32 @@ #include "op.hpp" + +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/argmax_cpu.hpp" +#include "nvidia/argmax_nvidia.cuh" + namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(max_idx, max_val, vals); + CHECK_SAME_DTYPE(max_val->dtype(), vals->dtype()); + ASSERT(vals->isContiguous(), "Argmax: vals tensor must be contiguous."); + + if(vals->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); + } + + llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId()); + switch (vals->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 00000000..d0010892 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,31 @@ +#include "embedding_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +template +void embedding_(T *out, const int64_t *index, const T* weight, size_t numel, size_t embedding_dim){ +#pragma omp parallel for schedule(static) if (numel >= 16) + for(std::ptrdiff_t i = 0; i < static_cast(numel); ++i){ + T* out_row_dst = out + i * embedding_dim; + const T* weight_row_src = weight + index[i] * embedding_dim; + std::memcpy(out_row_dst, weight_row_src, embedding_dim * sizeof(T)); + } +} + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t numel, size_t embedding_dim){ + switch (type) { + case LLAISYS_DTYPE_F32: + return embedding_(reinterpret_cast(out), reinterpret_cast(index), reinterpret_cast(weight), numel, embedding_dim); + case LLAISYS_DTYPE_BF16: + return embedding_(reinterpret_cast(out), reinterpret_cast(index), reinterpret_cast(weight), numel, embedding_dim); + case LLAISYS_DTYPE_F16: + return embedding_(reinterpret_cast(out), reinterpret_cast(index), reinterpret_cast(weight), numel, embedding_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 00000000..7d4f0d82 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { + void embedding(std::byte *out, const std::byte *index, const std::byte *wight, llaisysDataType_t dtype, size_t numel, size_t embedding_dim); +} diff --git a/src/ops/embedding/nvidia/embedding_nvidia.cu b/src/ops/embedding/nvidia/embedding_nvidia.cu new file mode 100644 index 00000000..9dadd983 --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_nvidia.cu @@ -0,0 +1,50 @@ +#include "embedding_nvidia.cuh" + +#include "../../nvidia/nvidia_common.cuh" + +namespace { + +template +__global__ void embedding_kernel(T *out, const int64_t *index, const T *weight, size_t numel, size_t embedding_dim) { + const size_t idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x); + const size_t total = numel * embedding_dim; + if (idx >= total) { + return; + } + + const size_t row = idx / embedding_dim; + const size_t col = idx % embedding_dim; + out[idx] = weight[static_cast(index[row]) * embedding_dim + col]; +} + +template +void embedding_impl(std::byte *out, const std::byte *index, const std::byte *weight, size_t numel, size_t embedding_dim) { + constexpr int threads = 256; + const size_t total = numel * embedding_dim; + embedding_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(index), + reinterpret_cast(weight), + numel, + embedding_dim); + LLAISYS_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t dtype, size_t numel, size_t embedding_dim) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return embedding_impl(out, index, weight, numel, embedding_dim); + case LLAISYS_DTYPE_F16: + return embedding_impl(out, index, weight, numel, embedding_dim); + case LLAISYS_DTYPE_BF16: + return embedding_impl(out, index, weight, numel, embedding_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/embedding/nvidia/embedding_nvidia.cuh b/src/ops/embedding/nvidia/embedding_nvidia.cuh new file mode 100644 index 00000000..cd3fab17 --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_nvidia.cuh @@ -0,0 +1,9 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t dtype, size_t numel, size_t embedding_dim); +} diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d0..2864563c 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,35 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/embedding_cpu.hpp" +#include "nvidia/embedding_nvidia.cuh" namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, index, weight); + CHECK_SAME_DTYPE(out->dtype(), weight->dtype()); + ASSERT(index->dtype() == LLAISYS_DTYPE_I64, "Embedding: index tensor must be int64."); + ASSERT(index->isContiguous(), "Embedding: index tensor must be contiguous."); + size_t embedding_dim = weight->shape().back(); + ASSERT(out->shape().size() == 2 && out->shape()[1] == embedding_dim, + "Embedding: output tensor shape is invalid."); + ASSERT(index->shape().size() == 1 && index->shape()[0] == out->shape()[0], + "Embedding: index tensor shape is invalid."); + if(out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->numel(), embedding_dim); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->numel(), embedding_dim); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->numel(), embedding_dim); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 00000000..86c65e5c --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,504 @@ +#include "linear_cpu.hpp" + +#if defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) +#ifdef __C +#pragma push_macro("__C") +#undef __C +#define LLAISYS_RESTORE_C_MACRO +#endif +#include +#ifdef LLAISYS_RESTORE_C_MACRO +#pragma pop_macro("__C") +#undef LLAISYS_RESTORE_C_MACRO +#endif +#endif + +#ifdef LLAISYS_USE_OPENBLAS +extern "C" { +#include +} +#endif + +#include "../../../utils.hpp" + +#include +#include +#include +#include + +namespace { + +constexpr std::ptrdiff_t OUTPUT_TILE = 32; +constexpr std::ptrdiff_t OUTPUT_UNROLL = 4; +constexpr std::ptrdiff_t ROW_UNROLL = 2; +constexpr std::ptrdiff_t REDUCTION_TILE = 256; + +#ifdef LLAISYS_USE_OPENBLAS +inline bool should_use_openblas(const std::vector &shapes) { + const std::ptrdiff_t dimi = static_cast(shapes[0]); + const std::ptrdiff_t dimk = static_cast(shapes[1]); + const std::ptrdiff_t dimj = static_cast(shapes[2]); + return dimi >= 16 && dimk >= 1024 && dimj >= 1024 && dimk >= dimj; +} +#endif + +inline bool should_use_gemm_kernel(const std::vector &shapes) { + const std::ptrdiff_t dimi = static_cast(shapes[0]); + const std::ptrdiff_t dimk = static_cast(shapes[1]); + const std::ptrdiff_t dimj = static_cast(shapes[2]); + return dimi >= 16 && dimk >= 1024 && dimj >= 1024; +} + +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) +inline bool has_avx2_fma() { + return __builtin_cpu_supports("avx2") && __builtin_cpu_supports("fma"); +} + +__attribute__((target("avx2,fma,sse3"))) +inline float hsum256_ps(__m256 v) { + const __m128 low = _mm256_castps256_ps128(v); + const __m128 high = _mm256_extractf128_ps(v, 1); + __m128 sum = _mm_add_ps(low, high); + sum = _mm_hadd_ps(sum, sum); + sum = _mm_hadd_ps(sum, sum); + return _mm_cvtss_f32(sum); +} + +__attribute__((target("avx2,fma"))) +void linear_f32_avx2_matvec(float *out, const float *in, const float *weight, const float *bias, const std::vector &shapes) { + const std::ptrdiff_t dimi = static_cast(shapes[0]); + const std::ptrdiff_t dimk = static_cast(shapes[1]); + const std::ptrdiff_t dimj = static_cast(shapes[2]); + const std::ptrdiff_t ksimd = dimk - (dimk % 8); + +#pragma omp parallel for collapse(2) schedule(static) + for (std::ptrdiff_t i = 0; i < dimi; ++i) { + for (std::ptrdiff_t j0 = 0; j0 < dimj; j0 += OUTPUT_TILE) { + const float *xrow = in + i * dimk; + float *outrow = out + i * dimj; + const std::ptrdiff_t jend = std::min(j0 + OUTPUT_TILE, dimj); + std::ptrdiff_t j = j0; + + for (; j + OUTPUT_UNROLL <= jend; j += OUTPUT_UNROLL) { + const float *w0 = weight + (j + 0) * dimk; + const float *w1 = weight + (j + 1) * dimk; + const float *w2 = weight + (j + 2) * dimk; + const float *w3 = weight + (j + 3) * dimk; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + for (std::ptrdiff_t k = 0; k < ksimd; k += 8) { + const __m256 x = _mm256_loadu_ps(xrow + k); + acc0 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w0 + k), acc0); + acc1 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w1 + k), acc1); + acc2 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w2 + k), acc2); + acc3 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w3 + k), acc3); + } + + float sum0 = hsum256_ps(acc0) + (bias ? bias[j + 0] : 0.0f); + float sum1 = hsum256_ps(acc1) + (bias ? bias[j + 1] : 0.0f); + float sum2 = hsum256_ps(acc2) + (bias ? bias[j + 2] : 0.0f); + float sum3 = hsum256_ps(acc3) + (bias ? bias[j + 3] : 0.0f); + + for (std::ptrdiff_t k = ksimd; k < dimk; ++k) { + const float x = xrow[k]; + sum0 += x * w0[k]; + sum1 += x * w1[k]; + sum2 += x * w2[k]; + sum3 += x * w3[k]; + } + + outrow[j + 0] = sum0; + outrow[j + 1] = sum1; + outrow[j + 2] = sum2; + outrow[j + 3] = sum3; + } + + for (; j < jend; ++j) { + const float *wrow = weight + j * dimk; + __m256 acc = _mm256_setzero_ps(); + + for (std::ptrdiff_t k = 0; k < ksimd; k += 8) { + const __m256 x = _mm256_loadu_ps(xrow + k); + acc = _mm256_fmadd_ps(x, _mm256_loadu_ps(wrow + k), acc); + } + + float sum = hsum256_ps(acc) + (bias ? bias[j] : 0.0f); + for (std::ptrdiff_t k = ksimd; k < dimk; ++k) { + sum += xrow[k] * wrow[k]; + } + outrow[j] = sum; + } + } + } +} + +__attribute__((target("avx2,fma"))) +void linear_f32_avx2_gemm(float *out, const float *in, const float *weight, const float *bias, const std::vector &shapes) { + const std::ptrdiff_t dimi = static_cast(shapes[0]); + const std::ptrdiff_t dimk = static_cast(shapes[1]); + const std::ptrdiff_t dimj = static_cast(shapes[2]); + const std::ptrdiff_t ksimd = dimk - (dimk % 8); + +#pragma omp parallel for collapse(2) schedule(static) + for (std::ptrdiff_t i0 = 0; i0 < dimi; i0 += ROW_UNROLL) { + for (std::ptrdiff_t j0 = 0; j0 < dimj; j0 += OUTPUT_TILE) { + const std::ptrdiff_t iend = std::min(i0 + ROW_UNROLL, dimi); + const std::ptrdiff_t jend = std::min(j0 + OUTPUT_TILE, dimj); + + if (iend - i0 < ROW_UNROLL) { + for (std::ptrdiff_t i = i0; i < iend; ++i) { + const float *xrow = in + i * dimk; + float *outrow = out + i * dimj; + std::ptrdiff_t j = j0; + for (; j + OUTPUT_UNROLL <= jend; j += OUTPUT_UNROLL) { + const float *w0 = weight + (j + 0) * dimk; + const float *w1 = weight + (j + 1) * dimk; + const float *w2 = weight + (j + 2) * dimk; + const float *w3 = weight + (j + 3) * dimk; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + for (std::ptrdiff_t k = 0; k < ksimd; k += 8) { + const __m256 x = _mm256_loadu_ps(xrow + k); + acc0 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w0 + k), acc0); + acc1 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w1 + k), acc1); + acc2 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w2 + k), acc2); + acc3 = _mm256_fmadd_ps(x, _mm256_loadu_ps(w3 + k), acc3); + } + + float sum0 = hsum256_ps(acc0) + (bias ? bias[j + 0] : 0.0f); + float sum1 = hsum256_ps(acc1) + (bias ? bias[j + 1] : 0.0f); + float sum2 = hsum256_ps(acc2) + (bias ? bias[j + 2] : 0.0f); + float sum3 = hsum256_ps(acc3) + (bias ? bias[j + 3] : 0.0f); + for (std::ptrdiff_t k = ksimd; k < dimk; ++k) { + const float x = xrow[k]; + sum0 += x * w0[k]; + sum1 += x * w1[k]; + sum2 += x * w2[k]; + sum3 += x * w3[k]; + } + + outrow[j + 0] = sum0; + outrow[j + 1] = sum1; + outrow[j + 2] = sum2; + outrow[j + 3] = sum3; + } + + for (; j < jend; ++j) { + const float *wrow = weight + j * dimk; + __m256 acc = _mm256_setzero_ps(); + for (std::ptrdiff_t k = 0; k < ksimd; k += 8) { + acc = _mm256_fmadd_ps(_mm256_loadu_ps(xrow + k), _mm256_loadu_ps(wrow + k), acc); + } + float sum = hsum256_ps(acc) + (bias ? bias[j] : 0.0f); + for (std::ptrdiff_t k = ksimd; k < dimk; ++k) { + sum += xrow[k] * wrow[k]; + } + outrow[j] = sum; + } + } + continue; + } + + const float *xrow0 = in + i0 * dimk; + const float *xrow1 = in + (i0 + 1) * dimk; + float *outrow0 = out + i0 * dimj; + float *outrow1 = out + (i0 + 1) * dimj; + std::ptrdiff_t j = j0; + + for (; j + OUTPUT_UNROLL <= jend; j += OUTPUT_UNROLL) { + const float *w0 = weight + (j + 0) * dimk; + const float *w1 = weight + (j + 1) * dimk; + const float *w2 = weight + (j + 2) * dimk; + const float *w3 = weight + (j + 3) * dimk; + + __m256 acc00 = _mm256_setzero_ps(); + __m256 acc01 = _mm256_setzero_ps(); + __m256 acc02 = _mm256_setzero_ps(); + __m256 acc03 = _mm256_setzero_ps(); + __m256 acc10 = _mm256_setzero_ps(); + __m256 acc11 = _mm256_setzero_ps(); + __m256 acc12 = _mm256_setzero_ps(); + __m256 acc13 = _mm256_setzero_ps(); + + for (std::ptrdiff_t k = 0; k < ksimd; k += 8) { + const __m256 x0 = _mm256_loadu_ps(xrow0 + k); + const __m256 x1 = _mm256_loadu_ps(xrow1 + k); + const __m256 wv0 = _mm256_loadu_ps(w0 + k); + const __m256 wv1 = _mm256_loadu_ps(w1 + k); + const __m256 wv2 = _mm256_loadu_ps(w2 + k); + const __m256 wv3 = _mm256_loadu_ps(w3 + k); + + acc00 = _mm256_fmadd_ps(x0, wv0, acc00); + acc01 = _mm256_fmadd_ps(x0, wv1, acc01); + acc02 = _mm256_fmadd_ps(x0, wv2, acc02); + acc03 = _mm256_fmadd_ps(x0, wv3, acc03); + acc10 = _mm256_fmadd_ps(x1, wv0, acc10); + acc11 = _mm256_fmadd_ps(x1, wv1, acc11); + acc12 = _mm256_fmadd_ps(x1, wv2, acc12); + acc13 = _mm256_fmadd_ps(x1, wv3, acc13); + } + + float sum00 = hsum256_ps(acc00) + (bias ? bias[j + 0] : 0.0f); + float sum01 = hsum256_ps(acc01) + (bias ? bias[j + 1] : 0.0f); + float sum02 = hsum256_ps(acc02) + (bias ? bias[j + 2] : 0.0f); + float sum03 = hsum256_ps(acc03) + (bias ? bias[j + 3] : 0.0f); + float sum10 = hsum256_ps(acc10) + (bias ? bias[j + 0] : 0.0f); + float sum11 = hsum256_ps(acc11) + (bias ? bias[j + 1] : 0.0f); + float sum12 = hsum256_ps(acc12) + (bias ? bias[j + 2] : 0.0f); + float sum13 = hsum256_ps(acc13) + (bias ? bias[j + 3] : 0.0f); + + for (std::ptrdiff_t k = ksimd; k < dimk; ++k) { + const float x0 = xrow0[k]; + const float x1 = xrow1[k]; + const float wv0 = w0[k]; + const float wv1 = w1[k]; + const float wv2 = w2[k]; + const float wv3 = w3[k]; + sum00 += x0 * wv0; + sum01 += x0 * wv1; + sum02 += x0 * wv2; + sum03 += x0 * wv3; + sum10 += x1 * wv0; + sum11 += x1 * wv1; + sum12 += x1 * wv2; + sum13 += x1 * wv3; + } + + outrow0[j + 0] = sum00; + outrow0[j + 1] = sum01; + outrow0[j + 2] = sum02; + outrow0[j + 3] = sum03; + outrow1[j + 0] = sum10; + outrow1[j + 1] = sum11; + outrow1[j + 2] = sum12; + outrow1[j + 3] = sum13; + } + + for (; j < jend; ++j) { + const float *wrow = weight + j * dimk; + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + for (std::ptrdiff_t k = 0; k < ksimd; k += 8) { + const __m256 wv = _mm256_loadu_ps(wrow + k); + acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(xrow0 + k), wv, acc0); + acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(xrow1 + k), wv, acc1); + } + float sum0 = hsum256_ps(acc0) + (bias ? bias[j] : 0.0f); + float sum1 = hsum256_ps(acc1) + (bias ? bias[j] : 0.0f); + for (std::ptrdiff_t k = ksimd; k < dimk; ++k) { + const float wv = wrow[k]; + sum0 += xrow0[k] * wv; + sum1 += xrow1[k] * wv; + } + outrow0[j] = sum0; + outrow1[j] = sum1; + } + } + } +} +#endif + +#ifdef LLAISYS_USE_OPENBLAS +void linear_f32_openblas(float *out, const float *in, const float *weight, const float *bias, const std::vector &shapes) { + const blasint dimi = static_cast(shapes[0]); + const blasint dimk = static_cast(shapes[1]); + const blasint dimj = static_cast(shapes[2]); + + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, + CblasTrans, + dimi, + dimj, + dimk, + 1.0f, + in, + dimk, + weight, + dimk, + 0.0f, + out, + dimj); + + if (bias == nullptr) { + return; + } + + const std::ptrdiff_t dimi_p = static_cast(dimi); + const std::ptrdiff_t dimj_p = static_cast(dimj); +#pragma omp parallel for schedule(static) + for (std::ptrdiff_t i = 0; i < dimi_p; ++i) { + float *outrow = out + i * dimj_p; + for (std::ptrdiff_t j0 = 0; j0 < dimj_p; j0 += OUTPUT_TILE) { + const std::ptrdiff_t jend = std::min(j0 + OUTPUT_TILE, dimj_p); +#pragma omp simd + for (std::ptrdiff_t j = j0; j < jend; ++j) { + outrow[j] += bias[j]; + } + } + } +} +#endif + +void linear_f32(float *out, const float *in, const float *weight, const float *bias, const std::vector &shapes) { + const std::ptrdiff_t dimi = static_cast(shapes[0]); + const std::ptrdiff_t dimk = static_cast(shapes[1]); + const std::ptrdiff_t dimj = static_cast(shapes[2]); + +#pragma omp parallel for collapse(2) schedule(static) + for (std::ptrdiff_t i = 0; i < dimi; ++i) { + for (std::ptrdiff_t j0 = 0; j0 < dimj; j0 += OUTPUT_TILE) { + const float *xrow = in + i * dimk; + float *outrow = out + i * dimj; + const std::ptrdiff_t jend = std::min(j0 + OUTPUT_TILE, dimj); + std::ptrdiff_t j = j0; + + for (; j + OUTPUT_UNROLL <= jend; j += OUTPUT_UNROLL) { + const float *w0 = weight + (j + 0) * dimk; + const float *w1 = weight + (j + 1) * dimk; + const float *w2 = weight + (j + 2) * dimk; + const float *w3 = weight + (j + 3) * dimk; + + float acc0 = bias ? bias[j + 0] : 0.0f; + float acc1 = bias ? bias[j + 1] : 0.0f; + float acc2 = bias ? bias[j + 2] : 0.0f; + float acc3 = bias ? bias[j + 3] : 0.0f; + + for (std::ptrdiff_t k0 = 0; k0 < dimk; k0 += REDUCTION_TILE) { + const std::ptrdiff_t kend = std::min(k0 + REDUCTION_TILE, dimk); + float part0 = 0.0f; + float part1 = 0.0f; + float part2 = 0.0f; + float part3 = 0.0f; + +#pragma omp simd reduction(+ : part0, part1, part2, part3) + for (std::ptrdiff_t k = k0; k < kend; ++k) { + const float x = xrow[k]; + part0 += x * w0[k]; + part1 += x * w1[k]; + part2 += x * w2[k]; + part3 += x * w3[k]; + } + + acc0 += part0; + acc1 += part1; + acc2 += part2; + acc3 += part3; + } + + outrow[j + 0] = acc0; + outrow[j + 1] = acc1; + outrow[j + 2] = acc2; + outrow[j + 3] = acc3; + } + + for (; j < jend; ++j) { + const float *wrow = weight + j * dimk; + float acc = bias ? bias[j] : 0.0f; + + for (std::ptrdiff_t k0 = 0; k0 < dimk; k0 += REDUCTION_TILE) { + const std::ptrdiff_t kend = std::min(k0 + REDUCTION_TILE, dimk); + float part = 0.0f; + +#pragma omp simd reduction(+ : part) + for (std::ptrdiff_t k = k0; k < kend; ++k) { + part += xrow[k] * wrow[k]; + } + + acc += part; + } + + outrow[j] = acc; + } + } + } +} + +template +void linear_generic(T *out, const T *in, const T *weight, const T *bias, const std::vector &shapes) { + const std::ptrdiff_t dimi = static_cast(shapes[0]); + const std::ptrdiff_t dimk = static_cast(shapes[1]); + const std::ptrdiff_t dimj = static_cast(shapes[2]); + +#pragma omp parallel for collapse(2) schedule(static) + for (std::ptrdiff_t i = 0; i < dimi; ++i) { + for (std::ptrdiff_t j = 0; j < dimj; ++j) { + float sum = 0.0f; + for (std::ptrdiff_t k = 0; k < dimk; ++k) { + sum += llaisys::utils::cast(in[i * dimk + k]) * llaisys::utils::cast(weight[j * dimk + k]); + } + if (bias != nullptr) { + sum += llaisys::utils::cast(bias[j]); + } + out[i * dimj + j] = llaisys::utils::cast(sum); + } + } +} + +} // namespace + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, llaisysDataType_t type, std::vector shapes){ + switch (type) { + case LLAISYS_DTYPE_F32: +#ifdef LLAISYS_USE_OPENBLAS + if (should_use_openblas(shapes)) { + return linear_f32_openblas( + reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + bias ? reinterpret_cast(bias) : nullptr, + shapes); + } +#endif +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) + if (has_avx2_fma()) { + if (should_use_gemm_kernel(shapes)) { + return linear_f32_avx2_gemm( + reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + bias ? reinterpret_cast(bias) : nullptr, + shapes); + } + return linear_f32_avx2_matvec( + reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + bias ? reinterpret_cast(bias) : nullptr, + shapes); + } +#endif + return linear_f32( + reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + bias ? reinterpret_cast(bias) : nullptr, + shapes); + case LLAISYS_DTYPE_BF16: + return linear_generic( + reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + bias ? reinterpret_cast(bias) : nullptr, + shapes); + case LLAISYS_DTYPE_F16: + return linear_generic( + reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + bias ? reinterpret_cast(bias) : nullptr, + shapes); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 00000000..a9a98ed0 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::cpu { + void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, llaisysDataType_t dtype, std::vector shapes); +} diff --git a/src/ops/linear/nvidia/linear_nvidia.cu b/src/ops/linear/nvidia/linear_nvidia.cu new file mode 100644 index 00000000..154d9c03 --- /dev/null +++ b/src/ops/linear/nvidia/linear_nvidia.cu @@ -0,0 +1,87 @@ +#include "linear_nvidia.cuh" + +#include "../../nvidia/nvidia_common.cuh" + +namespace { + +template +__global__ void add_bias_kernel(T *out, const T *bias, size_t dimi, size_t dimj) { + const size_t idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x); + const size_t total = dimi * dimj; + if (idx >= total) { + return; + } + + const size_t col = idx % dimj; + const float out_v = llaisys::device::nvidia::cuda_utils::toFloat(out[idx]); + const float bias_v = llaisys::device::nvidia::cuda_utils::toFloat(bias[col]); + out[idx] = llaisys::device::nvidia::cuda_utils::fromFloat(out_v + bias_v); +} + +template +void add_bias(std::byte *out, const std::byte *bias, size_t dimi, size_t dimj) { + constexpr int threads = 256; + const size_t total = dimi * dimj; + add_bias_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(bias), + dimi, + dimj); + LLAISYS_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void linear(std::byte *out, + const std::byte *in, + const std::byte *weight, + const std::byte *bias, + llaisysDataType_t dtype, + size_t dimi, + size_t dimk, + size_t dimj) { + const float alpha = 1.0f; + const float beta = 0.0f; + const cudaDataType_t data_type = device::nvidia::cuda_utils::cublasDataType(dtype); + const cublasComputeType_t compute_type = device::nvidia::cuda_utils::cublasComputeType(dtype); + + LLAISYS_CUBLAS_CHECK(cublasGemmEx( + current_cublas_handle(), + CUBLAS_OP_T, + CUBLAS_OP_N, + static_cast(dimj), + static_cast(dimi), + static_cast(dimk), + &alpha, + weight, + data_type, + static_cast(dimk), + in, + data_type, + static_cast(dimk), + &beta, + out, + data_type, + static_cast(dimj), + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + if (bias == nullptr) { + return; + } + + switch (dtype) { + case LLAISYS_DTYPE_F32: + return add_bias(out, bias, dimi, dimj); + case LLAISYS_DTYPE_F16: + return add_bias(out, bias, dimi, dimj); + case LLAISYS_DTYPE_BF16: + return add_bias(out, bias, dimi, dimj); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/linear/nvidia/linear_nvidia.cuh b/src/ops/linear/nvidia/linear_nvidia.cuh new file mode 100644 index 00000000..8e491dd1 --- /dev/null +++ b/src/ops/linear/nvidia/linear_nvidia.cuh @@ -0,0 +1,16 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void linear(std::byte *out, + const std::byte *in, + const std::byte *weight, + const std::byte *bias, + llaisysDataType_t dtype, + size_t dimi, + size_t dimk, + size_t dimj); +} diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f865..119d0c92 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,38 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/linear_cpu.hpp" +#include "nvidia/linear_nvidia.cuh" + namespace llaisys::ops { void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + ASSERT(in->shape().size() == 2, "Linear: input tensor must be 2-D."); + ASSERT(weight->shape().size() == 2, "Linear: weight tensor must be 2-D."); + ASSERT(out->shape().size() == 2, "Linear: output tensor must be 2-D."); + size_t dimi = in->shape()[0]; + size_t dimk = in->shape()[1]; + size_t dimj = weight->shape()[0]; + ASSERT(weight->shape()[1] == dimk, "Linear: weight tensor shape is invalid."); + ASSERT(out->shape()[0] == dimi && out->shape()[1] == dimj, "Linear: output tensor shape is invalid."); + if(bias != nullptr){ + ASSERT(bias->shape().size() == 1 && bias->shape()[0] == dimj, "Linear: bias tensor shape is invalid."); + } + + if(out->deviceType() == LLAISYS_DEVICE_CPU) { + return llaisys::ops::cpu::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, out->dtype(), {dimi, dimk, dimj}); + } + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, out->dtype(), {dimi, dimk, dimj}); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, out->dtype(), dimi, dimk, dimj); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/nvidia/nvidia_common.cuh b/src/ops/nvidia/nvidia_common.cuh new file mode 100644 index 00000000..c72acb00 --- /dev/null +++ b/src/ops/nvidia/nvidia_common.cuh @@ -0,0 +1,91 @@ +#pragma once + +#include "../../core/llaisys_core.hpp" +#include "../../device/nvidia/cuda_utils.cuh" +#include "../../utils.hpp" + +#include +#include +#include +#include + +namespace llaisys::ops::nvidia { + +inline cudaStream_t current_stream() { + return reinterpret_cast(core::context().runtime().stream()); +} + +inline int current_device_id() { + return core::context().runtime().deviceId(); +} + +class CublasHandlePool { +public: + ~CublasHandlePool() { + for (auto &entry : handles) { + if (entry.second != nullptr) { + cublasDestroy(entry.second); + } + } + } + + std::unordered_map handles; +}; + +inline cublasHandle_t current_cublas_handle() { + thread_local CublasHandlePool pool; + cublasHandle_t &handle = pool.handles[current_device_id()]; + if (handle == nullptr) { + LLAISYS_CUBLAS_CHECK(cublasCreate(&handle)); + } + LLAISYS_CUBLAS_CHECK(cublasSetStream(handle, current_stream())); + return handle; +} + +struct TensorDescriptor { + int ndim = 0; + int64_t shape[8]{}; + int64_t strides[8]{}; +}; + +inline TensorDescriptor make_descriptor(const std::vector &shape, const std::vector &strides) { + ASSERT(shape.size() == strides.size(), "Tensor descriptor shape/stride rank mismatch."); + ASSERT(shape.size() <= 8, "Tensor descriptor only supports up to 8 dimensions."); + + TensorDescriptor desc{}; + desc.ndim = static_cast(shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + desc.shape[i] = static_cast(shape[i]); + desc.strides[i] = static_cast(strides[i]); + } + return desc; +} + +inline size_t numel(const std::vector &shape) { + size_t total = 1; + for (size_t dim : shape) { + total *= dim; + } + return total; +} + +inline bool is_contiguous(const std::vector &shape, const std::vector &stride) { + if (shape.size() != stride.size()) { + return false; + } + + size_t expected = 1; + for (std::ptrdiff_t dim = static_cast(shape.size()) - 1; dim >= 0; --dim) { + if (stride[static_cast(dim)] != expected) { + return false; + } + expected *= shape[static_cast(dim)]; + } + return true; +} + +inline int blocks_for(size_t total, int threads = 256) { + return static_cast((total + static_cast(threads) - 1) / static_cast(threads)); +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rearrange/cpu/rearrange_cpu.cpp b/src/ops/rearrange/cpu/rearrange_cpu.cpp new file mode 100644 index 00000000..d238e69f --- /dev/null +++ b/src/ops/rearrange/cpu/rearrange_cpu.cpp @@ -0,0 +1,113 @@ +#include "rearrange_cpu.hpp" +#include "../../../utils.hpp" + +#include +#include + +namespace { + +size_t shape_numel(const std::vector &shape) { + size_t numel = 1; + for (size_t dim : shape) { + numel *= dim; + } + return numel; +} + +bool is_contiguous(const std::vector &shape, const std::vector &stride) { + if (shape.size() != stride.size()) { + return false; + } + + size_t expected = 1; + for (std::ptrdiff_t dim = static_cast(shape.size()) - 1; dim >= 0; --dim) { + if (stride[dim] != expected) { + return false; + } + expected *= shape[dim]; + } + return true; +} + +template +void rearrange_inner(T *out_base, const T *in_base, const std::vector &shape, const std::vector &stride_in, const std::vector &stride_out, size_t dim, size_t offset_in, size_t offset_out) { + const size_t len = shape[dim]; + const size_t s_in = stride_in[dim]; + const size_t s_out = stride_out[dim]; + + if (dim == shape.size() - 1) { + if (s_in == 1 && s_out == 1) { + std::memcpy(out_base + offset_out, in_base + offset_in, len * sizeof(T)); + } else { + for (size_t i = 0; i < len; ++i) { + out_base[offset_out + i * s_out] = in_base[offset_in + i * s_in]; + } + } + } else { + for (size_t i = 0; i < len; ++i) { + rearrange_inner(out_base, in_base, shape, stride_in, stride_out, dim + 1, offset_in + i * s_in, offset_out + i * s_out); + } + } +} + +template +void rearrange_dispatch(T *out_base, const T *in_base, const std::vector &shape, const std::vector &stride_in, const std::vector &stride_out) { + if (shape.empty()) { + return; + } + if (is_contiguous(shape, stride_in) && is_contiguous(shape, stride_out)) { + std::memcpy(out_base, in_base, shape_numel(shape) * sizeof(T)); + return; + } + if (shape.size() == 1) { + rearrange_inner(out_base, in_base, shape, stride_in, stride_out, 0, 0, 0); + return; + } + + const size_t len0 = shape[0]; + const size_t s_in0 = stride_in[0]; + const size_t s_out0 = stride_out[0]; + +#pragma omp parallel for schedule(static) if (len0 >= 4) + for (std::ptrdiff_t i = 0; i < static_cast(len0); ++i) { + rearrange_inner(out_base, in_base, shape, stride_in, stride_out, 1, static_cast(i) * s_in0, static_cast(i) * s_out0); + } +} + +} // namespace + +namespace llaisys::ops::cpu { + +void rearrange(std::byte *out, const std::byte *in, llaisysDataType_t dtype, const std::vector &shape, const std::vector &stride_in, const std::vector &stride_out) { + + if (shape.empty()) { + size_t size = 0; + switch (dtype) { + case LLAISYS_DTYPE_F32: size = 4; break; + case LLAISYS_DTYPE_BF16: size = 2; break; + case LLAISYS_DTYPE_F16: size = 2; break; + case LLAISYS_DTYPE_I64: size = 8; break; + default: EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } + std::memcpy(out, in, size); + return; + } + + switch (dtype) { + case LLAISYS_DTYPE_F32: + rearrange_dispatch(reinterpret_cast(out), reinterpret_cast(in), shape, stride_in, stride_out); + break; + case LLAISYS_DTYPE_BF16: + rearrange_dispatch(reinterpret_cast(out), reinterpret_cast(in), shape, stride_in, stride_out); + break; + case LLAISYS_DTYPE_F16: + rearrange_dispatch(reinterpret_cast(out), reinterpret_cast(in), shape, stride_in, stride_out); + break; + case LLAISYS_DTYPE_I64: + rearrange_dispatch(reinterpret_cast(out), reinterpret_cast(in), shape, stride_in, stride_out); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rearrange/cpu/rearrange_cpu.hpp b/src/ops/rearrange/cpu/rearrange_cpu.hpp new file mode 100644 index 00000000..f15927a4 --- /dev/null +++ b/src/ops/rearrange/cpu/rearrange_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" +#include +#include + +namespace llaisys::ops::cpu { + void rearrange(std::byte *out, const std::byte *in, llaisysDataType_t dtype, const std::vector &shape, const std::vector &stride_in, const std::vector &stride_out); +} \ No newline at end of file diff --git a/src/ops/rearrange/nvidia/rearrange_nvidia.cu b/src/ops/rearrange/nvidia/rearrange_nvidia.cu new file mode 100644 index 00000000..9c624dc3 --- /dev/null +++ b/src/ops/rearrange/nvidia/rearrange_nvidia.cu @@ -0,0 +1,89 @@ +#include "rearrange_nvidia.cuh" + +#include "../../nvidia/nvidia_common.cuh" + +namespace { + +template +__global__ void rearrange_kernel(T *out, + const T *in, + llaisys::ops::nvidia::TensorDescriptor desc_in, + llaisys::ops::nvidia::TensorDescriptor desc_out, + size_t total) { + const size_t linear_idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x); + if (linear_idx >= total) { + return; + } + + int64_t remaining = static_cast(linear_idx); + int64_t in_offset = 0; + int64_t out_offset = 0; + for (int dim = desc_in.ndim - 1; dim >= 0; --dim) { + const int64_t dim_idx = remaining % desc_in.shape[static_cast(dim)]; + remaining /= desc_in.shape[static_cast(dim)]; + in_offset += dim_idx * desc_in.strides[static_cast(dim)]; + out_offset += dim_idx * desc_out.strides[static_cast(dim)]; + } + + out[out_offset] = in[in_offset]; +} + +template +void rearrange_impl(std::byte *out, + const std::byte *in, + const std::vector &shape, + const std::vector &stride_in, + const std::vector &stride_out) { + const auto desc_in = llaisys::ops::nvidia::make_descriptor(shape, stride_in); + const auto desc_out = llaisys::ops::nvidia::make_descriptor(shape, stride_out); + constexpr int threads = 256; + const size_t total = llaisys::ops::nvidia::numel(shape); + rearrange_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(in), + desc_in, + desc_out, + total); + LLAISYS_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void rearrange(std::byte *out, + const std::byte *in, + llaisysDataType_t dtype, + const std::vector &shape, + const std::vector &stride_in, + const std::vector &stride_out) { + if (shape.empty()) { + LLAISYS_CUDA_CHECK(cudaMemcpyAsync(out, in, utils::dsize(dtype), cudaMemcpyDeviceToDevice, current_stream())); + return; + } + + if (is_contiguous(shape, stride_in) && is_contiguous(shape, stride_out)) { + LLAISYS_CUDA_CHECK(cudaMemcpyAsync( + out, + in, + numel(shape) * utils::dsize(dtype), + cudaMemcpyDeviceToDevice, + current_stream())); + return; + } + + switch (dtype) { + case LLAISYS_DTYPE_F32: + return rearrange_impl(out, in, shape, stride_in, stride_out); + case LLAISYS_DTYPE_F16: + return rearrange_impl(out, in, shape, stride_in, stride_out); + case LLAISYS_DTYPE_BF16: + return rearrange_impl(out, in, shape, stride_in, stride_out); + case LLAISYS_DTYPE_I64: + return rearrange_impl(out, in, shape, stride_in, stride_out); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rearrange/nvidia/rearrange_nvidia.cuh b/src/ops/rearrange/nvidia/rearrange_nvidia.cuh new file mode 100644 index 00000000..60be7e43 --- /dev/null +++ b/src/ops/rearrange/nvidia/rearrange_nvidia.cuh @@ -0,0 +1,15 @@ +#pragma once + +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::nvidia { +void rearrange(std::byte *out, + const std::byte *in, + llaisysDataType_t dtype, + const std::vector &shape, + const std::vector &stride_in, + const std::vector &stride_out); +} diff --git a/src/ops/rearrange/op.cpp b/src/ops/rearrange/op.cpp index 017a6ae5..75ba9c78 100644 --- a/src/ops/rearrange/op.cpp +++ b/src/ops/rearrange/op.cpp @@ -1,7 +1,31 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/rearrange_cpu.hpp" +#include "nvidia/rearrange_nvidia.cuh" namespace llaisys::ops { void rearrange(tensor_t out, tensor_t in) { - TO_BE_IMPLEMENTED(); + ASSERT(out->shape() == in->shape(), "Rearrange: input and output tensors must have the same shape."); + ASSERT(out->dtype() == in->dtype(), "Rearrange: input and output tensors must have the same dtype."); + + std::vector stride_in(in->strides().begin(), in->strides().end()); + std::vector stride_out(out->strides().begin(), out->strides().end()); + + if(out->deviceType() == LLAISYS_DEVICE_CPU) { + return llaisys::ops::cpu::rearrange(out->data(), in->data(), out->dtype(), out->shape(), stride_in, stride_out); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rearrange(out->data(), in->data(), out->dtype(), out->shape(), stride_in, stride_out); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::rearrange(out->data(), in->data(), out->dtype(), out->shape(), stride_in, stride_out); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.cpp b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp new file mode 100644 index 00000000..53f55ad0 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp @@ -0,0 +1,148 @@ +#include "rms_norm_cpu.hpp" + +#if defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) +#ifdef __C +#pragma push_macro("__C") +#undef __C +#define LLAISYS_RESTORE_C_MACRO +#endif +#include +#ifdef LLAISYS_RESTORE_C_MACRO +#pragma pop_macro("__C") +#undef LLAISYS_RESTORE_C_MACRO +#endif +#endif + +#include "../../../utils.hpp" + +#include +#include + +namespace { + +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) +inline bool has_avx2() { + return __builtin_cpu_supports("avx2"); +} + +__attribute__((target("avx2,sse3"))) +inline float hsum256_ps(__m256 v) { + const __m128 low = _mm256_castps256_ps128(v); + const __m128 high = _mm256_extractf128_ps(v, 1); + __m128 sum = _mm_add_ps(low, high); + sum = _mm_hadd_ps(sum, sum); + sum = _mm_hadd_ps(sum, sum); + return _mm_cvtss_f32(sum); +} + +__attribute__((target("avx2"))) +void rms_norm_f32_avx2(float *out, const float *in, const float *weight, float eps, const std::vector &shapes) { + const std::ptrdiff_t dimi = static_cast(shapes[0]); + const std::ptrdiff_t dimj = static_cast(shapes[1]); + const std::ptrdiff_t simd_dimj = dimj - (dimj % 8); + const bool parallel_rows = dimi > 1 && dimi * dimj >= 65536; + +#pragma omp parallel for schedule(static) if (parallel_rows) + for (std::ptrdiff_t i = 0; i < dimi; ++i) { + const float *in_row = in + i * dimj; + float *out_row = out + i * dimj; + __m256 sum_acc = _mm256_setzero_ps(); + + for (std::ptrdiff_t j = 0; j < simd_dimj; j += 8) { + const __m256 x = _mm256_loadu_ps(in_row + j); + sum_acc = _mm256_add_ps(sum_acc, _mm256_mul_ps(x, x)); + } + + float sum_sq = hsum256_ps(sum_acc); + for (std::ptrdiff_t j = simd_dimj; j < dimj; ++j) { + sum_sq += in_row[j] * in_row[j]; + } + + const float inv_rms = 1.0f / std::sqrt(sum_sq / static_cast(dimj) + eps); + const __m256 inv = _mm256_set1_ps(inv_rms); + + std::ptrdiff_t j = 0; + for (; j < simd_dimj; j += 8) { + const __m256 x = _mm256_loadu_ps(in_row + j); + const __m256 w = _mm256_loadu_ps(weight + j); + _mm256_storeu_ps(out_row + j, _mm256_mul_ps(_mm256_mul_ps(x, inv), w)); + } + + for (; j < dimj; ++j) { + out_row[j] = in_row[j] * inv_rms * weight[j]; + } + } +} +#endif + +void rms_norm_f32(float *out, const float *in, const float *weight, float eps, const std::vector &shapes) { + const std::ptrdiff_t dimi = static_cast(shapes[0]); + const std::ptrdiff_t dimj = static_cast(shapes[1]); + const bool parallel_rows = dimi > 1 && dimi * dimj >= 65536; + +#pragma omp parallel for schedule(static) if (parallel_rows) + for (std::ptrdiff_t i = 0; i < dimi; ++i) { + const float *in_row = in + i * dimj; + float *out_row = out + i * dimj; + float sum_sq = 0.0f; + +#pragma omp simd reduction(+ : sum_sq) + for (std::ptrdiff_t j = 0; j < dimj; ++j) { + sum_sq += in_row[j] * in_row[j]; + } + + const float inv_rms = 1.0f / std::sqrt(sum_sq / static_cast(dimj) + eps); + +#pragma omp simd + for (std::ptrdiff_t j = 0; j < dimj; ++j) { + out_row[j] = in_row[j] * inv_rms * weight[j]; + } + } +} + +template +void rms_norm_generic(T *out, const T *in, const T *weight, float eps, const std::vector &shapes) { + const std::ptrdiff_t dimi = static_cast(shapes[0]); + const std::ptrdiff_t dimj = static_cast(shapes[1]); + const bool parallel_rows = dimi > 1 && dimi * dimj >= 65536; + +#pragma omp parallel for schedule(static) if (parallel_rows) + for (std::ptrdiff_t i = 0; i < dimi; ++i) { + const T *in_row = in + i * dimj; + T *out_row = out + i * dimj; + float sum_sq = 0.0f; + for (std::ptrdiff_t j = 0; j < dimj; ++j) { + const float val = llaisys::utils::cast(in_row[j]); + sum_sq += val * val; + } + + const float inv_rms = 1.0f / std::sqrt(sum_sq / static_cast(dimj) + eps); + for (std::ptrdiff_t j = 0; j < dimj; ++j) { + const float val = llaisys::utils::cast(in_row[j]); + const float w = llaisys::utils::cast(weight[j]); + out_row[j] = llaisys::utils::cast(val * inv_rms * w); + } + } +} + +} // namespace + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, const float eps, llaisysDataType_t type, std::vector shapes){ + switch (type) { + case LLAISYS_DTYPE_F32: +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) + if (has_avx2()) { + return rms_norm_f32_avx2(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), eps, shapes); + } +#endif + return rms_norm_f32(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), eps, shapes); + case LLAISYS_DTYPE_BF16: + return rms_norm_generic(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), eps, shapes); + case LLAISYS_DTYPE_F16: + return rms_norm_generic(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), eps, shapes); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.hpp b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp new file mode 100644 index 00000000..61afb8d7 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::cpu { + void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, const float eps, llaisysDataType_t dtype, std::vector shapes); +} \ No newline at end of file diff --git a/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu new file mode 100644 index 00000000..f70a9fb7 --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu @@ -0,0 +1,92 @@ +#include "rms_norm_nvidia.cuh" + +#include "../../nvidia/nvidia_common.cuh" + +namespace { + +int choose_threads(size_t dimj) { + int threads = 32; + while (static_cast(threads) < dimj && threads < 256) { + threads <<= 1; + } + return threads; +} + +template +__global__ void rms_norm_kernel(T *out, const T *in, const T *weight, float eps, size_t dimj) { + extern __shared__ float shared_sum[]; + + const size_t row = static_cast(blockIdx.x); + const size_t tid = static_cast(threadIdx.x); + const T *in_row = in + row * dimj; + T *out_row = out + row * dimj; + + float local_sum = 0.0f; + for (size_t col = tid; col < dimj; col += static_cast(blockDim.x)) { + const float value = llaisys::device::nvidia::cuda_utils::toFloat(in_row[col]); + local_sum += value * value; + } + shared_sum[tid] = local_sum; + __syncthreads(); + + for (unsigned int stride = blockDim.x / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + shared_sum[tid] += shared_sum[tid + stride]; + } + __syncthreads(); + } + + __shared__ float inv_rms; + if (tid == 0) { + inv_rms = rsqrtf(shared_sum[0] / static_cast(dimj) + eps); + } + __syncthreads(); + + for (size_t col = tid; col < dimj; col += static_cast(blockDim.x)) { + const float value = llaisys::device::nvidia::cuda_utils::toFloat(in_row[col]); + const float scale = llaisys::device::nvidia::cuda_utils::toFloat(weight[col]); + out_row[col] = llaisys::device::nvidia::cuda_utils::fromFloat(value * inv_rms * scale); + } +} + +template +void rms_norm_impl(std::byte *out, + const std::byte *in, + const std::byte *weight, + float eps, + const std::vector &shape) { + const size_t dimi = shape[0]; + const size_t dimj = shape[1]; + const int threads = choose_threads(dimj); + rms_norm_kernel<<(dimi), threads, static_cast(threads) * sizeof(float), llaisys::ops::nvidia::current_stream()>>>( + reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + eps, + dimj); + LLAISYS_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void rms_norm(std::byte *out, + const std::byte *in, + const std::byte *weight, + float eps, + llaisysDataType_t dtype, + const std::vector &shape) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return rms_norm_impl(out, in, weight, eps, shape); + case LLAISYS_DTYPE_F16: + return rms_norm_impl(out, in, weight, eps, shape); + case LLAISYS_DTYPE_BF16: + return rms_norm_impl(out, in, weight, eps, shape); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rms_norm/nvidia/rms_norm_nvidia.cuh b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cuh new file mode 100644 index 00000000..903142c2 --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cuh @@ -0,0 +1,15 @@ +#pragma once + +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::nvidia { +void rms_norm(std::byte *out, + const std::byte *in, + const std::byte *weight, + float eps, + llaisysDataType_t dtype, + const std::vector &shape); +} diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9..ddaef306 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,35 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rms_norm_cpu.hpp" +#include "nvidia/rms_norm_nvidia.cuh" namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + ASSERT(in->shape().size() == 2, "RMSNorm: input tensor must be 2-D."); + ASSERT(weight->shape().size() == 1, "RMSNorm: weight tensor must be 1-D."); + ASSERT(out->shape().size() == 2, "RMSNorm: output tensor must be 2-D."); + size_t dimi = in->shape()[0]; + size_t dimj = in->shape()[1]; + + ASSERT(weight->shape()[0] == dimj, "RMSNorm: weight tensor shape is invalid."); + ASSERT(out->shape()[0] == dimi && out->shape()[1] == dimj, "RMSNorm: output tensor shape is invalid."); + + if(out->deviceType() == LLAISYS_DEVICE_CPU) { + return llaisys::ops::cpu::rms_norm(out->data(), in->data(), weight->data(), eps, out->dtype(), {dimi, dimj}); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rms_norm(out->data(), in->data(), weight->data(), eps, out->dtype(), {dimi, dimj}); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::rms_norm(out->data(), in->data(), weight->data(), eps, out->dtype(), {dimi, dimj}); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 00000000..c87cd0f2 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,255 @@ +#include "rope_cpu.hpp" + +#if defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) +#ifdef __C +#pragma push_macro("__C") +#undef __C +#define LLAISYS_RESTORE_C_MACRO +#endif +#include +#ifdef LLAISYS_RESTORE_C_MACRO +#pragma pop_macro("__C") +#undef LLAISYS_RESTORE_C_MACRO +#endif +#endif + +#include "../../../utils.hpp" + +#include +#include +#include + +#include + +namespace { + +std::vector build_inv_freq(float theta, std::ptrdiff_t head_dim) { + const std::ptrdiff_t half_dim = head_dim / 2; + std::vector inv_freq(static_cast(half_dim)); + for (std::ptrdiff_t j = 0; j < half_dim; ++j) { + const double exponent = (2.0 * static_cast(j)) / static_cast(head_dim); + inv_freq[static_cast(j)] = std::pow(static_cast(theta), -exponent); + } + return inv_freq; +} + +const std::vector &get_inv_freq(float theta, std::ptrdiff_t head_dim) { + static thread_local std::vector cached_inv_freq; + static thread_local float cached_theta = 0.0f; + static thread_local std::ptrdiff_t cached_head_dim = 0; + if (cached_inv_freq.empty() || cached_theta != theta || cached_head_dim != head_dim) { + cached_inv_freq = build_inv_freq(theta, head_dim); + cached_theta = theta; + cached_head_dim = head_dim; + } + return cached_inv_freq; +} + +void build_trig_tables(std::vector &sin_table, std::vector &cos_table, const int64_t *pos_ids, const std::vector &inv_freq, std::ptrdiff_t seq_len, std::ptrdiff_t half_dim) { + const bool parallel_trig = seq_len * half_dim >= 4096; +#pragma omp parallel for schedule(static) if (parallel_trig) + for (std::ptrdiff_t s = 0; s < seq_len; ++s) { + const double pos = static_cast(pos_ids[s]); + float *sin_row = sin_table.data() + s * half_dim; + float *cos_row = cos_table.data() + s * half_dim; + for (std::ptrdiff_t j = 0; j < half_dim; ++j) { + const double angle = pos * inv_freq[static_cast(j)]; + sin_row[j] = static_cast(std::sin(angle)); + cos_row[j] = static_cast(std::cos(angle)); + } + } +} + +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) +inline bool has_avx2_fma() { + return __builtin_cpu_supports("avx2") && __builtin_cpu_supports("fma"); +} + +__attribute__((target("avx2,fma"))) +void apply_rope_f32_avx2(float *out, const float *in, const std::vector &sin_table, const std::vector &cos_table, std::ptrdiff_t seq_len, std::ptrdiff_t num_heads, std::ptrdiff_t head_dim) { + const std::ptrdiff_t half_dim = head_dim / 2; + const std::ptrdiff_t simd_half_dim = half_dim - (half_dim % 8); + const bool parallel_apply = seq_len * num_heads * half_dim >= 4096; + +#pragma omp parallel for collapse(2) schedule(static) if (parallel_apply) + for (std::ptrdiff_t s = 0; s < seq_len; ++s) { + for (std::ptrdiff_t h = 0; h < num_heads; ++h) { + const std::ptrdiff_t offset = s * num_heads * head_dim + h * head_dim; + const float *src_vec = in + offset; + float *dst_vec = out + offset; + const float *sin_row = sin_table.data() + s * half_dim; + const float *cos_row = cos_table.data() + s * half_dim; + + std::ptrdiff_t j = 0; + for (; j < simd_half_dim; j += 8) { + const __m256 a = _mm256_loadu_ps(src_vec + j); + const __m256 b = _mm256_loadu_ps(src_vec + half_dim + j); + const __m256 sin_v = _mm256_loadu_ps(sin_row + j); + const __m256 cos_v = _mm256_loadu_ps(cos_row + j); + const __m256 a_rot = _mm256_fmsub_ps(a, cos_v, _mm256_mul_ps(b, sin_v)); + const __m256 b_rot = _mm256_fmadd_ps(a, sin_v, _mm256_mul_ps(b, cos_v)); + _mm256_storeu_ps(dst_vec + j, a_rot); + _mm256_storeu_ps(dst_vec + half_dim + j, b_rot); + } + + for (; j < half_dim; ++j) { + const float sin_v = sin_row[j]; + const float cos_v = cos_row[j]; + const float a = src_vec[j]; + const float b = src_vec[half_dim + j]; + dst_vec[j] = a * cos_v - b * sin_v; + dst_vec[half_dim + j] = b * cos_v + a * sin_v; + } + } + } +} +#endif + +void rope_f32(float *out, const float *in, const int64_t *pos_ids, float theta, std::ptrdiff_t seq_len, std::ptrdiff_t num_heads, std::ptrdiff_t head_dim) { + const std::ptrdiff_t half_dim = head_dim / 2; + std::vector sin_table(static_cast(seq_len * half_dim)); + std::vector cos_table(static_cast(seq_len * half_dim)); + const auto &inv_freq = get_inv_freq(theta, head_dim); + build_trig_tables(sin_table, cos_table, pos_ids, inv_freq, seq_len, half_dim); + const bool parallel_apply = seq_len * num_heads * half_dim >= 4096; + +#pragma omp parallel for collapse(2) schedule(static) if (parallel_apply) + for (std::ptrdiff_t s = 0; s < seq_len; ++s) { + for (std::ptrdiff_t h = 0; h < num_heads; ++h) { + const std::ptrdiff_t offset = s * num_heads * head_dim + h * head_dim; + const float *src_vec = in + offset; + float *dst_vec = out + offset; + const float *sin_row = sin_table.data() + s * half_dim; + const float *cos_row = cos_table.data() + s * half_dim; + +#pragma omp simd + for (std::ptrdiff_t j = 0; j < half_dim; ++j) { + const float a = src_vec[j]; + const float b = src_vec[half_dim + j]; + dst_vec[j] = a * cos_row[j] - b * sin_row[j]; + dst_vec[half_dim + j] = b * cos_row[j] + a * sin_row[j]; + } + } + } +} + +void rope_single_f32(float *out, const float *in, const int64_t *pos_ids, float theta, std::ptrdiff_t num_heads, std::ptrdiff_t head_dim) { + const std::ptrdiff_t half_dim = head_dim / 2; + const auto &inv_freq = get_inv_freq(theta, head_dim); + const double pos = static_cast(pos_ids[0]); + const bool parallel_heads = num_heads * half_dim >= 4096; + +#pragma omp parallel for schedule(static) if (parallel_heads) + for (std::ptrdiff_t h = 0; h < num_heads; ++h) { + const std::ptrdiff_t offset = h * head_dim; + const float *src_vec = in + offset; + float *dst_vec = out + offset; + for (std::ptrdiff_t j = 0; j < half_dim; ++j) { + const double angle = pos * inv_freq[static_cast(j)]; + const float sin_v = static_cast(std::sin(angle)); + const float cos_v = static_cast(std::cos(angle)); + const float a = src_vec[j]; + const float b = src_vec[half_dim + j]; + dst_vec[j] = a * cos_v - b * sin_v; + dst_vec[half_dim + j] = b * cos_v + a * sin_v; + } + } +} + +template +void rope_generic(T *out, const T *in, const int64_t *pos_ids, float theta, std::ptrdiff_t seq_len, std::ptrdiff_t num_heads, std::ptrdiff_t head_dim) { + const std::ptrdiff_t half_dim = head_dim / 2; + std::vector sin_table(static_cast(seq_len * half_dim)); + std::vector cos_table(static_cast(seq_len * half_dim)); + const auto &inv_freq = get_inv_freq(theta, head_dim); + build_trig_tables(sin_table, cos_table, pos_ids, inv_freq, seq_len, half_dim); + const bool parallel_apply = seq_len * num_heads * half_dim >= 4096; + +#pragma omp parallel for collapse(2) schedule(static) if (parallel_apply) + for (std::ptrdiff_t s = 0; s < seq_len; ++s) { + for (std::ptrdiff_t h = 0; h < num_heads; ++h) { + const std::ptrdiff_t offset = s * num_heads * head_dim + h * head_dim; + const T *src_vec = in + offset; + T *dst_vec = out + offset; + const float *sin_row = sin_table.data() + s * half_dim; + const float *cos_row = cos_table.data() + s * half_dim; + + for (std::ptrdiff_t j = 0; j < half_dim; ++j) { + const float a = llaisys::utils::cast(src_vec[j]); + const float b = llaisys::utils::cast(src_vec[half_dim + j]); + dst_vec[j] = llaisys::utils::cast(a * cos_row[j] - b * sin_row[j]); + dst_vec[half_dim + j] = llaisys::utils::cast(b * cos_row[j] + a * sin_row[j]); + } + } + } +} + +template +void rope_single_generic(T *out, const T *in, const int64_t *pos_ids, float theta, std::ptrdiff_t num_heads, std::ptrdiff_t head_dim) { + const std::ptrdiff_t half_dim = head_dim / 2; + const auto &inv_freq = get_inv_freq(theta, head_dim); + const double pos = static_cast(pos_ids[0]); + const bool parallel_heads = num_heads * half_dim >= 4096; + +#pragma omp parallel for schedule(static) if (parallel_heads) + for (std::ptrdiff_t h = 0; h < num_heads; ++h) { + const std::ptrdiff_t offset = h * head_dim; + const T *src_vec = in + offset; + T *dst_vec = out + offset; + for (std::ptrdiff_t j = 0; j < half_dim; ++j) { + const double angle = pos * inv_freq[static_cast(j)]; + const float sin_v = static_cast(std::sin(angle)); + const float cos_v = static_cast(std::cos(angle)); + const float a = llaisys::utils::cast(src_vec[j]); + const float b = llaisys::utils::cast(src_vec[half_dim + j]); + dst_vec[j] = llaisys::utils::cast(a * cos_v - b * sin_v); + dst_vec[half_dim + j] = llaisys::utils::cast(b * cos_v + a * sin_v); + } + } +} + +} // namespace + +namespace llaisys::ops::cpu { + +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, float theta, llaisysDataType_t dtype, const std::vector &shape) { + const std::ptrdiff_t seq_len = static_cast(shape[0]); + const std::ptrdiff_t num_heads = static_cast(shape[1]); + const std::ptrdiff_t head_dim = static_cast(shape[2]); + + const int64_t* pos_ptr = reinterpret_cast(pos_ids); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + if (seq_len == 1) { + return rope_single_f32(reinterpret_cast(out), reinterpret_cast(in), pos_ptr, theta, num_heads, head_dim); + } +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) + if (has_avx2_fma()) { + std::vector sin_table(static_cast(seq_len * (head_dim / 2))); + std::vector cos_table(static_cast(seq_len * (head_dim / 2))); + const auto &inv_freq = get_inv_freq(theta, head_dim); + build_trig_tables(sin_table, cos_table, pos_ptr, inv_freq, seq_len, head_dim / 2); + return apply_rope_f32_avx2(reinterpret_cast(out), reinterpret_cast(in), sin_table, cos_table, seq_len, num_heads, head_dim); + } +#endif + rope_f32(reinterpret_cast(out), reinterpret_cast(in), pos_ptr, theta, seq_len, num_heads, head_dim); + break; + case LLAISYS_DTYPE_BF16: + if (seq_len == 1) { + return rope_single_generic(reinterpret_cast(out), reinterpret_cast(in), pos_ptr, theta, num_heads, head_dim); + } + rope_generic(reinterpret_cast(out), reinterpret_cast(in), pos_ptr, theta, seq_len, num_heads, head_dim); + break; + case LLAISYS_DTYPE_F16: + if (seq_len == 1) { + return rope_single_generic(reinterpret_cast(out), reinterpret_cast(in), pos_ptr, theta, num_heads, head_dim); + } + rope_generic(reinterpret_cast(out), reinterpret_cast(in), pos_ptr, theta, seq_len, num_heads, head_dim); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 00000000..8a0dae47 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "llaisys.h" +#include +#include +#include + +namespace llaisys::ops::cpu { + void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, float theta, llaisysDataType_t dtype, const std::vector &shape); +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rope/nvidia/rope_nvidia.cu b/src/ops/rope/nvidia/rope_nvidia.cu new file mode 100644 index 00000000..5341b081 --- /dev/null +++ b/src/ops/rope/nvidia/rope_nvidia.cu @@ -0,0 +1,86 @@ +#include "rope_nvidia.cuh" + +#include "../../nvidia/nvidia_common.cuh" + +#include + +namespace { + +template +__global__ void rope_kernel(T *out, + const T *in, + const int64_t *pos_ids, + double log_theta, + int64_t num_heads, + int64_t head_dim, + size_t total_pairs) { + const size_t idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x); + if (idx >= total_pairs) { + return; + } + + const int64_t half_dim = head_dim / 2; + const int64_t pairs_per_seq = num_heads * half_dim; + const int64_t seq = static_cast(idx / static_cast(pairs_per_seq)); + const int64_t rem = static_cast(idx % static_cast(pairs_per_seq)); + const int64_t head = rem / half_dim; + const int64_t dim = rem % half_dim; + const int64_t base = (seq * num_heads + head) * head_dim; + + const float a = llaisys::device::nvidia::cuda_utils::toFloat(in[base + dim]); + const float b = llaisys::device::nvidia::cuda_utils::toFloat(in[base + half_dim + dim]); + const double exponent = -2.0 * static_cast(dim) / static_cast(head_dim); + const double freq = exp(log_theta * exponent); + const double angle = static_cast(pos_ids[seq]) * freq; + const float sin_v = static_cast(sin(angle)); + const float cos_v = static_cast(cos(angle)); + + out[base + dim] = llaisys::device::nvidia::cuda_utils::fromFloat(a * cos_v - b * sin_v); + out[base + half_dim + dim] = llaisys::device::nvidia::cuda_utils::fromFloat(b * cos_v + a * sin_v); +} + +template +void rope_impl(std::byte *out, + const std::byte *in, + const int64_t *pos_ids, + float theta, + const std::vector &shape) { + const int64_t seq_len = static_cast(shape[0]); + const int64_t num_heads = static_cast(shape[1]); + const int64_t head_dim = static_cast(shape[2]); + const size_t total_pairs = static_cast(seq_len * num_heads * (head_dim / 2)); + constexpr int threads = 256; + rope_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(in), + pos_ids, + log(static_cast(theta)), + num_heads, + head_dim, + total_pairs); + LLAISYS_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void rope(std::byte *out, + const std::byte *in, + const int64_t *pos_ids, + float theta, + llaisysDataType_t dtype, + const std::vector &shape) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return rope_impl(out, in, pos_ids, theta, shape); + case LLAISYS_DTYPE_F16: + return rope_impl(out, in, pos_ids, theta, shape); + case LLAISYS_DTYPE_BF16: + return rope_impl(out, in, pos_ids, theta, shape); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rope/nvidia/rope_nvidia.cuh b/src/ops/rope/nvidia/rope_nvidia.cuh new file mode 100644 index 00000000..2c35b2a6 --- /dev/null +++ b/src/ops/rope/nvidia/rope_nvidia.cuh @@ -0,0 +1,16 @@ +#pragma once + +#include "llaisys.h" + +#include +#include +#include + +namespace llaisys::ops::nvidia { +void rope(std::byte *out, + const std::byte *in, + const int64_t *pos_ids, + float theta, + llaisysDataType_t dtype, + const std::vector &shape); +} diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64..542cc671 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,43 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/rope_cpu.hpp" +#include "nvidia/rope_nvidia.cuh" namespace llaisys::ops { void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, pos_ids); + CHECK_SAME_DTYPE(out->dtype(), in->dtype()); + + ASSERT(pos_ids->dtype() == LLAISYS_DTYPE_I64, "RoPE: pos_ids must be int64."); + + ASSERT(in->shape().size() == 3, "RoPE: input tensor must be 3-D [seqlen, nhead, head_dim]."); + ASSERT(out->shape().size() == 3, "RoPE: output tensor must be 3-D."); + ASSERT(pos_ids->shape().size() == 1, "RoPE: pos_ids tensor must be 1-D [seqlen]."); + + size_t seq_len = in->shape()[0]; + size_t head_dim = in->shape()[2]; + + ASSERT(pos_ids->shape()[0] == seq_len, "RoPE: pos_ids length mismatch with input seqlen."); + ASSERT(out->shape() == in->shape(), "RoPE: output shape mismatch with input."); + ASSERT(head_dim % 2 == 0, "RoPE: head_dim must be even."); + + ASSERT(in->isContiguous() && out->isContiguous() && pos_ids->isContiguous(), "RoPE: inputs must be contiguous."); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rope(out->data(), in->data(), pos_ids->data(), theta, out->dtype(), in->shape()); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rope(out->data(), in->data(), pos_ids->data(), theta, out->dtype(), in->shape()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::rope(out->data(), in->data(), reinterpret_cast(pos_ids->data()), theta, out->dtype(), in->shape()); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/sample/cpu/sample_cpu.cpp b/src/ops/sample/cpu/sample_cpu.cpp new file mode 100644 index 00000000..1c494772 --- /dev/null +++ b/src/ops/sample/cpu/sample_cpu.cpp @@ -0,0 +1,120 @@ +#include "sample_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include +#include +#include +#include + +namespace { + +template +int64_t argmax_index(const T *vals, size_t numel) { + size_t max_index = 0; + float max_value = llaisys::utils::cast(vals[0]); + + for (size_t i = 1; i < numel; ++i) { + const float current = llaisys::utils::cast(vals[i]); + if (current > max_value) { + max_value = current; + max_index = i; + } + } + + return static_cast(max_index); +} + +template +int64_t sample_impl(const T *vals, size_t numel, int top_k, float top_p, float temperature) { + if (numel == 0) { + return 0; + } + + if (!std::isfinite(temperature) || temperature <= 0.0f || top_k == 1) { + return argmax_index(vals, numel); + } + + struct Candidate { + size_t index; + float logit; + float weight; + }; + + const size_t k = top_k <= 0 ? numel : std::min(numel, static_cast(top_k)); + const float safe_top_p = (!std::isfinite(top_p) || top_p <= 0.0f || top_p > 1.0f) ? 1.0f : top_p; + const float inv_temperature = 1.0f / temperature; + + std::vector candidates; + candidates.reserve(numel); + for (size_t i = 0; i < numel; ++i) { + candidates.push_back({i, llaisys::utils::cast(vals[i]) * inv_temperature, 0.0f}); + } + + auto by_logit_desc = [](const Candidate &lhs, const Candidate &rhs) { + return lhs.logit > rhs.logit; + }; + std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(), by_logit_desc); + candidates.resize(k); + + const float max_logit = candidates.front().logit; + float total_mass = 0.0f; + for (auto &candidate : candidates) { + candidate.weight = std::exp(candidate.logit - max_logit); + total_mass += candidate.weight; + } + + if (safe_top_p < 1.0f && total_mass > 0.0f) { + float cumulative = 0.0f; + size_t keep = 0; + for (; keep < candidates.size(); ++keep) { + cumulative += candidates[keep].weight / total_mass; + if (cumulative >= safe_top_p) { + ++keep; + break; + } + } + candidates.resize(std::max(1, std::min(keep, candidates.size()))); + } + + float kept_mass = 0.0f; + for (const auto &candidate : candidates) { + kept_mass += candidate.weight; + } + if (!(kept_mass > 0.0f)) { + return static_cast(candidates.front().index); + } + + thread_local std::mt19937 rng(std::random_device{}()); + std::uniform_real_distribution dist(0.0f, kept_mass); + const float draw = dist(rng); + + float running = 0.0f; + for (const auto &candidate : candidates) { + running += candidate.weight; + if (draw <= running) { + return static_cast(candidate.index); + } + } + + return static_cast(candidates.back().index); +} + +} // namespace + +namespace llaisys::ops::cpu { +int64_t sample(const std::byte *logits, llaisysDataType_t type, size_t numel, int top_k, float top_p, float temperature) { + switch (type) { + case LLAISYS_DTYPE_F32: + return sample_impl(reinterpret_cast(logits), numel, top_k, top_p, temperature); + case LLAISYS_DTYPE_BF16: + return sample_impl(reinterpret_cast(logits), numel, top_k, top_p, temperature); + case LLAISYS_DTYPE_F16: + return sample_impl(reinterpret_cast(logits), numel, top_k, top_p, temperature); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/sample/cpu/sample_cpu.hpp b/src/ops/sample/cpu/sample_cpu.hpp new file mode 100644 index 00000000..05b1b5ec --- /dev/null +++ b/src/ops/sample/cpu/sample_cpu.hpp @@ -0,0 +1,7 @@ +#pragma once + +#include "../../../tensor/tensor.hpp" + +namespace llaisys::ops::cpu { +int64_t sample(const std::byte *logits, llaisysDataType_t type, size_t numel, int top_k, float top_p, float temperature); +} diff --git a/src/ops/sample/nvidia/sample_nvidia.cu b/src/ops/sample/nvidia/sample_nvidia.cu new file mode 100644 index 00000000..3e0147e3 --- /dev/null +++ b/src/ops/sample/nvidia/sample_nvidia.cu @@ -0,0 +1,22 @@ +#include "sample_nvidia.cuh" + +#include "../../nvidia/nvidia_common.cuh" +#include "../cpu/sample_cpu.hpp" + +#include + +namespace llaisys::ops::nvidia { + +int64_t sample(const std::byte *logits, llaisysDataType_t dtype, size_t numel, int top_k, float top_p, float temperature) { + std::vector host_logits(numel * utils::dsize(dtype)); + LLAISYS_CUDA_CHECK(cudaMemcpyAsync( + host_logits.data(), + logits, + host_logits.size(), + cudaMemcpyDeviceToHost, + current_stream())); + LLAISYS_CUDA_CHECK(cudaStreamSynchronize(current_stream())); + return cpu::sample(host_logits.data(), dtype, numel, top_k, top_p, temperature); +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/sample/nvidia/sample_nvidia.cuh b/src/ops/sample/nvidia/sample_nvidia.cuh new file mode 100644 index 00000000..73b1365c --- /dev/null +++ b/src/ops/sample/nvidia/sample_nvidia.cuh @@ -0,0 +1,10 @@ +#pragma once + +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::nvidia { +int64_t sample(const std::byte *logits, llaisysDataType_t dtype, size_t numel, int top_k, float top_p, float temperature); +} diff --git a/src/ops/sample/op.cpp b/src/ops/sample/op.cpp new file mode 100644 index 00000000..29652aac --- /dev/null +++ b/src/ops/sample/op.cpp @@ -0,0 +1,30 @@ +#include "op.hpp" + +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/sample_cpu.hpp" +#include "nvidia/sample_nvidia.cuh" + +namespace llaisys::ops { +int64_t sample(tensor_t logits, int top_k, float top_p, float temperature) { + ASSERT(logits->isContiguous(), "Sample: logits tensor must be contiguous."); + ASSERT(logits->numel() > 0, "Sample: logits tensor must not be empty."); + + if (logits->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::sample(logits->data(), logits->dtype(), logits->numel(), top_k, top_p, temperature); + } + + llaisys::core::context().setDevice(logits->deviceType(), logits->deviceId()); + switch (logits->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::sample(logits->data(), logits->dtype(), logits->numel(), top_k, top_p, temperature); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::sample(logits->data(), logits->dtype(), logits->numel(), top_k, top_p, temperature); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } +} +} // namespace llaisys::ops diff --git a/src/ops/sample/op.hpp b/src/ops/sample/op.hpp new file mode 100644 index 00000000..7d3f7a45 --- /dev/null +++ b/src/ops/sample/op.hpp @@ -0,0 +1,7 @@ +#pragma once + +#include "../../tensor/tensor.hpp" + +namespace llaisys::ops { +int64_t sample(tensor_t logits, int top_k, float top_p, float temperature); +} diff --git a/src/ops/self_attention/cpu/self_attention_cpu.cpp b/src/ops/self_attention/cpu/self_attention_cpu.cpp new file mode 100644 index 00000000..e2385362 --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.cpp @@ -0,0 +1,508 @@ +#include "self_attention_cpu.hpp" + +#if defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) +#ifdef __C +#pragma push_macro("__C") +#undef __C +#define LLAISYS_RESTORE_C_MACRO +#endif +#include +#ifdef LLAISYS_RESTORE_C_MACRO +#pragma pop_macro("__C") +#undef LLAISYS_RESTORE_C_MACRO +#endif +#endif + +#include "../../../utils.hpp" + +#include +#include +#include +#include +#include + +namespace { + +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) +inline bool has_avx2_fma() { + return __builtin_cpu_supports("avx2") && __builtin_cpu_supports("fma"); +} + +__attribute__((target("avx2,fma,sse3"))) +inline float hsum256_ps(__m256 v) { + const __m128 low = _mm256_castps256_ps128(v); + const __m128 high = _mm256_extractf128_ps(v, 1); + __m128 sum = _mm_add_ps(low, high); + sum = _mm_hadd_ps(sum, sum); + sum = _mm_hadd_ps(sum, sum); + return _mm_cvtss_f32(sum); +} + +__attribute__((target("avx2,fma"))) +float dot_f32_avx2(const float *a, const float *b, std::ptrdiff_t numel) { + const std::ptrdiff_t simd_numel = numel - (numel % 8); + __m256 acc = _mm256_setzero_ps(); + for (std::ptrdiff_t i = 0; i < simd_numel; i += 8) { + acc = _mm256_fmadd_ps(_mm256_loadu_ps(a + i), _mm256_loadu_ps(b + i), acc); + } + + float sum = hsum256_ps(acc); + for (std::ptrdiff_t i = simd_numel; i < numel; ++i) { + sum += a[i] * b[i]; + } + return sum; +} + +__attribute__((target("avx2,fma"))) +void fill_zero_f32_avx2(float *dst, std::ptrdiff_t numel) { + const std::ptrdiff_t simd_numel = numel - (numel % 8); + const __m256 zero = _mm256_setzero_ps(); + std::ptrdiff_t i = 0; + for (; i < simd_numel; i += 8) { + _mm256_storeu_ps(dst + i, zero); + } + for (; i < numel; ++i) { + dst[i] = 0.0f; + } +} + +__attribute__((target("avx2,fma"))) +void axpy_f32_avx2(float *dst, const float *src, float alpha, std::ptrdiff_t numel) { + const std::ptrdiff_t simd_numel = numel - (numel % 8); + const __m256 scale = _mm256_set1_ps(alpha); + std::ptrdiff_t i = 0; + for (; i < simd_numel; i += 8) { + const __m256 cur = _mm256_loadu_ps(dst + i); + const __m256 x = _mm256_loadu_ps(src + i); + _mm256_storeu_ps(dst + i, _mm256_fmadd_ps(scale, x, cur)); + } + for (; i < numel; ++i) { + dst[i] += alpha * src[i]; + } +} + +__attribute__((target("avx2,fma"))) +void scale_f32_avx2(float *dst, float alpha, std::ptrdiff_t numel) { + const std::ptrdiff_t simd_numel = numel - (numel % 8); + const __m256 scale = _mm256_set1_ps(alpha); + std::ptrdiff_t i = 0; + for (; i < simd_numel; i += 8) { + _mm256_storeu_ps(dst + i, _mm256_mul_ps(_mm256_loadu_ps(dst + i), scale)); + } + for (; i < numel; ++i) { + dst[i] *= alpha; + } +} + +__attribute__((target("avx2,fma"))) +void self_attention_decode_f32_avx2(float *attn_val, const float *q, const float *k, const float *v, float scale, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape) { + const std::ptrdiff_t nh = static_cast(q_shape[1]); + const std::ptrdiff_t d = static_cast(q_shape[2]); + const std::ptrdiff_t sk = static_cast(k_shape[0]); + const std::ptrdiff_t nh_kv = static_cast(k_shape[1]); + const std::ptrdiff_t dv = static_cast(v_shape[2]); + const std::ptrdiff_t n_rep = nh / nh_kv; + const bool parallel_heads = nh * sk * d >= 32768; + +#pragma omp parallel for schedule(static) if (parallel_heads) + for (std::ptrdiff_t h = 0; h < nh; ++h) { + const std::ptrdiff_t h_kv = h / n_rep; + const float *q_vec = q + h * d; + float *out_ptr = attn_val + h * dv; + fill_zero_f32_avx2(out_ptr, dv); + + float max_score = -std::numeric_limits::infinity(); + float sum_exp = 0.0f; + for (std::ptrdiff_t j = 0; j < sk; ++j) { + const float *k_vec = k + (j * nh_kv + h_kv) * d; + const float *v_vec = v + (j * nh_kv + h_kv) * dv; + const float score = dot_f32_avx2(q_vec, k_vec, d) * scale; + + if (score > max_score) { + const float rescale = std::exp(max_score - score); + if (sum_exp > 0.0f) { + scale_f32_avx2(out_ptr, rescale, dv); + } + sum_exp = sum_exp * rescale + 1.0f; + max_score = score; + axpy_f32_avx2(out_ptr, v_vec, 1.0f, dv); + } else { + const float weight = std::exp(score - max_score); + sum_exp += weight; + axpy_f32_avx2(out_ptr, v_vec, weight, dv); + } + } + + if (sum_exp > 0.0f) { + scale_f32_avx2(out_ptr, 1.0f / sum_exp, dv); + } + } +} + +__attribute__((target("avx2,fma"))) +void self_attention_f32_avx2(float *attn_val, const float *q, const float *k, const float *v, float scale, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape) { + const std::ptrdiff_t sq = static_cast(q_shape[0]); + const std::ptrdiff_t nh = static_cast(q_shape[1]); + const std::ptrdiff_t d = static_cast(q_shape[2]); + const std::ptrdiff_t sk = static_cast(k_shape[0]); + const std::ptrdiff_t nh_kv = static_cast(k_shape[1]); + const std::ptrdiff_t dv = static_cast(v_shape[2]); + const std::ptrdiff_t n_rep = nh / nh_kv; + const bool parallel_work = sq * nh * sk * d >= 32768; + +#pragma omp parallel for collapse(2) schedule(static) if (parallel_work) + for (std::ptrdiff_t i = 0; i < sq; ++i) { + for (std::ptrdiff_t h = 0; h < nh; ++h) { + const std::ptrdiff_t h_kv = h / n_rep; + const std::ptrdiff_t q_abs_pos = sk - sq + i; + const std::ptrdiff_t valid_len = std::min(sk, q_abs_pos + 1); + const float *q_vec = q + (i * nh + h) * d; + float *out_ptr = attn_val + (i * nh + h) * dv; + + if (valid_len == 1) { + const float *v_vec = v + h_kv * dv; + std::copy(v_vec, v_vec + dv, out_ptr); + continue; + } + + std::vector scores(static_cast(valid_len)); + float max_score = -std::numeric_limits::infinity(); + for (std::ptrdiff_t j = 0; j < valid_len; ++j) { + const float *k_vec = k + (j * nh_kv + h_kv) * d; + const float score = dot_f32_avx2(q_vec, k_vec, d) * scale; + scores[static_cast(j)] = score; + if (score > max_score) { + max_score = score; + } + } + + float sum_exp = 0.0f; + for (float &score : scores) { + score = std::exp(score - max_score); + sum_exp += score; + } + const float inv_sum = 1.0f / (sum_exp + 1e-10f); + + fill_zero_f32_avx2(out_ptr, dv); + for (std::ptrdiff_t j = 0; j < valid_len; ++j) { + const float weight = scores[static_cast(j)] * inv_sum; + if (weight < 1e-10f) { + continue; + } + const float *v_vec = v + (j * nh_kv + h_kv) * dv; + axpy_f32_avx2(out_ptr, v_vec, weight, dv); + } + } + } +} +#endif + +void scale_f32(float *dst, float alpha, std::ptrdiff_t numel) { +#pragma omp simd + for (std::ptrdiff_t i = 0; i < numel; ++i) { + dst[i] *= alpha; + } +} + +void self_attention_decode_f32(float *attn_val, const float *q, const float *k, const float *v, float scale, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape) { + const std::ptrdiff_t nh = static_cast(q_shape[1]); + const std::ptrdiff_t d = static_cast(q_shape[2]); + const std::ptrdiff_t sk = static_cast(k_shape[0]); + const std::ptrdiff_t nh_kv = static_cast(k_shape[1]); + const std::ptrdiff_t dv = static_cast(v_shape[2]); + const std::ptrdiff_t n_rep = nh / nh_kv; + const bool parallel_heads = nh * sk * d >= 32768; + +#pragma omp parallel for schedule(static) if (parallel_heads) + for (std::ptrdiff_t h = 0; h < nh; ++h) { + const std::ptrdiff_t h_kv = h / n_rep; + const float *q_vec = q + h * d; + float *out_ptr = attn_val + h * dv; + +#pragma omp simd + for (std::ptrdiff_t l = 0; l < dv; ++l) { + out_ptr[l] = 0.0f; + } + + float max_score = -std::numeric_limits::infinity(); + float sum_exp = 0.0f; + for (std::ptrdiff_t j = 0; j < sk; ++j) { + const float *k_vec = k + (j * nh_kv + h_kv) * d; + const float *v_vec = v + (j * nh_kv + h_kv) * dv; + float dot = 0.0f; +#pragma omp simd reduction(+ : dot) + for (std::ptrdiff_t l = 0; l < d; ++l) { + dot += q_vec[l] * k_vec[l]; + } + const float score = dot * scale; + + if (score > max_score) { + const float rescale = std::exp(max_score - score); + if (sum_exp > 0.0f) { + scale_f32(out_ptr, rescale, dv); + } + sum_exp = sum_exp * rescale + 1.0f; + max_score = score; +#pragma omp simd + for (std::ptrdiff_t l = 0; l < dv; ++l) { + out_ptr[l] += v_vec[l]; + } + } else { + const float weight = std::exp(score - max_score); + sum_exp += weight; +#pragma omp simd + for (std::ptrdiff_t l = 0; l < dv; ++l) { + out_ptr[l] += weight * v_vec[l]; + } + } + } + + if (sum_exp > 0.0f) { + scale_f32(out_ptr, 1.0f / sum_exp, dv); + } + } +} + +void self_attention_f32(float *attn_val, const float *q, const float *k, const float *v, float scale, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape) { + const std::ptrdiff_t sq = static_cast(q_shape[0]); + const std::ptrdiff_t nh = static_cast(q_shape[1]); + const std::ptrdiff_t d = static_cast(q_shape[2]); + const std::ptrdiff_t sk = static_cast(k_shape[0]); + const std::ptrdiff_t nh_kv = static_cast(k_shape[1]); + const std::ptrdiff_t dv = static_cast(v_shape[2]); + const std::ptrdiff_t n_rep = nh / nh_kv; + const bool parallel_work = sq * nh * sk * d >= 32768; + +#pragma omp parallel for collapse(2) schedule(static) if (parallel_work) + for (std::ptrdiff_t i = 0; i < sq; ++i) { + for (std::ptrdiff_t h = 0; h < nh; ++h) { + const std::ptrdiff_t h_kv = h / n_rep; + const std::ptrdiff_t q_abs_pos = sk - sq + i; + const std::ptrdiff_t valid_len = std::min(sk, q_abs_pos + 1); + const float *q_vec = q + (i * nh + h) * d; + float *out_ptr = attn_val + (i * nh + h) * dv; + + if (valid_len == 1) { + const float *v_vec = v + h_kv * dv; + std::copy(v_vec, v_vec + dv, out_ptr); + continue; + } + + std::vector scores(static_cast(valid_len)); + float max_score = -std::numeric_limits::infinity(); + for (std::ptrdiff_t j = 0; j < valid_len; ++j) { + const float *k_vec = k + (j * nh_kv + h_kv) * d; + float dot = 0.0f; +#pragma omp simd reduction(+ : dot) + for (std::ptrdiff_t l = 0; l < d; ++l) { + dot += q_vec[l] * k_vec[l]; + } + const float score = dot * scale; + scores[static_cast(j)] = score; + if (score > max_score) { + max_score = score; + } + } + + float sum_exp = 0.0f; + for (float &score : scores) { + score = std::exp(score - max_score); + sum_exp += score; + } + const float inv_sum = 1.0f / (sum_exp + 1e-10f); + +#pragma omp simd + for (std::ptrdiff_t l = 0; l < dv; ++l) { + out_ptr[l] = 0.0f; + } + + for (std::ptrdiff_t j = 0; j < valid_len; ++j) { + const float weight = scores[static_cast(j)] * inv_sum; + if (weight < 1e-10f) { + continue; + } + const float *v_vec = v + (j * nh_kv + h_kv) * dv; +#pragma omp simd + for (std::ptrdiff_t l = 0; l < dv; ++l) { + out_ptr[l] += weight * v_vec[l]; + } + } + } + } +} + +template +void self_attention_decode_generic(T *attn_val, const T *q, const T *k, const T *v, float scale, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape) { + const std::ptrdiff_t nh = static_cast(q_shape[1]); + const std::ptrdiff_t d = static_cast(q_shape[2]); + const std::ptrdiff_t sk = static_cast(k_shape[0]); + const std::ptrdiff_t nh_kv = static_cast(k_shape[1]); + const std::ptrdiff_t dv = static_cast(v_shape[2]); + const std::ptrdiff_t n_rep = nh / nh_kv; + const bool parallel_heads = nh * sk * d >= 32768; + +#pragma omp parallel for schedule(static) if (parallel_heads) + for (std::ptrdiff_t h = 0; h < nh; ++h) { + const std::ptrdiff_t h_kv = h / n_rep; + const T *q_vec = q + h * d; + std::vector out_accum(static_cast(dv), 0.0f); + float max_score = -std::numeric_limits::infinity(); + float sum_exp = 0.0f; + + for (std::ptrdiff_t j = 0; j < sk; ++j) { + const T *k_vec = k + (j * nh_kv + h_kv) * d; + const T *v_vec = v + (j * nh_kv + h_kv) * dv; + float dot = 0.0f; + for (std::ptrdiff_t l = 0; l < d; ++l) { + dot += llaisys::utils::cast(q_vec[l]) * llaisys::utils::cast(k_vec[l]); + } + const float score = dot * scale; + + if (score > max_score) { + const float rescale = std::exp(max_score - score); + if (sum_exp > 0.0f) { + for (std::ptrdiff_t l = 0; l < dv; ++l) { + out_accum[static_cast(l)] *= rescale; + } + } + sum_exp = sum_exp * rescale + 1.0f; + max_score = score; + for (std::ptrdiff_t l = 0; l < dv; ++l) { + out_accum[static_cast(l)] += llaisys::utils::cast(v_vec[l]); + } + } else { + const float weight = std::exp(score - max_score); + sum_exp += weight; + for (std::ptrdiff_t l = 0; l < dv; ++l) { + out_accum[static_cast(l)] += weight * llaisys::utils::cast(v_vec[l]); + } + } + } + + T *out_ptr = attn_val + h * dv; + const float inv_sum = sum_exp > 0.0f ? 1.0f / sum_exp : 0.0f; + for (std::ptrdiff_t l = 0; l < dv; ++l) { + out_ptr[l] = llaisys::utils::cast(out_accum[static_cast(l)] * inv_sum); + } + } +} + +template +void self_attention_generic(T *attn_val, const T *q, const T *k, const T *v, float scale, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape) { + const std::ptrdiff_t sq = static_cast(q_shape[0]); + const std::ptrdiff_t nh = static_cast(q_shape[1]); + const std::ptrdiff_t d = static_cast(q_shape[2]); + const std::ptrdiff_t sk = static_cast(k_shape[0]); + const std::ptrdiff_t nh_kv = static_cast(k_shape[1]); + const std::ptrdiff_t dv = static_cast(v_shape[2]); + const std::ptrdiff_t n_rep = nh / nh_kv; + const bool parallel_work = sq * nh * sk * d >= 32768; + +#pragma omp parallel for collapse(2) schedule(static) if (parallel_work) + for (std::ptrdiff_t i = 0; i < sq; ++i) { + for (std::ptrdiff_t h = 0; h < nh; ++h) { + const std::ptrdiff_t h_kv = h / n_rep; + const std::ptrdiff_t q_abs_pos = sk - sq + i; + const std::ptrdiff_t valid_len = std::min(sk, q_abs_pos + 1); + const T *q_vec = q + (i * nh + h) * d; + std::vector scores(static_cast(valid_len)); + float max_score = -std::numeric_limits::infinity(); + + for (std::ptrdiff_t j = 0; j < valid_len; ++j) { + const T *k_vec = k + (j * nh_kv + h_kv) * d; + float dot = 0.0f; + for (std::ptrdiff_t l = 0; l < d; ++l) { + dot += llaisys::utils::cast(q_vec[l]) * llaisys::utils::cast(k_vec[l]); + } + const float score = dot * scale; + scores[static_cast(j)] = score; + if (score > max_score) { + max_score = score; + } + } + + float sum_exp = 0.0f; + for (float &score : scores) { + score = std::exp(score - max_score); + sum_exp += score; + } + const float inv_sum = 1.0f / (sum_exp + 1e-10f); + + std::vector out_accum(static_cast(dv), 0.0f); + for (std::ptrdiff_t j = 0; j < valid_len; ++j) { + const float weight = scores[static_cast(j)] * inv_sum; + if (weight < 1e-10f) { + continue; + } + + const T *v_vec = v + (j * nh_kv + h_kv) * dv; + for (std::ptrdiff_t l = 0; l < dv; ++l) { + out_accum[static_cast(l)] += weight * llaisys::utils::cast(v_vec[l]); + } + } + + T *out_ptr = attn_val + (i * nh + h) * dv; + for (std::ptrdiff_t l = 0; l < dv; ++l) { + out_ptr[l] = llaisys::utils::cast(out_accum[static_cast(l)]); + } + } + } +} + +} // namespace + +namespace llaisys::ops::cpu { + +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, float scale, llaisysDataType_t dtype, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape) { + const bool single_query = q_shape[0] == 1; + switch (dtype) { + case LLAISYS_DTYPE_F32: +#if (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) && (defined(__GNUC__) || defined(__clang__)) + if (has_avx2_fma()) { + if (single_query) { + return self_attention_decode_f32_avx2(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), scale, q_shape, k_shape, v_shape); + } + return self_attention_f32_avx2(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), scale, q_shape, k_shape, v_shape); + } +#endif + if (single_query) { + return self_attention_decode_f32(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), scale, q_shape, k_shape, v_shape); + } + return self_attention_f32(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), scale, q_shape, k_shape, v_shape); + case LLAISYS_DTYPE_BF16: + if (single_query) { + return self_attention_decode_generic(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), scale, q_shape, k_shape, v_shape); + } + return self_attention_generic(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), scale, q_shape, k_shape, v_shape); + case LLAISYS_DTYPE_F16: + if (single_query) { + return self_attention_decode_generic(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), scale, q_shape, k_shape, v_shape); + } + return self_attention_generic(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), scale, q_shape, k_shape, v_shape); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/cpu/self_attention_cpu.hpp b/src/ops/self_attention/cpu/self_attention_cpu.hpp new file mode 100644 index 00000000..e73a7df3 --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "llaisys.h" +#include +#include + +namespace llaisys::ops::cpu { + +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, float scale, llaisysDataType_t dtype, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape); + +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/nvidia/self_attention_nvidia.cu b/src/ops/self_attention/nvidia/self_attention_nvidia.cu new file mode 100644 index 00000000..2d7c1e29 --- /dev/null +++ b/src/ops/self_attention/nvidia/self_attention_nvidia.cu @@ -0,0 +1,145 @@ +#include "self_attention_nvidia.cuh" + +#include "../../nvidia/nvidia_common.cuh" + +namespace { + +constexpr size_t kMaxValueDim = 1024; + +template +__global__ void self_attention_kernel(T *attn_val, + const T *q, + const T *k, + const T *v, + float scale, + int64_t sq, + int64_t nh, + int64_t d, + int64_t sk, + int64_t nh_kv, + int64_t dv) { + if (threadIdx.x != 0) { + return; + } + + const int64_t work_idx = static_cast(blockIdx.x); + const int64_t qi = work_idx / nh; + const int64_t h = work_idx % nh; + const int64_t n_rep = nh / nh_kv; + const int64_t h_kv = h / n_rep; + const int64_t q_abs_pos = sk - sq + qi; + const int64_t valid_len = q_abs_pos + 1 < sk ? q_abs_pos + 1 : sk; + const T *q_vec = q + (qi * nh + h) * d; + T *out_ptr = attn_val + (qi * nh + h) * dv; + + float out_local[kMaxValueDim]; + for (int64_t l = 0; l < dv; ++l) { + out_local[l] = 0.0f; + } + + if (valid_len <= 0) { + for (int64_t l = 0; l < dv; ++l) { + out_ptr[l] = llaisys::device::nvidia::cuda_utils::fromFloat(0.0f); + } + return; + } + + float max_score = -1.0e30f; + float sum_exp = 0.0f; + for (int64_t j = 0; j < valid_len; ++j) { + const T *k_vec = k + (j * nh_kv + h_kv) * d; + const T *v_vec = v + (j * nh_kv + h_kv) * dv; + + float dot = 0.0f; + for (int64_t l = 0; l < d; ++l) { + dot += llaisys::device::nvidia::cuda_utils::toFloat(q_vec[l]) * + llaisys::device::nvidia::cuda_utils::toFloat(k_vec[l]); + } + const float score = dot * scale; + + if (score > max_score) { + const float rescale = expf(max_score - score); + if (sum_exp > 0.0f) { + for (int64_t l = 0; l < dv; ++l) { + out_local[l] *= rescale; + } + } + sum_exp = sum_exp * rescale + 1.0f; + max_score = score; + for (int64_t l = 0; l < dv; ++l) { + out_local[l] += llaisys::device::nvidia::cuda_utils::toFloat(v_vec[l]); + } + } else { + const float weight = expf(score - max_score); + sum_exp += weight; + for (int64_t l = 0; l < dv; ++l) { + out_local[l] += weight * llaisys::device::nvidia::cuda_utils::toFloat(v_vec[l]); + } + } + } + + const float inv_sum = 1.0f / sum_exp; + for (int64_t l = 0; l < dv; ++l) { + out_ptr[l] = llaisys::device::nvidia::cuda_utils::fromFloat(out_local[l] * inv_sum); + } +} + +template +void self_attention_impl(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + float scale, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape) { + const int64_t sq = static_cast(q_shape[0]); + const int64_t nh = static_cast(q_shape[1]); + const int64_t d = static_cast(q_shape[2]); + const int64_t sk = static_cast(k_shape[0]); + const int64_t nh_kv = static_cast(k_shape[1]); + const int64_t dv = static_cast(v_shape[2]); + + CHECK_ARGUMENT(static_cast(dv) <= kMaxValueDim, "NVIDIA self_attention only supports value head dim up to 1024 in this implementation."); + + self_attention_kernel<<(sq * nh), 1, 0, llaisys::ops::nvidia::current_stream()>>>( + reinterpret_cast(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + scale, + sq, + nh, + d, + sk, + nh_kv, + dv); + LLAISYS_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void self_attention(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + float scale, + llaisysDataType_t dtype, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return self_attention_impl(attn_val, q, k, v, scale, q_shape, k_shape, v_shape); + case LLAISYS_DTYPE_F16: + return self_attention_impl(attn_val, q, k, v, scale, q_shape, k_shape, v_shape); + case LLAISYS_DTYPE_BF16: + return self_attention_impl(attn_val, q, k, v, scale, q_shape, k_shape, v_shape); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/self_attention/nvidia/self_attention_nvidia.cuh b/src/ops/self_attention/nvidia/self_attention_nvidia.cuh new file mode 100644 index 00000000..2d722780 --- /dev/null +++ b/src/ops/self_attention/nvidia/self_attention_nvidia.cuh @@ -0,0 +1,18 @@ +#pragma once + +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::nvidia { +void self_attention(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + float scale, + llaisysDataType_t dtype, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape); +} diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d62014..d103aba4 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,49 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/self_attention_cpu.hpp" +#include "nvidia/self_attention_nvidia.cuh" namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(attn_val, q, k, v); + CHECK_SAME_DTYPE(attn_val->dtype(), q->dtype()); + CHECK_SAME_DTYPE(q->dtype(), k->dtype()); + CHECK_SAME_DTYPE(k->dtype(), v->dtype()); + + ASSERT(q->shape().size() == 3, "SelfAttention: q must be 3-D [seqlen, nhead, d]."); + ASSERT(k->shape().size() == 3, "SelfAttention: k must be 3-D [total_len, nkvhead, d]."); + ASSERT(v->shape().size() == 3, "SelfAttention: v must be 3-D [total_len, nkvhead, dv]."); + ASSERT(attn_val->shape().size() == 3, "SelfAttention: attn_val must be 3-D [seqlen, nhead, dv]."); + + size_t nh = q->shape()[1]; + size_t nh_kv = k->shape()[1]; + size_t d = q->shape()[2]; + + // GQA Check + ASSERT(nh % nh_kv == 0, "SelfAttention: nhead must be divisible by nkvhead (GQA constraint)."); + ASSERT(k->shape()[2] == d, "SelfAttention: Q and K head_dim mismatch."); + ASSERT(attn_val->shape()[0] == q->shape()[0], "SelfAttention: Output seqlen mismatch."); + ASSERT(attn_val->shape()[1] == nh, "SelfAttention: Output nhead mismatch."); + ASSERT(attn_val->shape()[2] == v->shape()[2], "SelfAttention: Output head_dim mismatch with V."); + + ASSERT(q->isContiguous() && k->isContiguous() && v->isContiguous() && attn_val->isContiguous(), + "SelfAttention: Inputs must be contiguous."); + + if (attn_val->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), scale, attn_val->dtype(), q->shape(), k->shape(), v->shape()); + } + + llaisys::core::context().setDevice(attn_val->deviceType(), attn_val->deviceId()); + switch (attn_val->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), scale, attn_val->dtype(), q->shape(), k->shape(), v->shape()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::self_attention(attn_val->data(), q->data(), k->data(), v->data(), scale, attn_val->dtype(), q->shape(), k->shape(), v->shape()); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 00000000..7df83cc6 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,49 @@ + +#include "swiglu_cpu.hpp" +#include "../../../utils.hpp" + +#include +#include + +namespace { + +void swiglu_f32(float *out, const float *gate, const float *up, std::ptrdiff_t numel) { +#pragma omp parallel for schedule(static) if (numel >= 4096) + for (std::ptrdiff_t i = 0; i < numel; ++i) { + const float g_val = gate[i]; + out[i] = up[i] * (g_val / (1.0f + std::exp(-g_val))); + } +} + +template +void swiglu_generic(T *out, const T *gate, const T *up, std::ptrdiff_t numel) { +#pragma omp parallel for schedule(static) if (numel >= 4096) + for (std::ptrdiff_t i = 0; i < numel; ++i) { + const float g_val = llaisys::utils::cast(gate[i]); + const float u_val = llaisys::utils::cast(up[i]); + out[i] = llaisys::utils::cast(u_val * (g_val / (1.0f + std::exp(-g_val)))); + } +} + +} // namespace + +namespace llaisys::ops::cpu { + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t dtype, size_t numel) { + const auto elem_count = static_cast(numel); + switch (dtype) { + case LLAISYS_DTYPE_F32: + swiglu_f32(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), elem_count); + break; + case LLAISYS_DTYPE_BF16: + swiglu_generic(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), elem_count); + break; + case LLAISYS_DTYPE_F16: + swiglu_generic(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), elem_count); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 00000000..85d84006 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t dtype, size_t numel); +} diff --git a/src/ops/swiglu/nvidia/swiglu_nvidia.cu b/src/ops/swiglu/nvidia/swiglu_nvidia.cu new file mode 100644 index 00000000..049520a5 --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_nvidia.cu @@ -0,0 +1,48 @@ +#include "swiglu_nvidia.cuh" + +#include "../../nvidia/nvidia_common.cuh" + +namespace { + +template +__global__ void swiglu_kernel(T *out, const T *gate, const T *up, size_t numel) { + const size_t idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x); + if (idx >= numel) { + return; + } + + const float g = llaisys::device::nvidia::cuda_utils::toFloat(gate[idx]); + const float u = llaisys::device::nvidia::cuda_utils::toFloat(up[idx]); + const float silu = g / (1.0f + expf(-g)); + out[idx] = llaisys::device::nvidia::cuda_utils::fromFloat(u * silu); +} + +template +void swiglu_impl(std::byte *out, const std::byte *gate, const std::byte *up, size_t numel) { + constexpr int threads = 256; + swiglu_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + LLAISYS_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t dtype, size_t numel) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return swiglu_impl(out, gate, up, numel); + case LLAISYS_DTYPE_F16: + return swiglu_impl(out, gate, up, numel); + case LLAISYS_DTYPE_BF16: + return swiglu_impl(out, gate, up, numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/swiglu/nvidia/swiglu_nvidia.cuh b/src/ops/swiglu/nvidia/swiglu_nvidia.cuh new file mode 100644 index 00000000..fb3ee838 --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_nvidia.cuh @@ -0,0 +1,9 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t dtype, size_t numel); +} diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc9..5e24e015 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,35 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/swiglu_cpu.hpp" +#include "nvidia/swiglu_nvidia.cuh" namespace llaisys::ops { void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, gate, up); + CHECK_SAME_DTYPE(out->dtype(), gate->dtype()); + CHECK_SAME_DTYPE(gate->dtype(), up->dtype()); + + ASSERT(gate->shape() == up->shape(), "SwiGLU: gate and up tensor shapes must match."); + ASSERT(out->shape() == gate->shape(), "SwiGLU: output tensor shape must match input."); + ASSERT(gate->isContiguous() && up->isContiguous() && out->isContiguous(), "SwiGLU: Inputs/Output tensors must be contiguous."); + + size_t numel = out->numel(); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb6..0c68dc98 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -164,27 +164,91 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + size_t ndim_ = this->ndim(); + ptrdiff_t stride = 1; + const auto &shape = this->shape(); + const auto &strides = this->strides(); + for (size_t i = 1; i <= ndim_; i++) { + if(strides[ndim_ - i] != stride) return false; + stride *= shape[ndim_ - i]; + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t ndim_ = order.size(); + if(ndim_ != this->ndim()){ + throw std::runtime_error("permute: order size does not match tensor ndim"); + } + std::vector seen(ndim_, false); + std::vector shape(ndim_); + std::vector strides(ndim_); + const auto &old_shape = this->shape(); + const auto &old_strides = this->strides(); + for(size_t i = 0; i < ndim_; ++i){ + size_t idx = order[i]; + if(idx < 0 || idx >= ndim_){ + throw std::runtime_error("permute: order index out of range"); + } + if(seen[idx]){ + throw std::runtime_error("permute: duplicate indices in order"); + } + seen[idx] = true; + shape[i] = old_shape[idx]; + strides[i] = old_strides[idx]; + } + TensorMeta meta = {this->dtype(), shape, strides}; + return std::shared_ptr(new Tensor(meta, _storage, _offset)); } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t numel = std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); + if(numel != this->numel()){ + throw std::runtime_error("view: new shape size does not match tensor size"); + } + if (!this->isContiguous()) { + throw std::runtime_error("view: input tensor must be contiguous. call .contiguous() first."); + } + size_t ndim_ = shape.size(); + std::vector strides(ndim_); + size_t stride = 1; + for (size_t i = 1; i <= ndim_; i++) { + strides[ndim_ - i] = stride; + stride *= shape[ndim_ - i]; + } + TensorMeta meta = {this->dtype(), shape, strides}; + return std::shared_ptr(new Tensor(meta, _storage, _offset)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t ndim_ = this->ndim(); + if(dim >= ndim_){ + throw std::runtime_error("slice: dim out of range"); + } + const auto& old_shape = this->shape(); + if(start >= end || end > old_shape[dim]){ + throw std::runtime_error("slice: invalid start or end"); + } + std::vector shape = old_shape;; + shape[dim] = end - start; + TensorMeta meta = {this->dtype(), shape, this->strides()}; + size_t offset = _offset + start * this->strides()[dim] * this->elementSize(); + return std::shared_ptr(new Tensor(meta, _storage, offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + size_t total_size = this->numel() * this->elementSize(); + void *dst_ = this->data(); + if (this->deviceType() == LLAISYS_DEVICE_NVIDIA) { + core::context().setDevice(this->deviceType(), this->deviceId()); + core::context().runtime().api()->memcpy_sync( + dst_, + src_, + total_size, + LLAISYS_MEMCPY_H2D); + } else { + std::memcpy(dst_, src_, total_size); + } } tensor_t Tensor::contiguous() const { diff --git a/test/chat_test_utils.py b/test/chat_test_utils.py new file mode 100644 index 00000000..005f3cd4 --- /dev/null +++ b/test/chat_test_utils.py @@ -0,0 +1,167 @@ +import socket +import time +from itertools import count +from threading import Lock +from typing import Iterator, Tuple + + +class FakeTokenizer: + def apply_chat_template(self, conversation, add_generation_prompt=True, tokenize=False): + rendered = [] + for message in conversation: + role = message["role"] + content = message["content"] + rendered.append(f"{role}:{content}" if content else f"{role}:") + prompt = "\n".join(rendered) + if add_generation_prompt: + prompt += "\nassistant:" + return prompt + + def encode(self, text, return_tensors=None): + tokens = [ord(ch) for ch in text] + if return_tensors == "pt": + raise NotImplementedError("FakeTokenizer does not support tensor outputs") + return tokens or [0] + + def decode(self, token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False): + return "".join(chr(token_id) for token_id in token_ids if 0 <= token_id < 256) + + +class _Meta: + end_token = 33 + + +class FakeModel: + _instance_ids = count() + + def __init__(self, *, delay_s: float = 0.0, calls_store=None): + self.meta = _Meta() + self._reply = [72, 101, 108, 108, 111, 33] + self.delay_s = float(delay_s) + self.instance_id = next(self._instance_ids) + self.calls = calls_store if calls_store is not None else [] + self._calls_lock = Lock() + self._base_prompt_len = 0 + self._cur_pos = 0 + + def _record(self, payload): + with self._calls_lock: + self.calls.append(payload) + + def generate_next(self, inputs, top_k=1, top_p=1.0, temperature=1.0, reset_state=False): + if reset_state: + self.reset() + if not inputs: + raise ValueError("inputs must not be empty") + + self._record( + { + "mode": "generate_next", + "inputs": list(inputs), + "reset_state": reset_state, + "instance_id": self.instance_id, + } + ) + if self.delay_s > 0: + time.sleep(self.delay_s) + + if reset_state or self._base_prompt_len == 0 or len(inputs) > 1 or self._cur_pos < self._base_prompt_len: + self._base_prompt_len = self._cur_pos + len(inputs) + + self._cur_pos += len(inputs) + index = max(0, self._cur_pos - self._base_prompt_len) + if index >= len(self._reply): + index = len(self._reply) - 1 + return self._reply[index] + + def generate(self, inputs, max_new_tokens=20, top_k=1, top_p=1.0, temperature=1.0, reset_state=True): + self._record( + { + "mode": "generate", + "inputs": list(inputs), + "reset_state": reset_state, + "instance_id": self.instance_id, + } + ) + generated = [] + token_source = list(inputs) + for step in range(max_new_tokens): + token = self.generate_next( + token_source, + top_k=top_k, + top_p=top_p, + temperature=temperature, + reset_state=reset_state if step == 0 else False, + ) + generated.append(token) + token_source = [token] + if token == self.meta.end_token: + break + return list(inputs) + generated + + def stream_generate( + self, + inputs, + *, + tokenizer=None, + max_new_tokens=20, + top_k=1, + top_p=1.0, + temperature=1.0, + reset_state=True, + ) -> Iterator[Tuple[int, str]]: + self._record( + { + "mode": "stream_generate", + "inputs": list(inputs), + "reset_state": reset_state, + "instance_id": self.instance_id, + } + ) + produced = [] + token_source = list(inputs) + for step in range(max_new_tokens): + token = self.generate_next( + token_source, + top_k=top_k, + top_p=top_p, + temperature=temperature, + reset_state=reset_state if step == 0 else False, + ) + produced.append(token) + text = tokenizer.decode(produced, skip_special_tokens=True, clean_up_tokenization_spaces=False) if tokenizer else "" + prev = tokenizer.decode(produced[:-1], skip_special_tokens=True, clean_up_tokenization_spaces=False) if tokenizer else "" + yield token, text[len(prev):] if text.startswith(prev) else text + token_source = [token] + if token == self.meta.end_token: + break + + def truncate(self, position: int) -> None: + self._cur_pos = int(position) + if self._cur_pos == 0: + self._base_prompt_len = 0 + self._record({"mode": "truncate", "position": int(position), "instance_id": self.instance_id}) + + def reset(self) -> None: + self._cur_pos = 0 + self._base_prompt_len = 0 + self._record({"mode": "reset", "instance_id": self.instance_id}) + + +def serve_fake_app(port: int) -> None: + import uvicorn + from llaisys.chat_server import create_app + + app = create_app(tokenizer=FakeTokenizer(), model=FakeModel(), model_name="fake-qwen") + uvicorn.run(app, host="127.0.0.1", port=port, log_level="error") + + +def wait_for_server(port: int, timeout: float = 10.0) -> None: + deadline = time.time() + timeout + while time.time() < deadline: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.2) + if sock.connect_ex(("127.0.0.1", port)) == 0: + return + time.sleep(0.1) + raise TimeoutError(f"Server on port {port} did not start in time") diff --git a/test/ops/rope.py b/test/ops/rope.py index fe59dd11..fac351da 100644 --- a/test/ops/rope.py +++ b/test/ops/rope.py @@ -67,6 +67,7 @@ def test_op_rope( parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ + ((1, 4, 8), (7, 8)), ((2, 1, 4), (0, 2)), ((512, 4, 4096), (512, 1024))] testDtypePrec = [ diff --git a/test/ops/rope_debug.py b/test/ops/rope_debug.py new file mode 100644 index 00000000..54b5352c --- /dev/null +++ b/test/ops/rope_debug.py @@ -0,0 +1,113 @@ +import sys +import os +import torch +import numpy as np + +# Adjust path to find llaisys package +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../python")) +sys.path.insert(0, parent_dir) +import llaisys + +def torch_rope(y: torch.Tensor, x: torch.Tensor, pos_ids: torch.Tensor, theta: float): + seq_len, n_heads, head_dim = y.shape + x_a, x_b = x[..., : head_dim // 2], x[..., head_dim // 2 :] + positions = pos_ids.to(torch.float32).unsqueeze(1) + i = torch.arange(0, head_dim // 2, dtype=torch.float32, device=y.device) + freqs = positions / (theta ** (2 * i / head_dim)) + sin, cos = freqs.sin(), freqs.cos() + sin = sin.unsqueeze(1) + cos = cos.unsqueeze(1) + y[..., : head_dim // 2] = x_a * cos - x_b * sin + y[..., head_dim // 2 :] = x_b * cos + x_a * sin + +def debug_rope(): + # Configuration matching the failing case + shape = (512, 4, 4096) + start_pos = 512 + end_pos = 1024 + dtype = torch.float32 + theta = 10000.0 + + print(f"Debugging RoPE with shape={shape}, range=[{start_pos}, {end_pos}), dtype={dtype}") + + # 1. Setup Data + torch.manual_seed(42) + x = torch.randn(shape, dtype=dtype) + pos_ids = torch.arange(start_pos, end_pos, dtype=torch.int64) + y_torch = torch.zeros_like(x) + + # 2. Run PyTorch + torch_rope(y_torch, x, pos_ids, theta) + + # 3. Setup LLAISYS + # Helpers + device_enum = llaisys.DeviceType.CPU + dt_enum = llaisys.DataType.F32 + api = llaisys.RuntimeAPI(device_enum) + + # Create LLAISYS tensors + x_ll = llaisys.Tensor(shape, dtype=dt_enum, device=device_enum) + y_ll = llaisys.Tensor(shape, dtype=dt_enum, device=device_enum) + pos_ll = llaisys.Tensor((len(pos_ids),), dtype=llaisys.DataType.I64, device=device_enum) + + # Copy Input Data (x, pos_ids) + # Using HostToHost since we are on CPU + kind = llaisys.MemcpyKind.HostToHost + + api.memcpy_sync(x_ll.data_ptr(), x.data_ptr(), x.numel() * x.element_size(), kind) + api.memcpy_sync(pos_ll.data_ptr(), pos_ids.data_ptr(), pos_ids.numel() * pos_ids.element_size(), kind) + + # Run Op + llaisys.Ops.rope(y_ll, x_ll, pos_ll, theta) + + # Copy Output Data back + y_llaisys = torch.zeros_like(x) + api.memcpy_sync(y_llaisys.data_ptr(), y_ll.data_ptr(), y_ll.numel() * y_ll.element_size(), kind) + + # 4. Analyze Error + diff = (y_torch - y_llaisys).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + print(f"Max Diff: {max_diff:.2e}") + print(f"Mean Diff: {mean_diff:.2e}") + + # 5. Detailed Breakdown + if max_diff > 1e-5: # Only show details if significant error + max_indices = torch.nonzero(diff == max_diff) + if len(max_indices) > 0: + idx = max_indices[0] + seq_idx, head_idx, dim_idx = idx.tolist() + print(f"Max error at index: seq={seq_idx}, head={head_idx}, dim={dim_idx}") + curr_pos = pos_ids[seq_idx].item() + print(f"Pos ID at failure: {curr_pos}") + + # Theoretical calc + head_dim = shape[2] + freq_idx = dim_idx if dim_idx < head_dim // 2 else dim_idx - head_dim // 2 + + freq_exponent_f = (2.0 * freq_idx) / head_dim + denom_f = theta ** freq_exponent_f + angle_f = curr_pos / denom_f + + # Double precision check + freq_exponent_d = (2.0 * freq_idx) / float(head_dim) + denom_d = theta ** freq_exponent_d + angle_d = curr_pos / denom_d + + print(f"Angle(float) approx: {angle_f}") + print(f"Angle(double) approx: {angle_d}") + + val_t = y_torch[seq_idx, head_idx, dim_idx].item() + val_l = y_llaisys[seq_idx, head_idx, dim_idx].item() + print(f"Values: Torch={val_t:.8f}, LLAISYS={val_l:.8f}") + print(f"Diff: {abs(val_t - val_l):.8f}") + + if max_diff > 5e-4: + print("\n\033[91mFAILED: Error exceeds 5e-4\033[0m") + sys.exit(1) + else: + print("\n\033[92mPASSED\033[0m") + +if __name__ == "__main__": + debug_rope() diff --git a/test/ops/sample.py b/test/ops/sample.py new file mode 100644 index 00000000..4f54f3af --- /dev/null +++ b/test/ops/sample.py @@ -0,0 +1,83 @@ +import os +import sys + +import torch + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) + +import llaisys +from test_utils import check_equal, llaisys_device, random_tensor, torch_dtype, torch_device + + +def tensor_from_torch(torch_tensor, device_name="cpu"): + tensor = llaisys.Tensor( + torch_tensor.shape, + dtype=llaisys.DataType.F32, + device=llaisys_device(device_name), + ) + api = llaisys.RuntimeAPI(llaisys_device(device_name)) + api.memcpy_sync( + tensor.data_ptr(), + torch_tensor.data_ptr(), + torch_tensor.numel() * torch_tensor.element_size(), + llaisys.MemcpyKind.D2D, + ) + return tensor + + +def allowed_indices(logits: torch.Tensor, top_k: int, top_p: float, temperature: float): + scaled = logits / temperature + k = min(top_k, logits.numel()) + values, indices = torch.topk(scaled, k) + probs = torch.softmax(values, dim=0) + if top_p >= 1.0: + return set(indices.tolist()) + cumulative = torch.cumsum(probs, dim=0) + keep = int((cumulative < top_p).sum().item()) + 1 + return set(indices[:keep].tolist()) + + +def test_top_k_one_is_argmax(device_name="cpu"): + logits = torch.tensor([0.25, 1.5, -2.0, 0.9], dtype=torch.float32, device=torch_device(device_name)) + token = llaisys.Ops.sample(tensor_from_torch(logits, device_name), top_k=1, top_p=1.0, temperature=1.0) + assert token == int(torch.argmax(logits).item()) + + +def test_temperature_fallback(device_name="cpu"): + logits = torch.tensor([0.1, 0.4, 0.3], dtype=torch.float32, device=torch_device(device_name)) + token = llaisys.Ops.sample(tensor_from_torch(logits, device_name), top_k=5, top_p=1.0, temperature=0.0) + assert token == int(torch.argmax(logits).item()) + + +def test_top_k_membership(device_name="cpu"): + logits = torch.tensor([0.1, 2.0, 1.5, 3.0], dtype=torch.float32, device=torch_device(device_name)) + allowed = set(torch.topk(logits, 2).indices.tolist()) + tensor = tensor_from_torch(logits, device_name) + for _ in range(100): + token = llaisys.Ops.sample(tensor, top_k=2, top_p=1.0, temperature=0.8) + assert token in allowed + + +def test_top_p_membership(device_name="cpu"): + logits = torch.tensor([4.0, 3.0, 2.0, 1.0], dtype=torch.float32, device=torch_device(device_name)) + allowed = allowed_indices(logits, top_k=4, top_p=0.7, temperature=0.9) + tensor = tensor_from_torch(logits, device_name) + for _ in range(100): + token = llaisys.Ops.sample(tensor, top_k=4, top_p=0.7, temperature=0.9) + assert token in allowed + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="cpu", choices=["cpu"], type=str) + args = parser.parse_args() + + test_top_k_one_is_argmax(args.device) + test_temperature_fallback(args.device) + test_top_k_membership(args.device) + test_top_p_membership(args.device) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops/self_attention.py b/test/ops/self_attention.py index a042b51b..cfee7224 100644 --- a/test/ops/self_attention.py +++ b/test/ops/self_attention.py @@ -15,7 +15,7 @@ def torch_self_attention(attn_val, query, key, value, scale): L, S = query.size(-2), key.size(-2) attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=S-L) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=S-L) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) @@ -71,6 +71,7 @@ def test_op_self_attention( testShapes = [ # qlen, kvlen, nh, nkvh, hd (2, 2, 1, 1, 4), + (1, 11, 4, 2, 8), (5, 11, 4, 2, 8), ] testDtypePrec = [ diff --git a/test/test_chat_api.py b/test/test_chat_api.py new file mode 100644 index 00000000..ccfd6f88 --- /dev/null +++ b/test/test_chat_api.py @@ -0,0 +1,376 @@ +import sys +import threading +import time + + +def _has_optional_deps() -> bool: + try: + import fastapi # noqa: F401 + import httpx # noqa: F401 + except ImportError: + return False + return True + + +if not _has_optional_deps(): + print("Skipped: optional server dependencies are not installed.") + sys.exit(0) + +from fastapi.testclient import TestClient + +from chat_test_utils import FakeModel, FakeTokenizer +from llaisys.chat_server import create_app + + +def test_non_stream_response() -> None: + app = create_app(tokenizer=FakeTokenizer(), model=FakeModel(), model_name="fake-qwen") + client = TestClient(app) + response = client.post( + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 8, + "top_k": 10, + "top_p": 0.9, + "temperature": 0.8, + }, + ) + response.raise_for_status() + payload = response.json() + assert payload["object"] == "chat.completion" + assert payload["model"] == "fake-qwen" + assert payload["session_id"].startswith("session-") + assert payload["cache_reused_tokens"] == 0 + assert payload["worker_id"] == 0 + assert payload["batch_id"] == 1 + assert payload["dispatch_count"] == 6 + assert payload["last_batch_id"] >= payload["batch_id"] + assert payload["choices"][0]["message"]["content"] == "Hello!" + assert payload["choices"][0]["finish_reason"] == "stop" + assert payload["usage"]["completion_tokens"] == 6 + + +def test_stream_response() -> None: + app = create_app(tokenizer=FakeTokenizer(), model=FakeModel(), model_name="fake-qwen") + client = TestClient(app) + chunks = [] + with client.stream( + "POST", + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 8, + "top_k": 10, + "top_p": 0.9, + "temperature": 0.8, + "stream": True, + }, + ) as response: + response.raise_for_status() + for line in response.iter_lines(): + if line: + chunks.append(line) + + assert chunks[-1] == "data: [DONE]" + content_parts = [] + for line in chunks[:-1]: + if not line.startswith("data: "): + continue + event = line[6:] + if event == "[DONE]": + continue + payload = __import__("json").loads(event) + delta = payload["choices"][0]["delta"] + if "content" in delta: + content_parts.append(delta["content"]) + assert "".join(content_parts) == "Hello!" + + +def test_service_stats_and_multi_user_batch_dispatch() -> None: + app = create_app( + tokenizer=FakeTokenizer(), + model_factory=lambda: FakeModel(delay_s=0.05), + model_name="fake-qwen", + num_workers=2, + max_batch_size=2, + batch_wait_ms=20, + ) + client = TestClient(app) + results = [None, None] + + def send(index: int, prompt: str) -> None: + response = client.post( + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "session_id": f"session-batch-{index}", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 8, + }, + ) + response.raise_for_status() + results[index] = response.json() + + t0 = threading.Thread(target=send, args=(0, "alpha")) + t1 = threading.Thread(target=send, args=(1, "beta")) + t0.start() + t1.start() + t0.join() + t1.join() + + payload0 = results[0] + payload1 = results[1] + assert payload0 is not None and payload1 is not None + assert payload0["choices"][0]["message"]["content"] == "Hello!" + assert payload1["choices"][0]["message"]["content"] == "Hello!" + assert payload0["batch_id"] == payload1["batch_id"] + assert {payload0["worker_id"], payload1["worker_id"]} == {0, 1} + assert payload0["dispatch_count"] == 6 + assert payload1["dispatch_count"] == 6 + + stats = client.get("/v1/service/stats") + stats.raise_for_status() + body = stats.json() + assert body["worker_count"] == 2 + assert body["scheduled_batches"] >= 6 + assert body["max_observed_batch_size"] >= 2 + assert body["completed_requests"] == 2 + assert body["queue_depth"] == 0 + assert body["requeued_requests"] >= 10 + assert body["total_generated_tokens"] == 12 + assert body["cached_session_count"] == 2 + + +def test_same_session_concurrent_request_is_rejected() -> None: + app = create_app( + tokenizer=FakeTokenizer(), + model_factory=lambda: FakeModel(delay_s=0.1), + model_name="fake-qwen", + num_workers=2, + max_batch_size=2, + batch_wait_ms=10, + ) + client = TestClient(app) + holder = {} + + def first_request() -> None: + response = client.post( + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "session_id": "session-busy", + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 8, + }, + ) + holder["first"] = response + + thread = threading.Thread(target=first_request) + thread.start() + + deadline = time.time() + 3.0 + while time.time() < deadline: + stats = client.get("/v1/service/stats") + stats.raise_for_status() + if stats.json()["active_requests"] > 0: + break + time.sleep(0.01) + + second = client.post( + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "session_id": "session-busy", + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 8, + }, + ) + assert second.status_code == 409 + assert "in-flight request" in second.json()["detail"] + + thread.join() + holder["first"].raise_for_status() + + +def test_session_reuse_response() -> None: + model = FakeModel() + app = create_app(tokenizer=FakeTokenizer(), model=model, model_name="fake-qwen") + client = TestClient(app) + session_id = "session-reuse" + + first = client.post( + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "session_id": session_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 8, + }, + ) + first.raise_for_status() + assistant_content = first.json()["choices"][0]["message"]["content"] + + second = client.post( + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "session_id": session_id, + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": assistant_content}, + {"role": "user", "content": "again"}, + ], + "max_tokens": 8, + }, + ) + second.raise_for_status() + payload = second.json() + assert payload["cache_reused_tokens"] > 0 + generate_next_calls = [call for call in model.calls if call["mode"] == "generate_next"] + assert generate_next_calls[0]["reset_state"] is True + resumed_call = next(call for call in generate_next_calls[1:] if call["reset_state"] is False and len(call["inputs"]) > 1) + assert resumed_call["reset_state"] is False + assert len(resumed_call["inputs"]) < len(generate_next_calls[0]["inputs"]) + len(assistant_content) + + +def test_session_affinity_preserves_cache_slot_across_workers() -> None: + app = create_app( + tokenizer=FakeTokenizer(), + model_factory=FakeModel, + model_name="fake-qwen", + num_workers=2, + max_batch_size=2, + batch_wait_ms=10, + ) + client = TestClient(app) + + first = client.post( + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "session_id": "session-affinity-a", + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 8, + }, + ) + first.raise_for_status() + first_payload = first.json() + + second = client.post( + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "session_id": "session-affinity-b", + "messages": [{"role": "user", "content": "beta"}], + "max_tokens": 8, + }, + ) + second.raise_for_status() + second_payload = second.json() + assert first_payload["worker_id"] != second_payload["worker_id"] + + assistant_content = first_payload["choices"][0]["message"]["content"] + third = client.post( + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "session_id": "session-affinity-a", + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": assistant_content}, + {"role": "user", "content": "again"}, + ], + "max_tokens": 8, + }, + ) + third.raise_for_status() + third_payload = third.json() + assert third_payload["worker_id"] == first_payload["worker_id"] + assert third_payload["cache_reused_tokens"] > 0 + + +def test_session_management_crud() -> None: + app = create_app(tokenizer=FakeTokenizer(), model=FakeModel(), model_name="fake-qwen") + client = TestClient(app) + + created = client.post( + "/v1/sessions", + json={ + "session_id": "session-crud", + "messages": [{"role": "system", "content": "You are terse."}], + }, + ) + created.raise_for_status() + assert created.json()["id"] == "session-crud" + + listed = client.get("/v1/sessions") + listed.raise_for_status() + session_ids = [item["id"] for item in listed.json()["data"]] + assert "session-crud" in session_ids + + fetched = client.get("/v1/sessions/session-crud") + fetched.raise_for_status() + assert fetched.json()["messages"][0]["content"] == "You are terse." + + updated = client.put( + "/v1/sessions/session-crud", + json={"messages": [{"role": "user", "content": "edited"}]}, + ) + updated.raise_for_status() + assert updated.json()["messages"][0]["content"] == "edited" + + deleted = client.delete("/v1/sessions/session-crud") + deleted.raise_for_status() + assert deleted.json()["deleted"] is True + + +def test_regenerate_uses_truncate_for_active_session() -> None: + model = FakeModel() + app = create_app(tokenizer=FakeTokenizer(), model=model, model_name="fake-qwen") + client = TestClient(app) + session_id = "session-regen" + + first = client.post( + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "session_id": session_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 8, + }, + ) + first.raise_for_status() + + second = client.post( + "/v1/chat/completions", + json={ + "model": "fake-qwen", + "session_id": session_id, + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 8, + }, + ) + second.raise_for_status() + + truncate_calls = [call for call in model.calls if call["mode"] == "truncate"] + assert truncate_calls, model.calls + assert truncate_calls[-1]["position"] > 0 + generate_next_calls = [call for call in model.calls if call["mode"] == "generate_next"] + assert generate_next_calls[0]["reset_state"] is True + truncate_index = max(idx for idx, call in enumerate(model.calls) if call["mode"] == "truncate") + resumed_call = next(call for call in model.calls[truncate_index + 1:] if call["mode"] == "generate_next") + assert resumed_call["reset_state"] is False + + +if __name__ == "__main__": + test_non_stream_response() + test_stream_response() + test_service_stats_and_multi_user_batch_dispatch() + test_same_session_concurrent_request_is_rejected() + test_session_reuse_response() + test_session_affinity_preserves_cache_slot_across_workers() + test_session_management_crud() + test_regenerate_uses_truncate_for_active_session() + print("\033[92mTest passed!\033[0m\n") diff --git a/test/test_chat_cli.py b/test/test_chat_cli.py new file mode 100644 index 00000000..b6452075 --- /dev/null +++ b/test/test_chat_cli.py @@ -0,0 +1,111 @@ +import multiprocessing +import os +import socket +import subprocess +import sys + + +def _has_optional_deps() -> bool: + try: + import fastapi # noqa: F401 + import httpx # noqa: F401 + import uvicorn # noqa: F401 + except ImportError: + return False + return True + + +if not _has_optional_deps(): + print("Skipped: optional server dependencies are not installed.") + sys.exit(0) + +from chat_test_utils import serve_fake_app, wait_for_server + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def test_cli_smoke() -> None: + port = _free_port() + process = multiprocessing.Process(target=serve_fake_app, args=(port,), daemon=True) + process.start() + try: + wait_for_server(port) + env = os.environ.copy() + repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + python_paths = [ + os.path.join(repo_root, "python"), + os.path.join(repo_root, "test"), + env.get("PYTHONPATH", ""), + ] + env["PYTHONPATH"] = os.pathsep.join(path for path in python_paths if path) + + result = subprocess.run( + [ + sys.executable, + "-m", + "llaisys.chat_cli", + "--url", + f"http://127.0.0.1:{port}", + "--model", + "fake-qwen", + "--prompt", + "hello", + ], + env=env, + capture_output=True, + text=True, + timeout=20, + check=False, + ) + assert result.returncode == 0, result.stderr + assert "Hello!" in result.stdout + + session_result = subprocess.run( + [ + sys.executable, + "-m", + "llaisys.chat_cli", + "--url", + f"http://127.0.0.1:{port}", + "--session-id", + "cli-session", + "--create-session", + ], + env=env, + capture_output=True, + text=True, + timeout=20, + check=False, + ) + assert session_result.returncode == 0, session_result.stderr + assert "cli-session" in session_result.stdout + + list_result = subprocess.run( + [ + sys.executable, + "-m", + "llaisys.chat_cli", + "--url", + f"http://127.0.0.1:{port}", + "--list-sessions", + ], + env=env, + capture_output=True, + text=True, + timeout=20, + check=False, + ) + assert list_result.returncode == 0, list_result.stderr + assert "cli-session" in list_result.stdout + finally: + process.terminate() + process.join(timeout=5) + + +if __name__ == "__main__": + test_cli_smoke() + print("\033[92mTest passed!\033[0m\n") diff --git a/test/test_generate_sampling.py b/test/test_generate_sampling.py new file mode 100644 index 00000000..119ccfce --- /dev/null +++ b/test/test_generate_sampling.py @@ -0,0 +1,43 @@ +import argparse +import os + +from transformers import AutoTokenizer + +import llaisys +from test_utils import llaisys_device + + +def test_sampling_smoke(model_path: str, device_name: str = "cpu") -> None: + if not model_path or not os.path.isdir(model_path): + print("Skipped: provide --model with a local model directory.") + return + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = llaisys.models.Qwen2(model_path, llaisys_device(device_name)) + prompt = tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": "Say hello in one short sentence."}], + add_generation_prompt=True, + tokenize=False, + ) + input_ids = tokenizer.encode(prompt) + output_ids = model.generate( + input_ids, + max_new_tokens=16, + top_k=20, + top_p=0.9, + temperature=0.8, + ) + generated = output_ids[len(input_ids):] + assert generated + assert len(generated) <= 16 + assert all(isinstance(token, int) for token in generated) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--model", default=None, type=str) + args = parser.parse_args() + + test_sampling_smoke(args.model, args.device) + print("\033[92mTest passed!\033[0m\n") diff --git a/test/test_model_support.py b/test/test_model_support.py new file mode 100644 index 00000000..7c0c30d0 --- /dev/null +++ b/test/test_model_support.py @@ -0,0 +1,149 @@ +import json +import os +import struct +import sys +import tempfile +from pathlib import Path + +import numpy as np + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) + +import llaisys + + +def _write_safetensors(path: Path, tensors: dict[str, np.ndarray]) -> None: + header = {} + offset = 0 + raw_chunks = [] + for name in sorted(tensors): + array = np.ascontiguousarray(tensors[name], dtype=np.float32) + raw = array.tobytes(order="C") + header[name] = { + "dtype": "F32", + "shape": list(array.shape), + "data_offsets": [offset, offset + len(raw)], + } + raw_chunks.append(raw) + offset += len(raw) + + header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8") + with open(path, "wb") as handle: + handle.write(struct.pack(" Path: + model_dir = root / "tiny-llama" + model_dir.mkdir() + + config = { + "model_type": "llama", + "torch_dtype": "float32", + "num_hidden_layers": 1, + "hidden_size": 8, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "intermediate_size": 16, + "max_position_embeddings": 32, + "vocab_size": 16, + "rms_norm_eps": 1e-6, + "rope_theta": 10000.0, + "eos_token_id": 2, + "tie_word_embeddings": False, + } + with open(model_dir / "config.json", "w") as handle: + json.dump(config, handle) + + hs = config["hidden_size"] + dh = hs // config["num_attention_heads"] + nkvh = config["num_key_value_heads"] + di = config["intermediate_size"] + voc = config["vocab_size"] + embed = np.arange(voc * hs, dtype=np.float32).reshape(voc, hs) * 0.01 + + tensors = { + "model.embed_tokens.weight": embed, + "lm_head.weight": embed.copy(), + "model.norm.weight": np.ones((hs,), dtype=np.float32), + "model.layers.0.input_layernorm.weight": np.ones((hs,), dtype=np.float32), + "model.layers.0.self_attn.q_proj.weight": np.eye(hs, dtype=np.float32), + "model.layers.0.self_attn.k_proj.weight": np.eye(nkvh * dh, hs, dtype=np.float32), + "model.layers.0.self_attn.v_proj.weight": np.eye(nkvh * dh, hs, dtype=np.float32), + "model.layers.0.self_attn.o_proj.weight": np.eye(hs, dtype=np.float32), + "model.layers.0.post_attention_layernorm.weight": np.ones((hs,), dtype=np.float32), + "model.layers.0.mlp.gate_proj.weight": np.ones((di, hs), dtype=np.float32) * 0.01, + "model.layers.0.mlp.up_proj.weight": np.ones((di, hs), dtype=np.float32) * 0.01, + "model.layers.0.mlp.down_proj.weight": np.ones((hs, di), dtype=np.float32) * 0.01, + } + _write_safetensors(model_dir / "model.safetensors", tensors) + return model_dir + + +def _build_tiny_tied_llama_dir(root: Path) -> Path: + model_dir = root / "tiny-llama-tied" + model_dir.mkdir() + + config = { + "model_type": "llama", + "torch_dtype": "float32", + "num_hidden_layers": 1, + "hidden_size": 8, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "intermediate_size": 16, + "max_position_embeddings": 32, + "vocab_size": 16, + "rms_norm_eps": 1e-6, + "rope_theta": 10000.0, + "eos_token_id": 2, + "tie_word_embeddings": True, + } + with open(model_dir / "config.json", "w") as handle: + json.dump(config, handle) + + hs = config["hidden_size"] + di = config["intermediate_size"] + voc = config["vocab_size"] + embed = np.arange(voc * hs, dtype=np.float32).reshape(voc, hs) * 0.01 + tensors = { + "model.embed_tokens.weight": embed, + "model.norm.weight": np.ones((hs,), dtype=np.float32), + "model.layers.0.input_layernorm.weight": np.ones((hs,), dtype=np.float32), + "model.layers.0.self_attn.q_proj.weight": np.eye(hs, dtype=np.float32), + "model.layers.0.self_attn.k_proj.weight": np.eye(hs, dtype=np.float32), + "model.layers.0.self_attn.v_proj.weight": np.eye(hs, dtype=np.float32), + "model.layers.0.self_attn.o_proj.weight": np.eye(hs, dtype=np.float32), + "model.layers.0.post_attention_layernorm.weight": np.ones((hs,), dtype=np.float32), + "model.layers.0.mlp.gate_proj.weight": np.ones((di, hs), dtype=np.float32) * 0.01, + "model.layers.0.mlp.up_proj.weight": np.ones((di, hs), dtype=np.float32) * 0.01, + "model.layers.0.mlp.down_proj.weight": np.ones((hs, di), dtype=np.float32) * 0.01, + } + _write_safetensors(model_dir / "model.safetensors", tensors) + return model_dir + + +def test_llama_model_support() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = _build_tiny_llama_dir(Path(tmpdir)) + assert llaisys.models.detect_model_type(str(model_dir)) == "llama" + assert llaisys.models.default_model_name(str(model_dir)) == "llaisys-llama" + + model = llaisys.models.create_model(str(model_dir), llaisys.DeviceType.CPU, 0) + assert isinstance(model, llaisys.models.Llama) + + outputs = model.generate([1, 3, 5], max_new_tokens=1, top_k=1, top_p=1.0, temperature=1.0) + assert outputs[:3] == [1, 3, 5] + assert len(outputs) == 4 + + tied_model_dir = _build_tiny_tied_llama_dir(Path(tmpdir)) + tied_model = llaisys.models.create_model(str(tied_model_dir), llaisys.DeviceType.CPU, 0) + assert bool(tied_model.weights_ptr.contents.out_embed) + + +if __name__ == "__main__": + test_llama_model_support() + print("\033[92mTest passed!\033[0m\n") diff --git a/test/test_tensor_parallel.py b/test/test_tensor_parallel.py new file mode 100644 index 00000000..2f422bc1 --- /dev/null +++ b/test/test_tensor_parallel.py @@ -0,0 +1,63 @@ +import argparse +import os + +import torch + +import llaisys + + +def test_tensor_parallel_qwen2( + model_path: str, + *, + tp_size: int = 2, + max_steps: int = 8, +) -> None: + if not model_path or not os.path.isdir(model_path): + print("Skipped: provide --model with a local model directory.") + return + if not torch.cuda.is_available(): + print("Skipped: CUDA is not available.") + return + if torch.cuda.device_count() < tp_size: + print(f"Skipped: need at least {tp_size} CUDA devices.") + return + + prompt = [151646, 151644, 15191, 525, 498, 30, 151645, 151648, 198] + expected_generated = [91786, 0, 358, 2776, 18183, 39350, 10911, 16] + tp_device_ids = list(range(tp_size)) + parallel = llaisys.models.create_model( + model_path, + llaisys.DeviceType.NVIDIA, + 0, + tp_size=tp_size, + tp_device_ids=tp_device_ids, + ) + + parallel_out = parallel.generate( + prompt, + max_new_tokens=max_steps, + top_k=1, + top_p=1.0, + temperature=1.0, + ) + assert parallel_out[: len(prompt)] == prompt + assert parallel_out[len(prompt):] == expected_generated[:max_steps] + + close = getattr(parallel, "close", None) + if callable(close): + close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default=None, type=str) + parser.add_argument("--tp-size", default=2, type=int) + parser.add_argument("--max-steps", default=8, type=int) + args = parser.parse_args() + + test_tensor_parallel_qwen2( + args.model, + tp_size=args.tp_size, + max_steps=args.max_steps, + ) + print("\033[92mTest passed!\033[0m\n") diff --git a/xmake.lua b/xmake.lua index 1f65f7a9..30c00a65 100644 --- a/xmake.lua +++ b/xmake.lua @@ -95,6 +95,23 @@ target("llaisys-ops") on_install(function (target) end) target_end() +target("llaisys-models") + set_kind("static") + add_deps("llaisys-ops") + add_deps("llaisys-tensor") + add_deps("llaisys-core") + + set_languages("cxx17") + set_warnings("all", "error") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + end + + add_files("src/models/*/*.cpp") + + on_install(function (target) end) +target_end() + target("llaisys") set_kind("shared") add_deps("llaisys-utils") @@ -102,9 +119,19 @@ target("llaisys") add_deps("llaisys-core") add_deps("llaisys-tensor") add_deps("llaisys-ops") - - set_languages("cxx17") + add_deps("llaisys-models") + if has_config("nv-gpu") then + add_nvidia_build_settings() + add_nvidia_source_files() + else + set_languages("cxx17") + end set_warnings("all", "error") + if not is_plat("windows") then + add_ldflags("-Wl,--no-as-needed") + add_syslinks("gomp") + end + add_cpu_blas_settings() add_files("src/llaisys/*.cc") set_installdir(".") @@ -119,4 +146,4 @@ target("llaisys") os.cp("lib/*.so", "python/llaisys/libllaisys/") end end) -target_end() \ No newline at end of file +target_end() diff --git a/xmake/cpu.lua b/xmake/cpu.lua index 101d894e..d49ba2ea 100644 --- a/xmake/cpu.lua +++ b/xmake/cpu.lua @@ -1,3 +1,119 @@ +option("cpu-blas") + set_default(true) + set_showmenu(true) + set_description("Whether to enable OpenBLAS acceleration for large CPU linear kernels when available") +option_end() + +option("openblas-prefix") + set_default("") + set_showmenu(true) + set_description("Prefix directory of OpenBLAS installation (expects include/cblas.h and lib/libopenblas.so)") +option_end() + +local _openblas_config = false +local _openblas_checked = false +local _openblas_warned = false + +local function _detect_openblas_from_prefix(prefix) + if not prefix or #prefix == 0 then + return nil + end + + local include_dirs = { + path.join(prefix, "include"), + path.join(prefix, "include", "x86_64-linux-gnu"), + } + local lib_dirs = { + path.join(prefix, "lib"), + path.join(prefix, "lib64"), + path.join(prefix, "lib", "x86_64-linux-gnu"), + } + + for _, include_dir in ipairs(include_dirs) do + if not os.isfile(path.join(include_dir, "cblas.h")) then + goto continue + end + for _, lib_dir in ipairs(lib_dirs) do + if os.isfile(path.join(lib_dir, "libopenblas.so")) then + return { + include_dir = include_dir, + lib_dir = lib_dir, + } + end + end + ::continue:: + end + + return nil +end + +local function _get_openblas_config() + if _openblas_checked then + return _openblas_config + end + _openblas_checked = true + + if not has_config("cpu-blas") then + _openblas_config = nil + return _openblas_config + end + + local prefixes = {} + local configured_prefix = get_config("openblas-prefix") + if configured_prefix and #configured_prefix > 0 then + table.insert(prefixes, configured_prefix) + end + + local conda_prefix = os.getenv("CONDA_PREFIX") + if conda_prefix and #conda_prefix > 0 then + table.insert(prefixes, conda_prefix) + end + + table.insert(prefixes, "/usr") + table.insert(prefixes, "/usr/local") + + local seen = {} + for _, prefix in ipairs(prefixes) do + if seen[prefix] then + goto continue + end + seen[prefix] = true + + local config = _detect_openblas_from_prefix(prefix) + if config ~= nil then + _openblas_config = config + return _openblas_config + end + ::continue:: + end + + _openblas_config = nil + return _openblas_config +end + +function add_cpu_blas_settings() + if not has_config("cpu-blas") then + return + end + + local config = _get_openblas_config() + if config == nil then + if not _openblas_warned then + print("warning: OpenBLAS not found, falling back to internal CPU linear kernels") + _openblas_warned = true + end + return + end + + add_defines("LLAISYS_USE_OPENBLAS") + add_includedirs(config.include_dir) + add_linkdirs(config.lib_dir) + add_links("openblas") + if not is_plat("windows") then + add_rpathdirs(config.lib_dir) + end +end + target("llaisys-device-cpu") set_kind("static") set_languages("cxx17") @@ -17,11 +133,13 @@ target("llaisys-ops-cpu") set_languages("cxx17") set_warnings("all", "error") if not is_plat("windows") then - add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cxflags("-fPIC", "-Wno-unknown-pragmas", "-fopenmp") + else + add_cxflags("/openmp") end + add_cpu_blas_settings() add_files("../src/ops/*/cpu/*.cpp") on_install(function (target) end) target_end() - diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua new file mode 100644 index 00000000..0bb56273 --- /dev/null +++ b/xmake/nvidia.lua @@ -0,0 +1,12 @@ +function add_nvidia_build_settings() + set_languages("cxx17") + set_policy("build.cuda.devlink", true) + add_cugencodes("native") + add_cuflags("-Xcompiler=-fPIC") + add_links("cudart", "cublas", "cudadevrt") +end + +function add_nvidia_source_files() + add_files("src/device/nvidia/*.cu") + add_files("src/ops/*/nvidia/*.cu") +end