diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 000000000..3d31c23bb --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,60 @@ +name: Build and test +on: + pull_request: + push: + paths-ignore: + - '**.md' + - 'LICENSE' + +jobs: + build: + name: Build + strategy: + fail-fast: false + matrix: + os: [windows-latest, ubuntu-latest] + type: [release] + runs-on: ${{ matrix.os }} + steps: + + - name: checkout code + uses: actions/checkout@v4 + + - name: install xmake + uses: xmake-io/github-action-setup-xmake@v1 + with: + xmake-version: latest + + - name: Xmake Build & Install + run: | + xmake + xmake install + + - name: Install Python + run: | + cd python + pip install . + cd .. + + - name: Assignment-0 + run: | + python test/test_runtime.py --device cpu + + - name: Assignment-1 + run: | + python test/test_tensor.py + + - name: Assignment-2 + run: | + python test/ops/add.py + python test/ops/argmax.py + python test/ops/embedding.py + python test/ops/linear.py + python test/ops/rms_norm.py + python test/ops/rope.py + python test/ops/self_attention.py + python test/ops/swiglu.py + + - name: Assignment-3 + run: | + python test/test_infer.py --test diff --git a/assets/AI.ico b/assets/AI.ico new file mode 100644 index 000000000..358e94af5 Binary files /dev/null and b/assets/AI.ico differ diff --git a/chat_cli.py b/chat_cli.py new file mode 100644 index 000000000..6ab68e30e --- /dev/null +++ b/chat_cli.py @@ -0,0 +1,436 @@ +""" +LLAISYS 交互式命令行聊天客户端 + +用法: + python chat_cli.py [--server http://localhost:8000] [--session default] + +内置命令: + /quit 退出 + /new 新建对话(清空历史) + /history 显示当前对话历史 + /sessions 列出所有本地会话 + /switch 切换到指定会话 + /edit 编辑第 N 条用户消息并重新生成(N 从 1 开始) + /temp 设置温度(0–2) + /topk 设置 Top-K + /topp 设置 Top-P + /maxtok 设置最大新 Token 数 +""" + +import argparse +import json +import sys +import os +import uuid +from typing import List, Dict, Optional + +try: + import requests +except ImportError: + print("错误:请先安装 requests:pip install requests") + sys.exit(1) + +# prompt_toolkit 能正确处理 CJK 双宽字符,退格不会错位 +try: + from prompt_toolkit import prompt as _pt_prompt + from prompt_toolkit.formatted_text import ANSI as _PT_ANSI + _USE_PT = True +except ImportError: + _USE_PT = False + +import re as _re +_STRIP_ANSI = _re.compile(r'\x1b\[[0-9;]*[mK]') + +# 清洗特殊 token 的正则 +# 使用 [^\n||<>] 代替 [\w._-],以匹配 ▁(U+2581,SentencePiece 词边界符)等非 ASCII 字符 +_SPECIAL_TOKEN_RE = _re.compile( + r"<\s*[||]\s*[^\n||<>]{2,}\s*[||]\s*>", + _re.UNICODE, +) +# ▁ = U+2581(LOWER ONE EIGHTH BLOCK),Qwen EOS token 中使用,不是普通下划线 +_EOS_TOKEN_RE = _re.compile( + r"<\s*[||]\s*" + r"(?:end[▁_\-\s]*of[▁_\-\s]*(?:sentence|text|turn)" + r"|endoftext" + r"|im[▁_\-\s]*end" + r"|eot[▁_\-\s]*id" + r")\s*[||]\s*>", + _re.IGNORECASE, +) +# [^\n>]* 可匹配含 ▁ 的任意字符 +_PARTIAL_TAIL_RE = _re.compile(r"<(?:\s{0,3}[||][^\n>]*)?$") + +def _clean_reply(text: str) -> str: + """清洗模型输出中的特殊 token 和 EOS 标记。""" + text = _EOS_TOKEN_RE.sub("", text) + text = _SPECIAL_TOKEN_RE.sub("", text) + text = _PARTIAL_TAIL_RE.sub("", text) + text = text.replace("\ufffd", "") + return text.rstrip() + +def _input(prompt_ansi: str) -> str: + """ + 支持 CJK 宽字符退格的 input 封装。 + 优先使用 prompt_toolkit(最佳)。 + 未安装时降级为无颜色的标准 input。 + 建议:pip install prompt-toolkit + """ + if _USE_PT: + return _pt_prompt(_PT_ANSI(prompt_ansi)) + # 降级:去掉颜色码,输出纯文本提示符 + return input(_STRIP_ANSI.sub('', prompt_ansi)) + + +# ───────────────────────────────────────────────────────────────────────────── +# 会话数据结构 +# ───────────────────────────────────────────────────────────────────────────── + +class Session: + def __init__(self, session_id: str, title: str = "新对话"): + self.id = session_id + self.title = title + self.messages: List[Dict] = [] + + def add_user(self, content: str): + self.messages.append({"role": "user", "content": content}) + if len(self.messages) == 1: + self.title = content[:30] + + def add_assistant(self, content: str): + self.messages.append({"role": "assistant", "content": content}) + + +# ───────────────────────────────────────────────────────────────────────────── +# 核心聊天逻辑 +# ───────────────────────────────────────────────────────────────────────────── + +def stream_chat( + server: str, + session: Session, + temperature: float, + top_k: int, + top_p: float, + max_tokens: int, +) -> str: + """ + 向服务器发送当前会话消息,流式打印响应,返回完整回复文本。 + """ + try: + resp = requests.post( + f"{server}/v1/chat/completions", + json={ + "model": "llaisys", + "messages": session.messages, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "max_tokens": max_tokens, + "stream": True, + "session_id": session.id, + }, + stream=True, + timeout=300, + ) + resp.raise_for_status() + except requests.exceptions.ConnectionError: + print(f"\n[错误] 无法连接服务器 {server},请确认服务器已启动。") + return "" + except requests.exceptions.HTTPError as e: + print(f"\n[错误] HTTP {e.response.status_code}: {e.response.text[:200]}") + return "" + + print("\033[32mAssistant\033[0m: ", end="", flush=True) + reply = "" + buf = "" + # hold-back 缓冲:缓存可能是特殊 token 不完整前缀的尾部 + _hold = "" + # 思考标签状态跟踪 + _in_think = False + + for raw in resp.iter_content(chunk_size=None): + if not raw: + continue + buf += raw.decode("utf-8", errors="replace") + while "\n" in buf: + line, buf = buf.split("\n", 1) + line = line.strip() + if not line.startswith("data: "): + continue + payload = line[6:].strip() + if payload == "[DONE]": + break + try: + chunk = json.loads(payload) + delta = chunk["choices"][0]["delta"].get("content", "") + if not delta: + continue + # 拼入 hold-back 缓冲 + _hold += delta + + # 跟踪 think 标签状态 + if "" in _hold and not _in_think: + _in_think = True + if "" in _hold and _in_think: + _in_think = False + + # 检测是否包含完整 EOS 标记 + eos_m = _EOS_TOKEN_RE.search(_hold) + if eos_m: + _hold = _hold[:eos_m.start()] + # 检测尾部是否有未闭合的特殊 token 前缀 + partial_m = _PARTIAL_TAIL_RE.search(_hold) + if partial_m: + safe = _hold[:partial_m.start()] + _hold = _hold[partial_m.start():] + else: + safe = _hold + _hold = "" + if safe: + # 清洗完整特殊 token(但保留 think 标签) + # 先移除 EOS 标记,保留 ... + safe = _EOS_TOKEN_RE.sub("", safe) + # 移除其他特殊 token,但保留 + safe = _SPECIAL_TOKEN_RE.sub("", safe) + print(safe, end="", flush=True) + reply += safe + except (json.JSONDecodeError, KeyError, IndexError): + pass + + # 输出 hold-back 中残留的非特殊 token 文本 + if _hold: + _hold = _EOS_TOKEN_RE.sub("", _hold) + _hold = _SPECIAL_TOKEN_RE.sub("", _hold) + _hold = _PARTIAL_TAIL_RE.sub("", _hold) + if _hold: + print(_hold, end="", flush=True) + reply += _hold + + # 确保 think 标签闭合:如果流式结束时仍在 think 块内,自动闭合 + if _in_think: + print("\n", end="", flush=True) + reply += "\n" + + print() # 换行 + return _clean_reply(reply) + + +def clear_server_cache(server: str, session_id: str): + """通知服务器清空指定会话的 KV-Cache。""" + try: + requests.post(f"{server}/v1/sessions/{session_id}/clear", timeout=5) + except Exception: + pass + + +# ───────────────────────────────────────────────────────────────────────────── +# 主循环 +# ───────────────────────────────────────────────────────────────────────────── + +def print_help(temp=0.8, topk=50, topp=0.9, maxtok=512): + print(f""" +可用命令: + /quit 退出程序 + /new 新建对话 + /clone [N] 新建对话并保留最近 N 轮上下文(默认 2 轮) + /history 显示对话历史 + /sessions 列出所有会话 + /switch 切换会话ネid 可由 /sessions 查看) + /edit 编辑第 N 条用户消息并重新生成 + /temp <0.0–2.0> 设置 temperature(当前:{temp}) + /topk <1–500> 设置 Top-K(当前:{topk}) + /topp <0.0–1.0> 设置 Top-P(当前:{topp}) + /maxtok 设置最大新 Token 数(当前:{maxtok}) +""") + + +def chat_loop(server: str, default_session_id: str): + sessions: Dict[str, Session] = {} + current_id = default_session_id + + def get_session(sid: str) -> Session: + if sid not in sessions: + sessions[sid] = Session(sid) + return sessions[sid] + + current = get_session(current_id) + + # 采样参数(运行时可调) + temperature = 0.8 + top_k = 50 + top_p = 0.9 + max_tokens = 512 + + print(f"LLAISYS Chat CLI — 服务器: {server}") + print("输入 /help 查看命令列表,Ctrl-C 或 /quit 退出。") + print("提示:Enter 发送,Shift+Enter 换行(需要 prompt_toolkit 支持)\n") + + while True: + # 提示符 + try: + prompt_str = f"\033[34mYou\033[0m [{current.title[:18]}]: " + user_input = _input(prompt_str).strip() + except (EOFError, KeyboardInterrupt): + print("\n再见!") + break + + if not user_input: + continue + + # ── 内置命令 ────────────────────────────────────────────────────────── + + if user_input == "/quit": + print("再见!") + break + + elif user_input == "/help": + print_help(temperature, top_k, top_p, max_tokens) + + elif user_input == "/new": + new_id = f"s_{uuid.uuid4().hex[:8]}" + current = get_session(new_id) + current_id = new_id + print(f"[新建对话 {new_id}]") + + elif user_input.startswith("/clone"): + parts_cmd = user_input.split() + n_keep = 2 + if len(parts_cmd) > 1: + try: + n_keep = int(parts_cmd[1]) + except ValueError: + pass + # 保留最近 n_keep 轮对话作为示例上下文 + kept: List[Dict] = [] + rounds = 0 + for m in reversed(current.messages): + if m["role"] == "user" and rounds >= n_keep: + break + kept.insert(0, dict(m)) + if m["role"] == "user": + rounds += 1 + new_id = f"s_{uuid.uuid4().hex[:8]}" + new_sess = get_session(new_id) + new_sess.messages = kept + new_sess.title = f"[续] {current.title[:20]}" + current = new_sess + current_id = new_id + print(f"[新建对话 {new_id},保留 {len(kept)} 条示例消息]") # noqa: E501 + + elif user_input == "/history": + if not current.messages: + print("[当前对话为空]") + else: + for i, m in enumerate(current.messages): + role_label = "You" if m["role"] == "user" else "AI " + preview = m["content"].replace("\n", " ")[:80] + print(f" [{i+1}] {role_label}: {preview}") + + elif user_input == "/sessions": + if not sessions: + print("[暂无会话]") + else: + for sid, s in sessions.items(): + marker = " ◀" if sid == current_id else "" + print(f" {sid} {s.title}{marker}") + + elif user_input.startswith("/switch "): + sid = user_input[8:].strip() + if sid in sessions: + current_id = sid + current = sessions[sid] + # 切换会话时通知服务器清空 KV-Cache + clear_server_cache(server, current_id) + print(f"[切换到会话: {current.title or sid}]") + else: + print(f"[未找到会话 {sid},可用: {list(sessions.keys())}]") + + elif user_input.startswith("/edit "): + try: + n = int(user_input[6:].strip()) + except ValueError: + print("[用法: /edit <消息序号,从1开始>]") + continue + + user_msgs = [(i, m) for i, m in enumerate(current.messages) if m["role"] == "user"] + if n < 1 or n > len(user_msgs): + print(f"[序号超范围,当前共 {len(user_msgs)} 条用户消息]") + continue + + orig_idx, orig_msg = user_msgs[n - 1] + print(f" 原内容: {orig_msg['content']}") + try: + new_content = input(" 新内容: ").strip() + except (EOFError, KeyboardInterrupt): + print() + continue + if not new_content: + continue + + # 修改消息并截断后续历史 + current.messages[orig_idx]["content"] = new_content + current.messages = current.messages[: orig_idx + 1] + if n == 1: + current.title = new_content[:30] + + # 重新生成 + reply = stream_chat(server, current, temperature, top_k, top_p, max_tokens) + if reply: + current.add_assistant(reply) + + elif user_input.startswith("/temp "): + try: + temperature = float(user_input[6:]) + print(f"[temperature = {temperature}]") + except ValueError: + print("[用法: /temp <0.0–2.0>]") + + elif user_input.startswith("/topk "): + try: + top_k = int(user_input[6:]) + print(f"[top_k = {top_k}]") + except ValueError: + print("[用法: /topk <整数>]") + + elif user_input.startswith("/topp "): + try: + top_p = float(user_input[6:]) + print(f"[top_p = {top_p}]") + except ValueError: + print("[用法: /topp <0.0–1.0>]") + + elif user_input.startswith("/maxtok "): + try: + max_tokens = int(user_input[8:]) + print(f"[max_tokens = {max_tokens}]") + except ValueError: + print("[用法: /maxtok <整数>]") + + elif user_input.startswith("/"): + print(f"[未知命令: {user_input},输入 /help 查看帮助]") + + # ── 正常聊天 ────────────────────────────────────────────────────────── + + else: + current.add_user(user_input) + reply = stream_chat(server, current, temperature, top_k, top_p, max_tokens) + if reply: + current.add_assistant(reply) + else: + # 发送失败,撤销用户消息 + current.messages.pop() + + +def main(): + parser = argparse.ArgumentParser(description="LLAISYS 交互式聊天 CLI") + parser.add_argument( + "--server", default="http://localhost:8000", help="服务器地址" + ) + parser.add_argument( + "--session", default=f"s_{uuid.uuid4().hex[:8]}", help="初始会话 ID" + ) + args = parser.parse_args() + chat_loop(args.server, args.session) + + +if __name__ == "__main__": + main() diff --git a/chat_server.py b/chat_server.py new file mode 100644 index 000000000..4cdcad069 --- /dev/null +++ b/chat_server.py @@ -0,0 +1,1501 @@ +""" +LLAISYS Chat Server — OpenAI-compatible /v1/chat/completions 接口 + +用法: + python chat_server.py --model /path/to/model [--device cpu|nvidia] [--port 8000] + +流式调用示例: + curl http://localhost:8000/v1/chat/completions \\ + -H "Content-Type: application/json" \\ + -d '{"messages":[{"role":"user","content":"你好"}],"stream":true}' +""" + +import argparse +import asyncio +import concurrent.futures +import hashlib +import json +import queue +import re +import threading +import time +import uuid +import sys +import os +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import AsyncIterator, Dict, List, Optional, Iterator, Set, Tuple + +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse +from pydantic import BaseModel, Field +import uvicorn + +# ───────────────────────────────────────────────────────────────────────────── +# Pydantic 模型 (OpenAI schema 子集) +# ───────────────────────────────────────────────────────────────────────────── + +class Message(BaseModel): + role: str + content: str + + +class ChatCompletionRequest(BaseModel): + model: str = "llaisys" + messages: List[Message] + temperature: float = Field(default=0.5, ge=0.0, le=2.0) + top_p: float = Field(default=0.85, ge=0.0, le=1.0) + top_k: int = Field(default=30, ge=1) + max_tokens: int = Field(default=2048, ge=1) + stream: bool = False + # 扩展字段:会话 ID,用于 KV-Cache 前缀复用 + session_id: Optional[str] = "default" + # 思考预算: 块最多可生成的字符数(0 = 不限制) + thinking_budget: int = Field(default=800, ge=0) + + + +# ───────────────────────────────────────────────────────────────────────────── +# 用户会话:独立 KV-Cache + 前缀复用逻辑 +# ───────────────────────────────────────────────────────────────────────────── + +class UserSession: + """ + 轻量级会话标识,不持有 KV-Cache(由 KVCachePool 统一管理)。 + 职责:分词 + asyncio.Lock 确保同一 session_id 的请求串行提交给调度器。 + """ + + def __init__(self, session_id: str, tokenizer): + self.session_id = session_id + self.tokenizer = tokenizer + self.lock = asyncio.Lock() + + # 默认系统提示:鼓励简洁、避免无限罗列 + _DEFAULT_SYSTEM = ( + "你是一个有帮助的AI助手。请简洁清晰地回答问题," + "避免过度展开、无限罗列或重复相似内容。" + "如需思考,请控制思考长度,直接给出核心答案。" + ) + + def _tokenize(self, messages: List[Message]) -> List[int]: + msgs = [{"role": m.role, "content": m.content} for m in messages] + # 若没有 system 消息,自动插入默认提示词 + if not msgs or msgs[0]["role"] != "system": + msgs.insert(0, {"role": "system", "content": self._DEFAULT_SYSTEM}) + try: + prompt = self.tokenizer.apply_chat_template( + msgs, add_generation_prompt=True, tokenize=False, + enable_thinking=True + ) + except TypeError: + prompt = self.tokenizer.apply_chat_template( + msgs, add_generation_prompt=True, tokenize=False + ) + return self.tokenizer.encode(prompt) + + +# ───────────────────────────────────────────────────────────────────────────── +# KV-Cache Pool:块哈希前缀匹配 + LRU 淘汰 +# ───────────────────────────────────────────────────────────────────────────── + +BLOCK_SIZE: int = 16 # 每块 token 数(只缓存/匹配「完整 block」) +_ROOT_HASH: bytes = b'\x00' * 8 # 哈希链起始节点 + + +def _block_hash(parent_hash: bytes, block_tokens: List[int], extra: bytes = b"") -> bytes: + """ + 计算单个 KV block 的哈希 key: + key = H(parent_hash || len(block_tokens) || block_tokens || extra) + + 带入 parent_hash 保证不同前缀下相同 block token 不会误命中: + Prompt A: [X Y][P Q] vs Prompt B: [M N][P Q] + 第二块 token 相同,但 parent_hash 不同 → key 不同 → 不会错复用 + """ + h = hashlib.sha256() + h.update(parent_hash) + h.update(len(block_tokens).to_bytes(4, "little")) + for t in block_tokens: + h.update(t.to_bytes(8, "little")) + h.update(extra) + return h.digest()[:8] # 64-bit,碰撞概率 < 2^{-64} + + +@dataclass +class KVCacheEntry: + """ + 池中的一个物理 KV 状态条目,封装一个 Qwen2Session。 + + Fields: + entry_id : 池内唯一 ID + model_session : Qwen2Session(持有 GPU/CPU KV tensor,shape {maxseq, nkvh, dh}) + cached_tokens : model_session KV buffer 中已正确计算的完整 token 序列 + block_hashes : 对应 cached_tokens 的完整 block 哈希链 + block_hashes[i] = H(block_hashes[i-1], cached_tokens[i*BS:(i+1)*BS]) + ref_cnt : 借出引用计数(> 0 时不可淘汰,不可被他人借用) + last_access : 最近使用时刻(LRU 依据,单调时钟) + owner_sid : 最近持有该 entry 的 session_id(同 session 优先复用) + """ + entry_id: int + model_session: object + cached_tokens: List[int] = field(default_factory=list) + block_hashes: List[bytes] = field(default_factory=list) + ref_cnt: int = 0 + last_access: float = field(default_factory=time.monotonic) + owner_sid: str = "" + + +class KVCachePool: + """ + 跨会话 KV-Cache 前缀匹配池(仿 vLLM block cache 思路)。 + + 核心数据结构: + _entries : Dict[entry_id, KVCacheEntry] — 全部 entry + _cache_index : Dict[block_hash, entry_id] — 只索引完整 block 的末尾 hash + key = H(parent_hash, block_tokens),覆盖整条前缀链 + _free_lru : OrderedDict[entry_id, None] — 空闲 entry 的 LRU 链 + 末尾 = 最近使用,头部 = 最旧(淘汰头部) + + 前缀匹配流程(borrow): + 1. 把 prompt_tokens 按 BLOCK_SIZE 切成完整块 + 2. 从 _ROOT_HASH 开始逐块计算 block_hash,查询 _cache_index + 3. 找到最长连续命中前缀 → 对应 entry 的 KV data [0, matched_pos) 完全有效 + 4. 借出:ref_cnt += 1,从 _free_lru 移除 + 5. 调用方 set model_session.cache_pos = matched_pos,只对 suffix 做 prefill + + 写回流程(release): + 1. ref_cnt -= 1 + 2. 更新 cached_tokens = prompt_tokens + generated_tokens + 3. 为新完整 block 计算 hash,若 _cache_index 无此 hash 或旧指针已失效则写入 + (只用 token 数更多的 entry 覆盖,保证 index 始终指向最深缓存) + 4. entry 加入 _free_lru 末尾 + """ + + def __init__(self, model, max_entries: int = 32): + self._model = model + self._max_entries = max_entries + self._lock = threading.Lock() + self._entries: Dict[int, KVCacheEntry] = {} + self._cache_index: Dict[bytes, int] = {} # block_hash → entry_id + self._free_lru: OrderedDict = OrderedDict() # entry_id → None (LRU) + self._next_id: int = 0 + + # ── 公共接口 ────────────────────────────────────────────────────────────── + + def borrow( + self, + prompt_tokens: List[int], + owner_sid: str = "", + extra: bytes = b"", + ) -> Tuple[KVCacheEntry, int]: + """ + 查找并借出 prompt_tokens 的最长前缀命中 entry,返回 (entry, matched_pos)。 + + matched_pos : 命中的 token 数(对齐到 BLOCK_SIZE,0 = 无命中) + entry : 已借出(ref_cnt += 1);model_session.cache_pos 已设为 matched_pos + + 若无命中:创建新 entry(超出上限时 LRU 淘汰最旧空闲 entry)。 + """ + with self._lock: + best_entry, best_pos = self._find_best_prefix( + prompt_tokens, owner_sid, extra + ) + if best_entry is not None: + # 命中:rewind session 到前缀边界 + best_entry.model_session.cache_pos = best_pos + best_entry.ref_cnt += 1 + best_entry.last_access = time.monotonic() + best_entry.owner_sid = owner_sid + self._free_lru.pop(best_entry.entry_id, None) + return best_entry, best_pos + else: + # 未命中:分配新 entry(空 session,cache_pos = 0) + entry = self._alloc_entry() + entry.ref_cnt = 1 + entry.last_access = time.monotonic() + entry.owner_sid = owner_sid + return entry, 0 + + def release( + self, + entry: KVCacheEntry, + prompt_tokens: List[int], + generated_tokens: List[int], + extra: bytes = b"", + ) -> None: + """ + 归还借出的 entry,把完整 block 发布到 _cache_index,然后加入 _free_lru。 + cached_tokens 更新为 prompt_tokens + generated_tokens。 + 只有「完整 block」会被索引(末尾不足 BLOCK_SIZE 的 token 不缓存)。 + """ + with self._lock: + # 关键修复:Qwen2Session 的 cache_pos 只统计“作为输入喂给模型”的 token。 + # 生成流程里最后一个已输出 token 通常尚未作为下一步输入写入 KV, + # 因此不能盲目把 prompt + generated 全部当成可复用 KV。 + full_tokens = list(prompt_tokens) + list(generated_tokens) + real_cached = int(getattr(entry.model_session, "cache_pos", len(full_tokens))) + if real_cached < 0: + real_cached = 0 + if real_cached > len(full_tokens): + real_cached = len(full_tokens) + entry.cached_tokens = full_tokens[:real_cached] + self._publish_blocks(entry, extra) + entry.ref_cnt = max(0, entry.ref_cnt - 1) + entry.last_access = time.monotonic() + if entry.ref_cnt == 0: + self._free_lru[entry.entry_id] = None + self._free_lru.move_to_end(entry.entry_id) + + def stats(self) -> dict: + with self._lock: + return { + "total_entries": len(self._entries), + "free_entries": len(self._free_lru), + "indexed_blocks": len(self._cache_index), + } + + def _compute_block_hashes( + self, tokens: List[int], extra: bytes + ) -> List[bytes]: + """返回 tokens 所有完整 block 的哈希链,长度 = len(tokens) // BLOCK_SIZE。""" + hashes: List[bytes] = [] + parent = _ROOT_HASH + n_full = len(tokens) // BLOCK_SIZE + for i in range(n_full): + bt = tokens[i * BLOCK_SIZE : (i + 1) * BLOCK_SIZE] + bh = _block_hash(parent, bt, extra) + hashes.append(bh) + parent = bh + return hashes + + def _find_best_prefix( + self, + prompt_tokens: List[int], + owner_sid: str, + extra: bytes, + ) -> Tuple[Optional[KVCacheEntry], int]: + """ + 查找最长前缀命中(持锁内调用)。 + + 逐块计算 block_hash,在 _cache_index 中查找: + - 命中且 entry 空闲(ref_cnt == 0)→ 更新 best_entry / best_pos + - 未命中 / entry 被占用 → 停止(后续 hash 也不会命中) + 同等匹配长度时,优先 owner_sid 相同的 entry(多轮对话偏好)。 + """ + n_full = len(prompt_tokens) // BLOCK_SIZE + if n_full == 0: + return None, 0 + + parent = _ROOT_HASH + best_entry: Optional[KVCacheEntry] = None + best_pos: int = 0 + + for i in range(n_full): + bt = prompt_tokens[i * BLOCK_SIZE : (i + 1) * BLOCK_SIZE] + bh = _block_hash(parent, bt, extra) + + eid = self._cache_index.get(bh) + if eid is None: + break # 链断裂,后续不会命中 + + entry = self._entries.get(eid) + if entry is None or entry.ref_cnt > 0: + break # 无效或被占用 + + pos = (i + 1) * BLOCK_SIZE + # 防御性校验(防止极低概率哈希碰撞) + if (len(entry.cached_tokens) < pos or + entry.cached_tokens[i * BLOCK_SIZE : pos] != bt): + break + # 关键修复:必须确保底层 Qwen2Session 的真实 KV 至少覆盖到 pos。 + # 否则会把“逻辑上记录了 token,但 KV 实际未写入”的状态误判为命中, + # 导致后续解码异常(如提前 EOS、输出中断)。 + real_cache_pos = int(getattr(entry.model_session, "cache_pos", 0)) + if real_cache_pos < pos: + break + + if (pos > best_pos or + (pos == best_pos and entry.owner_sid == owner_sid + and (best_entry is None or best_entry.owner_sid != owner_sid))): + best_entry = entry + best_pos = pos + + parent = bh + + return best_entry, best_pos + + def _publish_blocks(self, entry: KVCacheEntry, extra: bytes) -> None: + """ + 将 entry.cached_tokens 的完整 block 写入 _cache_index(持锁内调用)。 + 若某 block_hash 已被其他 entry 占用,仅当新 entry 的 cached_tokens + 更长(更深缓存)时才覆盖,保证 index 始终指向最长可复用的 entry。 + """ + hashes = self._compute_block_hashes(entry.cached_tokens, extra) + entry.block_hashes = hashes + for bh in hashes: + existing_eid = self._cache_index.get(bh) + if existing_eid is None or existing_eid not in self._entries: + self._cache_index[bh] = entry.entry_id + else: + existing = self._entries[existing_eid] + if len(entry.cached_tokens) > len(existing.cached_tokens): + self._cache_index[bh] = entry.entry_id # 更深缓存优先 + + def _alloc_entry(self) -> KVCacheEntry: + """分配新 entry。超出上限时,LRU 淘汰最旧空闲 entry(持锁内调用)。""" + if len(self._entries) >= self._max_entries and self._free_lru: + self._evict_lru() + eid = self._next_id + self._next_id += 1 + model_session = self._model.create_session() + entry = KVCacheEntry(entry_id=eid, model_session=model_session) + self._entries[eid] = entry + return entry + + def _evict_lru(self) -> None: + """淘汰 _free_lru 头部(最旧)的空闲 entry(持锁内调用)。""" + if not self._free_lru: + return + eid, _ = self._free_lru.popitem(last=False) # 弹出最旧(头部) + entry = self._entries.pop(eid, None) + if entry is None: + return + # 从 _cache_index 中撤销该 entry 的全部 block hash + for bh in entry.block_hashes: + if self._cache_index.get(bh) == eid: + del self._cache_index[bh] + # Qwen2Session 由 __del__ 自动调用 llaisysQwen2SessionDestroy 释放 + + +# ───────────────────────────────────────────────────────────────────────────── +# 连续批处理调度器:请求队列 + 独立循环线程 +# ───────────────────────────────────────────────────────────────────────────── + +# 终止生成:session_id → bool,由 /v1/sessions/{sid}/abort 端点设置,_process_token 消费后清除 +_session_abort: Dict[str, bool] = {} +_session_abort_mu = threading.Lock() + +# 匹配 special token 变体:<|xxx|> / <|xxx|> / < | xxx | >(允许空格) +# 使用 [^\n||<>] 代替 [\w._-],以匹配 ▁(U+2581,SentencePiece 词边界符)等非 ASCII 字符 +_SPECIAL_FILTER = re.compile( + r"<\s*[||]\s*[^\n||<>]{2,}\s*[||]\s*>", + re.UNICODE, +) + +# 文本级 EOS 标记检测:模型有时把 EOS 当文本输出而非 token ID +# ▁ = U+2581(LOWER ONE EIGHTH BLOCK),Qwen EOS token 中使用,不是普通下划线 +_EOS_TEXT_RE = re.compile( + r"<\s*[||]\s*" + r"(?:end[▁_\-\s]*of[▁_\-\s]*(?:sentence|text|turn)" + r"|endoftext" + r"|im[▁_\-\s]*end" + r"|eot[▁_\-\s]*id" + r")\s*[||]\s*>", + re.IGNORECASE, +) + +# 流式 hold-back:检测文本尾部可能是特殊 token 的不完整前缀 +# [^\n>]* 可匹配含 ▁ 的任意字符,避免 hold-back 失效 +_PARTIAL_SPECIAL_TAIL = re.compile(r"<(?:\s{0,3}[||][^\n>]*)?$") + + +def _sanitize_generated_text(text: str) -> str: + """ + 服务端统一文本清洗: + 1) 去掉 <|...|> / <|...|> / < | ... | > 特殊 token + 2) 去掉 U+FFFD(常见乱码替换符) + 3) 去掉不可见控制字符(保留换行/制表) + """ + s = _SPECIAL_FILTER.sub("", text) + s = s.replace("\ufffd", "") + s = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", "", s) + return s + + +@dataclass +class PendingRequest: + """HTTP 层提交给调度器的等待请求。""" + req_id: str + request: "ChatCompletionRequest" + session_id: str + prompt_tokens: List[int] + result_queue: asyncio.Queue # ("delta", str) | ("done", None) | ("error", str) + loop: asyncio.AbstractEventLoop + tokenizer: object + + +class ActiveRequest: + """ + 调度器正在迭代解码的单个请求。 + + 持有借出的 KVCacheEntry,每次 decode_step() 推进一个 token, + 生成结束后通过调度器归还 entry 到 KVCachePool。 + """ + + def __init__( + self, + pending: PendingRequest, + entry: KVCacheEntry, + prompt_tokens: List[int], + first_token: int, + end_token: int, + ): + self.pending = pending + self.entry = entry + self.prompt_tokens = prompt_tokens + self.end_token = end_token + self.generated: List[int] = [] + self.accumulated_ids: List[int] = [] + self.text_so_far: str = "" + self._text_emitted: str = "" # 已发送给客户端的文本(思考截断后差量基线) + self.done: bool = False + self._step: int = 0 + # 思考预算控制 + self._thinking_budget: int = getattr(pending.request, 'thinking_budget', 800) + self._think_suppressed: bool = False # 是否正在抑制过长的思考内容 + self._process_token(first_token) + + # ── 文本级多尺度重复检测 ────────────────────────────────────────────────── + + def _check_repetition(self) -> bool: + """ + 文本级多尺度重复检测。 + + 第一层:短周期连续重复检测(15-200 字符的模式连续出现 ≥3 次)。 + 适用于模型反复输出"列表的索引和切片"等短循环。 + 第二层:大块子串检测(120-500 字符在前文中出现过)。 + 适用于模型在更长尺度上的回环。 + """ + text = self.text_so_far + n = len(text) + if n < 60: + return False + + # ── 第一层:短周期连续重复(3 连击) ────────────────────────────── + tail = text[-600:] if n > 600 else text + ct = len(tail) + for plen in range(15, min(ct // 3 + 1, 201)): + if ct >= 3 * plen: + pat = tail[-plen:] + if (tail[-2 * plen : -plen] == pat + and tail[-3 * plen : -2 * plen] == pat): + return True + + # ── 第二层:大块子串检测 ────────────────────────────────────────── + if n < 240: + return False + upper = min(n // 2, 500) + 1 + for sz in range(120, upper, 40): + if text[-sz:] in text[:-sz]: + return True + return False + + # ── token 处理 ──────────────────────────────────────────────────────────── + + def _process_token(self, tok: int) -> None: + req = self.pending.request + self.generated.append(tok) + self.accumulated_ids.append(tok) + self._step += 1 + + # ── 中止检测(非阻塞,每 token 检查一次)──────────────────────────── + if _session_abort.get(self.pending.session_id, False): + with _session_abort_mu: + _session_abort.pop(self.pending.session_id, None) + self.done = True + asyncio.run_coroutine_threadsafe( + self.pending.result_queue.put(("done", None)), + self.pending.loop, + ) + return + + # 全量 decode(含特殊 token 文本),用于 EOS 文本检测 + raw_text = self.pending.tokenizer.decode( + self.accumulated_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + + # ── 文本级 EOS 检测(在 sanitize 之前) ────────────────────────────── + eos_match = _EOS_TEXT_RE.search(raw_text) + if eos_match: + final_text = _sanitize_generated_text(raw_text[:eos_match.start()]) + delta = final_text[len(self._text_emitted):] + self._text_emitted = final_text + self.text_so_far = final_text + if delta and not self._think_suppressed: + asyncio.run_coroutine_threadsafe( + self.pending.result_queue.put(("delta", delta)), + self.pending.loop, + ) + self.done = True + asyncio.run_coroutine_threadsafe( + self.pending.result_queue.put(("done", None)), + self.pending.loop, + ) + return + + new_text = _sanitize_generated_text(raw_text) + self.text_so_far = new_text + + # ── 重复循环检测(在发送 delta 之前)──────────────────────────────── + if self._check_repetition(): + self.done = True + asyncio.run_coroutine_threadsafe( + self.pending.result_queue.put(("done", None)), + self.pending.loop, + ) + return + + # ── 思考预算控制 ────────────────────────────────────────────────────── + if not self._think_suppressed: + t_open = new_text.find('') + if (self._thinking_budget > 0 and t_open >= 0 + and new_text.find('', t_open) < 0): + # 正处于未关闭的 think 块内 + think_len = len(new_text) - t_open - 7 + if think_len > self._thinking_budget: + # 思考超出预算:注入 截断,进入抑制模式 + asyncio.run_coroutine_threadsafe( + self.pending.result_queue.put(("delta", "\n\n\n")), + self.pending.loop, + ) + self._think_suppressed = True + self._text_emitted = new_text + if tok == self.end_token or self._step >= req.max_tokens: + self.done = True + asyncio.run_coroutine_threadsafe( + self.pending.result_queue.put(("done", None)), + self.pending.loop, + ) + return + + # 正常路径:计算并发送差量(含 hold-back 机制) + is_finishing = (tok == self.end_token or self._step >= req.max_tokens) + # 如果不是最后一步,检查尾部是否有未闭合的 <|... 部分特殊 token + if is_finishing: + safe_end = len(new_text) + else: + m = _PARTIAL_SPECIAL_TAIL.search(new_text) + safe_end = m.start() if m else len(new_text) + delta = new_text[len(self._text_emitted):safe_end] + self._text_emitted = new_text[:safe_end] + if delta: + asyncio.run_coroutine_threadsafe( + self.pending.result_queue.put(("delta", delta)), + self.pending.loop, + ) + else: + # 思考抑制中:等待模型自然关闭 + t_open = new_text.find('') + if t_open >= 0 and new_text.find('', t_open) >= 0: + # 模型已关闭思考块:恢复发送 + t_close = new_text.find('', t_open) + self._think_suppressed = False + after_close = new_text[t_close + 8:] # 8 = len('') + self._text_emitted = new_text + if after_close: + asyncio.run_coroutine_threadsafe( + self.pending.result_queue.put(("delta", after_close)), + self.pending.loop, + ) + else: + # 仍在等待,推进基线指针但不发送 + self._text_emitted = new_text + + if tok == self.end_token or self._step >= req.max_tokens: + self.done = True + asyncio.run_coroutine_threadsafe( + self.pending.result_queue.put(("done", None)), + self.pending.loop, + ) + + def decode_step(self) -> None: + """推进一步解码(线程安全:各 entry 的 model_session 独立,可并发)。""" + if self.done: + return + req = self.pending.request + try: + next_tok = self.entry.model_session._infer_sample( + [self.generated[-1]], + req.temperature, req.top_k, req.top_p, + ) + self._process_token(next_tok) + except Exception as exc: + asyncio.run_coroutine_threadsafe( + self.pending.result_queue.put(("error", str(exc))), + self.pending.loop, + ) + self.done = True + + +class ContinuousBatchScheduler: + """ + 请求池 + 单一循环线程实现的迭代级连续批处理调度器。 + + 数据流: + HTTP → pending_queue (queue.Queue, 线程安全) + ↓ + 调度主循环(独立 daemon 线程,永续运行) + ┌────────────────────────────────────────────────────┐ + │ Phase 1 – PREFILL(串行,避免多路同时争 GPU) │ + │ 新请求 → KVCachePool.borrow(prompt) │ + │ → 找最长前缀命中 → 只对 suffix 做 prefill │ + │ → 首 token 采样 → 构造 ActiveRequest │ + │ │ + │ Phase 2 – DECODE(线程池并发) │ + │ 每个 ActiveRequest.decode_step() 在独立线程执行 │ + │ 各 entry 持有独立 KV-Cache,互不干扰 │ + │ │ + │ Phase 3 – CLEANUP │ + │ 完成请求 → KVCachePool.release() │ + │ → 更新 cached_tokens → 发布新 block hash │ + │ → entry 加入 free_lru,可被后续请求复用 │ + └────────────────────────────────────────────────────┘ + """ + + def __init__( + self, + kv_pool: KVCachePool, + max_batch_size: int = 8, + max_workers: int = 16, + ): + self._pool = kv_pool + self._max_batch = max_batch_size + self._pending: queue.Queue = queue.Queue() + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers, + thread_name_prefix="llaisys-decode", + ) + self._thread = threading.Thread( + target=self._loop, daemon=True, name="llaisys-scheduler" + ) + self._thread.start() + + def submit(self, req: PendingRequest) -> None: + """将请求加入等待队列(线程安全,可在任意线程调用)。""" + self._pending.put(req) + + # ── 内部 ────────────────────────────────────────────────────────────────── + + def _prefill(self, pending: PendingRequest) -> Optional[ActiveRequest]: + """ + 在调度线程内串行执行 prefill: + 1. KVCachePool.borrow(prompt) → (entry, matched_pos) + 2. 只对 prompt[matched_pos:] 做 prefill(KV 前缀复用) + 3. 首 token 采样 → 构造 ActiveRequest + """ + req = pending.request + prompt = pending.prompt_tokens + # 新请求开始前,清除该 session 的任何残留 abort 标志 + with _session_abort_mu: + _session_abort.pop(pending.session_id, None) + try: + entry, matched_pos = self._pool.borrow( + prompt, owner_sid=pending.session_id + ) + # 只 feed prefix 之后的 suffix token + suffix = prompt[matched_pos:] + if not suffix: + # 完整命中:回退一步,重新 feed 最后一个 token 以获取首生成 token + entry.model_session.cache_pos = max(0, matched_pos - 1) + suffix = [prompt[-1]] if prompt else [] + if not suffix: + self._pool.release(entry, prompt, []) + asyncio.run_coroutine_threadsafe( + pending.result_queue.put(("done", None)), pending.loop + ) + return None + first_tok = entry.model_session._infer_sample( + suffix, req.temperature, req.top_k, req.top_p + ) + end_tok = entry.model_session._meta.end_token + ar = ActiveRequest(pending, entry, prompt, first_tok, end_tok) + if ar.done: + self._release(ar) + return None + return ar + except Exception as exc: + asyncio.run_coroutine_threadsafe( + pending.result_queue.put(("error", str(exc))), pending.loop + ) + return None + + def _release(self, ar: ActiveRequest) -> None: + """请求完成,归还 entry 到 KVCachePool(更新 cached_tokens + 发布新 block)。""" + self._pool.release(ar.entry, ar.prompt_tokens, ar.generated) + + def _loop(self) -> None: + """ + 调度主循环(独立 daemon 线程,永续运行)。 + + active_sids : 当前活跃 session ID 集合,保证同一 session 同时至多 + 一个活跃请求(防止同一 KV 序列被并发写入)。 + requeue_buf : 因 session 冲突暂缓的请求,当前轮结束后放回队列。 + """ + active: List[ActiveRequest] = [] + active_sids: Set[str] = set() + requeue_buf: List[PendingRequest] = [] + + while True: + # ── Phase 1: 接受新请求(prefill 串行)──────────────────────────── + slots = self._max_batch - len(active) + while slots > 0: + try: + pending = self._pending.get_nowait() + sid = pending.session_id + if sid in active_sids: + requeue_buf.append(pending) # 同 session 已有活跃请求,暂缓 + else: + active_sids.add(sid) + ar = self._prefill(pending) + if ar is not None: + active.append(ar) + else: + active_sids.discard(sid) + slots -= 1 + except queue.Empty: + break + + # 暂缓请求放回队列(下一轮调度) + for p in requeue_buf: + self._pending.put(p) + requeue_buf.clear() + + # ── 无活跃请求时阻塞等待新请求 ──────────────────────────────────── + if not active: + try: + pending = self._pending.get(timeout=0.005) + sid = pending.session_id + active_sids.add(sid) + ar = self._prefill(pending) + if ar is not None: + active.append(ar) + else: + active_sids.discard(sid) + except queue.Empty: + continue + + # ── Phase 2: 并发执行 decode step ───────────────────────────────── + if len(active) == 1: + active[0].decode_step() # 单请求直接在调度线程执行,省线程切换 + else: + # 多请求:提交线程池并行执行(各 entry 持有独立 KV-Cache,互不干扰) + futs = { + self._executor.submit(ar.decode_step): ar + for ar in active + } + for f in concurrent.futures.as_completed(futs): + try: + f.result() + except Exception: + pass # 错误已在 decode_step 内部回传给 result_queue + + # ── Phase 3: 清理已完成请求,归还 KVCachePool ───────────────────── + still: List[ActiveRequest] = [] + for ar in active: + if ar.done: + self._release(ar) + active_sids.discard(ar.pending.session_id) + else: + still.append(ar) + active = still + + +# ───────────────────────────────────────────────────────────────────────────── +# ModelServer:会话管理 + KVCachePool + 调度器 +# ───────────────────────────────────────────────────────────────────────────── + +class ModelServer: + """ + 多用户推理服务核心。 + + - KVCachePool : 跨 session 的 KV-Cache 前缀匹配池(独立物理 session 对象池) + - ContinuousBatchScheduler : 请求队列 + 循环线程 + 迭代级批处理 + - _sessions : session_id → UserSession(轻量标识 + tokenizer) + """ + + def __init__(self, model, tokenizer, pool_size: int = 32, max_batch: int = 8): + self.model = model + self.tokenizer = tokenizer + self._sessions: Dict[str, UserSession] = {} + self._sessions_mu = threading.Lock() + self.kv_pool = KVCachePool(model, max_entries=pool_size) + self.scheduler = ContinuousBatchScheduler( + self.kv_pool, + max_batch_size=max_batch, + max_workers=max(16, max_batch * 2), + ) + + def get_or_create_session(self, session_id: str) -> UserSession: + with self._sessions_mu: + if session_id not in self._sessions: + self._sessions[session_id] = UserSession(session_id, self.tokenizer) + return self._sessions[session_id] + + def delete_session(self, session_id: str): + with self._sessions_mu: + self._sessions.pop(session_id, None) + + def clear_session(self, session_id: str): + """将 session_id 的缓存条目在 KVCachePool 中标记为失效(清零其 cached_tokens)。""" + with self.kv_pool._lock: + for entry in list(self.kv_pool._entries.values()): + if entry.owner_sid == session_id and entry.ref_cnt == 0: + # 撤销该 entry 的 block hash + for bh in entry.block_hashes: + if self.kv_pool._cache_index.get(bh) == entry.entry_id: + del self.kv_pool._cache_index[bh] + entry.cached_tokens = [] + entry.block_hashes = [] + entry.model_session.reset_cache() + + def session_count(self) -> int: + with self._sessions_mu: + return len(self._sessions) + + +# ───────────────────────────────────────────────────────────────────────────── +# Web UI(内嵌 HTML) +# ───────────────────────────────────────────────────────────────────────────── + +WEB_UI_HTML = """ + + + + +LLAISYS Chat + + + + + + +
+ + +
+ + + + + +
+ +
+ +
+ + + +
+
+ + + + +""" + + +# ───────────────────────────────────────────────────────────────────────────── +# FastAPI 应用 +# ───────────────────────────────────────────────────────────────────────────── + +app = FastAPI(title="LLAISYS Chat Server", version="0.2.0") +_server: Optional[ModelServer] = None + + +def _make_sse_chunk(content: str, chat_id: str, model: str) -> str: + data = { + "id": chat_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}], + } + return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + + +def _make_sse_done(chat_id: str, model: str) -> str: + data = { + "id": chat_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + } + return f"data: {json.dumps(data)}\n\ndata: [DONE]\n\n" + + +@app.get("/", response_class=HTMLResponse) +def index(): + return WEB_UI_HTML + + +@app.post("/v1/chat/completions") +async def chat_completions(request: ChatCompletionRequest): + """ + OpenAI-compatible /v1/chat/completions。 + + 所有请求统一提交到 ContinuousBatchScheduler: + - 分词在 HTTP handler 中完成(纯 CPU,不阻塞事件循环) + - PendingRequest 入队 → 调度器 prefill(KVCachePool 前缀复用) + - 调度器 decode 循环推进 → token 通过 result_queue 回传事件循环 + - 流式:SSE 实时推送;非流式:等全部 token 后一次性返回 + """ + if _server is None: + raise HTTPException(status_code=503, detail="Model not loaded") + + chat_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" + sid = request.session_id or "default" + user_session = _server.get_or_create_session(sid) + + loop = asyncio.get_event_loop() + result_queue: asyncio.Queue = asyncio.Queue(maxsize=512) + prompt_tokens = user_session._tokenize(request.messages) + + pending = PendingRequest( + req_id=chat_id, + request=request, + session_id=sid, + prompt_tokens=prompt_tokens, + result_queue=result_queue, + loop=loop, + tokenizer=user_session.tokenizer, + ) + _server.scheduler.submit(pending) + + if request.stream: + async def sse_gen(): + while True: + kind, data = await result_queue.get() + if kind == "done": + yield _make_sse_done(chat_id, request.model) + break + elif kind == "error": + yield _make_sse_chunk(f"[Error: {data}]", chat_id, request.model) + yield _make_sse_done(chat_id, request.model) + break + elif kind == "delta" and data: + yield _make_sse_chunk(data, chat_id, request.model) + + return StreamingResponse( + sse_gen(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", + }, + ) + else: + parts: List[str] = [] + while True: + kind, data = await result_queue.get() + if kind == "done": + break + elif kind == "error": + raise HTTPException(status_code=500, detail=data) + elif kind == "delta" and data: + parts.append(data) + reply = "".join(parts) + return JSONResponse({ + "id": chat_id, + "object": "chat.completion", + "created": int(time.time()), + "model": request.model, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": reply}, + "finish_reason": "stop", + }], + "usage": {"prompt_tokens": len(prompt_tokens), + "completion_tokens": -1, "total_tokens": -1}, + }) + + +@app.get("/v1/models") +def list_models(): + return {"object": "list", "data": [ + {"id": "llaisys", "object": "model", + "active_sessions": _server.session_count() if _server else 0} + ]} + + +@app.post("/v1/sessions/{session_id}/clear") +def clear_session(session_id: str): + """清除指定会话的 KV-Cache(会话切换时调用)。""" + if _server is not None: + _server.clear_session(session_id) + return {"status": "ok", "session_id": session_id} + + +@app.post("/v1/sessions/{session_id}/abort") +def abort_session(session_id: str): + """中止指定会话的当前生成(最多延迟 1-2 个 token 后停止)。""" + with _session_abort_mu: + _session_abort[session_id] = True + return {"status": "ok", "session_id": session_id} + + +@app.delete("/v1/sessions/{session_id}") +def delete_session(session_id: str): + """彻底删除会话并释放其 KV-Cache 内存。""" + if _server is not None: + _server.delete_session(session_id) + return {"status": "deleted", "session_id": session_id} + + +@app.get("/v1/sessions") +def list_sessions(): + """查看当前活跃会话数量及 KVCachePool 统计。""" + count = _server.session_count() if _server else 0 + pool_stats = _server.kv_pool.stats() if _server else {} + return {"active_sessions": count, "kv_pool": pool_stats} + + +# ───────────────────────────────────────────────────────────────────────────── +# 启动入口 +# ───────────────────────────────────────────────────────────────────────────── + +def main(): + global _server + + parser = argparse.ArgumentParser( + description="LLAISYS Chat Server (OpenAI-compatible, continuous batching + KV prefix cache)" + ) + parser.add_argument("--model", required=True, help="模型目录路径") + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"]) + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--pool-size", type=int, default=32, + help="KVCachePool 最大 entry 数(每个 entry 占一个 Qwen2Session 的 KV 显存)") + parser.add_argument("--max-batch", type=int, default=8, + help="调度器每轮最大并发请求数") + args = parser.parse_args() + + import llaisys + from llaisys import DeviceType + from transformers import AutoTokenizer + + device = DeviceType.NVIDIA if args.device == "nvidia" else DeviceType.CPU + + print(f"[server] 加载 tokenizer: {args.model}") + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + + print(f"[server] 加载模型 (device={args.device}): {args.model}") + model = llaisys.models.Qwen2(args.model, device) + print("[server] 模型加载完成。") + + _server = ModelServer(model, tokenizer, + pool_size=args.pool_size, + max_batch=args.max_batch) + + print(f"[server] KVCachePool 已启动: max_entries={args.pool_size}, block_size={BLOCK_SIZE} tokens") + print(f"[server] 调度器已启动: max_batch={args.max_batch},前缀命中时只 prefill suffix") + print(f"[server] 监听 http://{args.host}:{args.port}") + print(f"[server] Web UI: http://localhost:{args.port}/") + uvicorn.run(app, host=args.host, port=args.port, log_level="info") + + +if __name__ == "__main__": + main() diff --git a/chat_ui.py b/chat_ui.py new file mode 100644 index 000000000..5728517bc --- /dev/null +++ b/chat_ui.py @@ -0,0 +1,901 @@ +#!/usr/bin/env python3 +""" +LLAISYS Chat UI — Gradio 6.x 豆包风格(全屏 + 历史侧边栏) + +用法: + python chat_ui.py --server http://localhost:8000 [--port 7860] +""" + +import html as _html +import json +import os +import re +import uuid +import argparse + +import gradio as gr +import requests + +# ───────────────────────────────────────────────────────────────────────────── +# 本地持久化路径(JSON 文件存放对话历史) +# ───────────────────────────────────────────────────────────────────────────── +_PERSIST_DIR = os.path.join(os.path.expanduser("~"), ".llaisys") +_PERSIST_FILE = os.path.join(_PERSIST_DIR, "chat_history.json") + + +def _save_conversations(conversations: list, session_id: str) -> None: + """将 conversations 和 active session_id 写入磁盘 JSON。""" + try: + os.makedirs(_PERSIST_DIR, exist_ok=True) + data = {"session_id": session_id, "conversations": conversations} + with open(_PERSIST_FILE, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=1) + except Exception: + pass + + +def _load_conversations(): + """从磁盘加载 conversations 和 session_id。返回 (conversations, session_id)。""" + try: + if os.path.exists(_PERSIST_FILE): + with open(_PERSIST_FILE, "r", encoding="utf-8") as f: + data = json.load(f) + convs = data.get("conversations", []) + sid = data.get("session_id", str(uuid.uuid4())) + if convs: + return convs, sid + except Exception: + pass + return [], str(uuid.uuid4()) + +# ───────────────────────────────────────────────────────────────────────────── +# 工具函数 +# ───────────────────────────────────────────────────────────────────────────── + +# 匹配 special token 变体:<|xxx|> / <|xxx|> / < | xxx | >(允许空格) +# 使用 [^\n||<>] 代替 [\w._-],以匹配 ▁(U+2581,SentencePiece 词边界符)等非 ASCII 字符 +_SPECIAL_RE = re.compile( + r"<\s*[||]\s*[^\n||<>]{2,}\s*[||]\s*>", + re.UNICODE, +) + +# 显式匹配常见 EOS 变体(宽松空格、中英文竖线、▁ 词边界符),作为兜底清洗 +# ▁ = U+2581(LOWER ONE EIGHTH BLOCK),Qwen EOS token 中使用,不是普通下划线 +_EOS_RE = re.compile( + r"<\s*[||]\s*" + r"(?:end[▁_\-\s]*of[▁_\-\s]*(?:sentence|text|turn)" + r"|endoftext" + r"|im[▁_\-\s]*end" + r"|eot[▁_\-\s]*id" + r")\s*[||]\s*>", + re.IGNORECASE, +) + +# 尾部不完整的特殊 token 前缀(流式残留) +# [^\n>]* 可匹配含 ▁ 的任意字符 +_PARTIAL_TAIL_RE = re.compile(r"<(?:\s{0,3}[||][^\n>]*)?$") + +def _to_text(content) -> str: + """ + 兼容 Gradio Chatbot 的多种 content 形态: + - str + - list[dict|str|...] + - dict(如多模态消息片段) + 统一抽取为字符串,避免 re.sub 收到 list/dict 报错。 + """ + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + # 常见字段:text / content / value + txt = item.get("text") or item.get("content") or item.get("value") + if isinstance(txt, str): + parts.append(txt) + return "\n".join(p for p in parts if p) + if isinstance(content, dict): + txt = content.get("text") or content.get("content") or content.get("value") + return txt if isinstance(txt, str) else "" + return "" + +def _clean(text) -> str: + s = _to_text(text) + # 1) 显式移除 EOS 标记(最优先,避免被其他正则截断后残留) + s = _EOS_RE.sub("", s) + # 2) 去除特殊 token 形态 + s = _SPECIAL_RE.sub("", s) + # 3) 去除尾部不完整的特殊 token 前缀(流式残留如 "< | end_of") + s = _PARTIAL_TAIL_RE.sub("", s) + # 4) 常见乱码替换符 U+FFFD + s = s.replace("\ufffd", "") + # 5) 过滤大多数不可见控制字符(保留换行/制表) + s = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", "", s) + return s.strip() + +def _trim_repetition(text: str, min_chunk: int = 80) -> str: + """ + 截掉文本末尾与前文重复的大段内容。 + + 如果末尾 sz 个字符与紧邻前方 sz 个字符完全一致,说明模型陷入了 + 重复循环,移除末尾的重复副本。作为服务端检测的第二道防线。 + """ + n = len(text) + if n < min_chunk * 2: + return text + for sz in range(min_chunk, min(n // 2 + 1, 501), 30): + if text[-sz:] == text[-2 * sz : -sz]: + return text[:-sz].rstrip() + return text + +def _normalize_think(text: str) -> str: + """ + 确保 ... 标签完整,让 Gradio reasoning_tags 正确渲染。 + - 有 但没有 :补头 + - 有 但没有 (流式中途):补尾 + - 连续多个 think 块不影响 + """ + has_open = "" in text + has_close = "" in text + if has_close and not has_open: + text = "" + text + elif has_open and not has_close: + text = text + "\n" + return text + +def _strip_think(text: str) -> str: + t = re.sub(r".*?", "", text, flags=re.DOTALL) + t = re.sub(r".*$", "", t, flags=re.DOTALL) + return _clean(t) + +def _build_api_messages(history: list) -> list: + msgs = [] + for item in history: + role = item["role"] + content = item["content"] if isinstance(item["content"], str) else "" + if role == "assistant": + content = _strip_think(content) + if content: + msgs.append({"role": role, "content": content}) + return msgs + + +# ───────────────────────────────────────────────────────────────────────────── +# 历史对话管理 +# ───────────────────────────────────────────────────────────────────────────── + +def _conv_title(history: list) -> str: + for m in history: + if m.get("role") == "user": + text = re.sub(r"<[^>]+>", "", _clean(m.get("content", ""))).strip() + return (text[:22] + "…") if len(text) > 22 else text + return "新对话" + +def _get_title_by_id(conversations: list, sid: str) -> str: + for c in conversations: + if c["id"] == sid: + return c.get("title") or "新对话" + return "新对话" + +def _update_conversations(conversations: list, sid: str, history: list) -> list: + convs = [dict(c) for c in conversations] + for c in convs: + if c["id"] == sid: + c["messages"] = history + if not c.get("title"): + c["title"] = _conv_title(history) + return convs + convs.append({"id": sid, "title": _conv_title(history), "messages": history}) + return convs + +def render_sidebar(conversations: list, current_id: str) -> str: + items = "" + for conv in reversed(conversations): + cls = "hi active" if conv["id"] == current_id else "hi" + title = _html.escape((conv.get("title") or "新对话")[:26]) + cid = conv["id"] + # 使用 data 属性存储 session id + items += ( + f'
' + f'' + f'{title}' + f'
\n' + ) + if not items: + items = '
暂无历史对话
' + return f'
{items}
' + + +# ───────────────────────────────────────────────────────────────────────────── +# 历史对话点击处理 JS +# ───────────────────────────────────────────────────────────────────────────── + +_HEAD_JS = """ +(function() { + console.log('[LLA] JS loaded'); + + // 历史对话点击处理函数 + window.__llaHandleClick = function(event, sid) { + event.preventDefault(); + event.stopPropagation(); + console.log('[LLA] History click:', sid); + var input = document.querySelector('#hcb input'); + if (!input) input = document.querySelector('#hcb textarea'); + if (input) { + input.value = sid; + input.dispatchEvent(new Event('input', {bubbles: true})); + input.dispatchEvent(new Event('change', {bubbles: true})); + } + }; + + // 只处理 Shift+Enter 换行,Enter 交给 Gradio 原生处理 + function initKeyHandler() { + var ta = document.querySelector('#msg-input textarea'); + if (!ta) { + setTimeout(initKeyHandler, 100); + return; + } + console.log('[LLA] Found textarea:', ta); + + ta.addEventListener('keydown', function(e) { + console.log('[LLA] Keydown:', e.key, 'Shift:', e.shiftKey); + if (e.key === 'Enter' && e.shiftKey) { + // Shift+Enter:允许默认换行行为,不做任何处理 + console.log('[LLA] Shift+Enter - allow newline'); + } + // Enter (without Shift): 不处理,完全交给 Gradio 原生的 msg_box.submit() + }); + } + + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', initKeyHandler); + } else { + initKeyHandler(); + } +})(); +""" + +CSS = """ +/* ══ RESET & FULL PAGE ═══════════════════════════════════════════════════════ */ +*, *::before, *::after { box-sizing: border-box; } +html, body { + height: 100% !important; margin: 0 !important; padding: 0 !important; + overflow: hidden !important; + background: #f7f8fa !important; + font-family: -apple-system,'PingFang SC','Microsoft YaHei','Segoe UI',sans-serif !important; +} +.gradio-container { + height: 100vh !important; max-width: 100% !important; width: 100% !important; + padding: 0 !important; margin: 0 !important; + background: transparent !important; overflow: hidden !important; +} +.gradio-container .main, +.gradio-container > .main > .contain, +.gradio-container > .main > .contain > .wrap { + height: 100vh !important; max-height: 100vh !important; + padding: 0 !important; margin: 0 !important; + overflow: hidden !important; max-width: 100% !important; +} +footer { display: none !important; } + +/* ══ OUTER ROW ═══════════════════════════════════════════════════════════════ */ +#outer-row { + display: flex !important; flex-direction: row !important; + flex-wrap: nowrap !important; align-items: stretch !important; + height: 100vh !important; width: 100% !important; + gap: 0 !important; overflow: hidden !important; +} +#outer-row > * { min-width: 0 !important; } + +/* ══ SIDEBAR ═════════════════════════════════════════════════════════════════ */ +#sidebar { + flex: 0 0 220px !important; width: 220px !important; + min-width: 220px !important; max-width: 220px !important; + height: 100vh !important; + background: #f7f8fa !important; + border-right: 1px solid #e8e8e8 !important; + display: flex !important; flex-direction: column !important; + overflow: hidden !important; padding: 0 !important; gap: 0 !important; +} +/* Sidebar header */ +#sb-hdr { padding: 18px 16px 12px !important; border-bottom: 1px solid #efefef !important; + flex-shrink: 0 !important; gap: 0 !important; } +#sb-hdr .gr-markdown p, +#sb-hdr p { font-size: 16px !important; font-weight: 700 !important; + color: #111 !important; margin: 0 !important; } + +/* Sidebar new-chat button */ +#sb-new { padding: 8px 10px !important; flex-shrink: 0 !important; gap: 0 !important; } +#sb-new button { + width: 100% !important; background: #e8faf3 !important; color: #059669 !important; + border: 1px solid #bbf7d0 !important; border-radius: 8px !important; + font-size: 13px !important; font-weight: 600 !important; height: 36px !important; + padding: 0 14px !important; text-align: left !important; + transition: all 0.15s !important; +} +#sb-new button:hover { background: #d1fae5 !important; border-color: #6ee7b7 !important; } + +/* History label */ +#sb-lbl { padding: 10px 16px 2px !important; flex-shrink: 0 !important; gap: 0 !important; } +#sb-lbl .gr-markdown p, +#sb-lbl p { font-size: 11px !important; color: #9ca3af !important; + font-weight: 600 !important; letter-spacing: 0.7px !important; + text-transform: uppercase !important; margin: 0 !important; } + +/* History HTML scroll area */ +#sb-hist { + flex: 1 !important; min-height: 0 !important; + overflow-y: auto !important; overflow-x: hidden !important; + padding: 0 8px !important; +} +#sb-hist::-webkit-scrollbar { width: 4px; } +#sb-hist::-webkit-scrollbar-thumb { background: #d1d5db; border-radius: 4px; } +#sb-hist > div { display: block !important; } +#hsc { padding: 4px 0 8px !important; } + +/* History items */ +.hi { + display: flex !important; align-items: center !important; gap: 7px !important; + padding: 7px 9px !important; border-radius: 7px !important; + cursor: pointer !important; color: #374151 !important; + font-size: 13px !important; user-select: none !important; + margin: 1px 0 !important; transition: background 0.12s !important; +} +.hi:hover { background: #eef0f3 !important; } +.hi.active { background: #e8faf3 !important; color: #065f46 !important; } +.hico { width: 14px !important; height: 14px !important; flex-shrink: 0 !important; + color: #9ca3af !important; } +.hi.active .hico { color: #09b37b !important; } +.hti { flex: 1 !important; overflow: hidden !important; + text-overflow: ellipsis !important; white-space: nowrap !important; } +.hempty { color: #9ca3af !important; font-size: 12px !important; + text-align: center !important; padding: 24px 0 !important; } + +/* Hidden textbox for JS click events - keep in DOM but invisible */ +#hcb { opacity: 0 !important; position: absolute !important; } + +/* 侧边栏 & 右侧参数面板:不被全局 overflow:hidden 截断 */ +#sidebar .wrap, #sidebar .contain, +#rp-col .wrap, #rp-col .contain { + height: auto !important; max-height: none !important; + overflow: visible !important; +} +#rp-col { + overflow-y: auto !important; +} + +/* Sidebar bottom */ +#sb-bot { border-top: 1px solid #efefef !important; padding: 12px 16px !important; + flex-shrink: 0 !important; gap: 0 !important; } +#sb-bot .gr-markdown p, +#sb-bot p { font-size: 12px !important; color: #6b7280 !important; margin: 0 !important; } + +/* ══ MAIN CHAT COLUMN ════════════════════════════════════════════════════════ */ +#main-col { + flex: 1 1 0 !important; height: 100vh !important; min-width: 0 !important; + display: flex !important; flex-direction: column !important; + background: #ffffff !important; padding: 0 !important; gap: 0 !important; + overflow: hidden !important; +} +/* Chat topbar */ +#ct-bar { height: 52px !important; min-height: 52px !important; flex-shrink: 0 !important; + border-bottom: 1px solid #f0f0f0 !important; padding: 0 24px !important; + display: flex !important; align-items: center !important; + background: #fff !important; gap: 0 !important; } +#ct-bar .gr-markdown p, +#ct-bar p { font-size: 15px !important; font-weight: 600 !important; + color: #111 !important; margin: 0 !important; } + +/* Chatbot fills remaining height */ +#chatbot-box { + flex: 1 !important; min-height: 0 !important; + border: none !important; box-shadow: none !important; + background: #fafbfc !important; +} +/* Override Gradio's inline height */ +#chatbot-box > div { height: 100% !important; } +#chatbot-box .bubble-wrap { max-height: none !important; } + +/* Message bubbles */ +#chatbot-box .message-wrap { padding: 6px 24px !important; } +#chatbot-box .message.user .bubble-wrap, +#chatbot-box .user .bubble-wrap { + background: #f0fdf7 !important; border: 1px solid #c6f0dc !important; + border-radius: 14px 14px 4px 14px !important; + box-shadow: 0 1px 3px rgba(0,0,0,0.04) !important; +} +#chatbot-box .message.bot .bubble-wrap, +#chatbot-box .bot .bubble-wrap, +#chatbot-box .message.assistant .bubble-wrap { + background: #ffffff !important; border: 1px solid #eaecef !important; + border-radius: 14px 14px 14px 4px !important; + box-shadow: 0 1px 3px rgba(0,0,0,0.04) !important; +} + +/* ══ INPUT ZONE ══════════════════════════════════════════════════════════════ */ +#inp-zone { + border-top: 1px solid #f0f0f0 !important; flex-shrink: 0 !important; + padding: 12px 24px 8px !important; background: #fff !important; gap: 10px !important; +} +#msg-input { margin: 0 !important; } +#msg-input textarea { + border: 1.5px solid #dde1e7 !important; border-radius: 12px !important; + background: #f9fafb !important; font-size: 14px !important; + line-height: 1.6 !important; padding: 10px 14px !important; color: #111 !important; + transition: border-color .18s, box-shadow .18s !important; resize: none !important; + font-family: inherit !important; +} +#msg-input textarea:focus { + border-color: #09b37b !important; background: #fff !important; + box-shadow: 0 0 0 3px rgba(9,179,123,.10) !important; outline: none !important; +} +#msg-input textarea::placeholder { color: #b0b7c0 !important; } +#send-btn { margin: 0 !important; } +#send-btn button { + background: #09b37b !important; border: none !important; + border-radius: 11px !important; color: #fff !important; + font-size: 14px !important; font-weight: 600 !important; min-height: 50px !important; + box-shadow: 0 2px 8px rgba(9,179,123,.28) !important; + transition: background .15s, box-shadow .15s, transform .12s !important; + font-family: inherit !important; +} +#send-btn button:hover { + background: #07a06e !important; box-shadow: 0 5px 18px rgba(9,179,123,.38) !important; + transform: translateY(-1px) !important; +} +#send-btn button:active { transform: translateY(0) !important; } + +/* Action row */ +#act-row { padding: 2px 24px 12px !important; background: #fff !important; + gap: 10px !important; flex-shrink: 0 !important; } +#act-row button { + background: transparent !important; border: 1px solid #e5e7eb !important; + border-radius: 8px !important; color: #6b7280 !important; + font-size: 12.5px !important; height: 32px !important; transition: all .15s !important; +} +#clear-btn button:hover { + border-color: #fca5a5 !important; color: #dc2626 !important; + background: #fff8f8 !important; +} +#stop-btn button { + border-color: #fca5a5 !important; color: #dc2626 !important; +} +#stop-btn button:hover { + background: #fff1f2 !important; border-color: #f87171 !important; +} + +/* ══ RIGHT PANEL ═════════════════════════════════════════════════════════════ */ +#rp-col { + flex: 0 0 260px !important; width: 260px !important; + min-width: 260px !important; max-width: 260px !important; + height: 100vh !important; + border-left: 1px solid #f0f0f0 !important; background: #fafbfc !important; + overflow-y: auto !important; padding: 20px 16px !important; + flex-shrink: 0 !important; gap: 10px !important; +} +#rp-col::-webkit-scrollbar { width: 4px; } +#rp-col::-webkit-scrollbar-thumb { background: #d1d5db; border-radius: 4px; } +#rp-col .gr-markdown h3 { font-size: 13px !important; font-weight: 600 !important; + color: #374151 !important; margin: 0 0 12px !important; } +#rp-col .gr-markdown p { font-size: 13px !important; color: #4b5563 !important; + line-height: 1.6 !important; } +#rp-col label span { font-size: 13px !important; } + +/* ══ REASONING (think) BLOCK ═════════════════════════════════════════════════ */ +.thinking { + background: linear-gradient(135deg,#f0fdf8,#ecfdf5) !important; + border: 1px solid #a7f3d0 !important; border-left: 3px solid #09b37b !important; + border-radius: 10px !important; margin: 6px 0 10px !important; + font-size: 13px !important; overflow: hidden !important; +} +.thinking > summary { + padding: 9px 13px !important; color: #065f46 !important; font-weight: 500 !important; + cursor: pointer !important; list-style: none !important; user-select: none !important; + background: rgba(9,179,123,.05) !important; + display: flex !important; align-items: center !important; gap: 6px !important; +} +.thinking > summary::marker, +.thinking > summary::-webkit-details-marker { display: none !important; } +.thinking > summary::before { + content: "▶" !important; font-size: 10px !important; color: #09b37b !important; + transition: transform .2s !important; display: inline-block !important; +} +.thinking[open] > summary::before { transform: rotate(90deg) !important; } +.thinking > summary:hover { background: rgba(9,179,123,.10) !important; } +.thinking > div, .thinking > p { + padding: 8px 14px 11px !important; color: #1a3a2e !important; + line-height: 1.7 !important; border-top: 1px solid #bbf7d0 !important; +} + +/* ══ CODE BLOCKS ═════════════════════════════════════════════════════════════ */ +code, pre { font-family: 'JetBrains Mono','Fira Code',Consolas,monospace !important; } +pre { background: #1e1e2e !important; border-radius: 8px !important; + padding: 14px 16px !important; overflow-x: auto !important; border: none !important; } +pre code { color: #cdd6f4 !important; font-size: 13px !important; line-height: 1.6 !important; } +:not(pre) > code { background: #f1f4f8 !important; color: #d63384 !important; + border-radius: 4px !important; padding: 2px 5px !important; font-size: 13px !important; } + +/* ── Catppuccin Mocha 暗色语法高亮(覆盖 Highlight.js 默认配色) ───────── */ +pre code .hljs-keyword, +pre code .token.keyword { color: #cba6f7 !important; } /* 紫色: if/for/def/class/import */ +pre code .hljs-built_in, +pre code .token.builtin { color: #fab387 !important; } /* 橘色: print/len/range */ +pre code .hljs-string, +pre code .token.string { color: #a6e3a1 !important; } /* 绿色: "hello" */ +pre code .hljs-number, +pre code .token.number { color: #fab387 !important; } /* 橘色: 42, 3.14 */ +pre code .hljs-title, +pre code .hljs-title\\.function_, +pre code .token.function { color: #89b4fa !important; } /* 蓝色: 函数名 */ +pre code .hljs-comment, +pre code .token.comment { color: #7f849c !important; font-style: italic !important; } +pre code .hljs-variable, +pre code .token.variable { color: #f38ba8 !important; } /* 粉红: 变量 */ +pre code .hljs-operator, +pre code .token.operator { color: #89dceb !important; } /* 青色: = + - * / % */ +pre code .hljs-punctuation, +pre code .token.punctuation { color: #bac2de !important; } /* 浅灰: () [] {} , ; */ +pre code .hljs-params { color: #cdd6f4 !important; } /* 白色: 函数参数 */ +pre code .hljs-meta, +pre code .token.decorator { color: #f38ba8 !important; } /* 粉红: @decorator */ +pre code .hljs-literal, +pre code .token.boolean { color: #fab387 !important; } /* 橘色: True/False/None */ +pre code .hljs-type, +pre code .hljs-name { color: #f9e2af !important; } /* 黄色: 类型名/标签名 */ +pre code .hljs-attr, +pre code .token.attr-name { color: #f9e2af !important; } /* 黄色: 属性名 */ +pre code .hljs-symbol { color: #f2cdcd !important; } +pre code .hljs-selector-class { color: #f9e2af !important; } +pre code .hljs-selector-tag { color: #cba6f7 !important; } +""" + + +# ───────────────────────────────────────────────────────────────────────────── +# 流式生成 +# ───────────────────────────────────────────────────────────────────────────── + +def respond(user_msg, history, session_id, server_url, + temperature, top_k, top_p, max_tokens, thinking_budget, conversations): + if not user_msg.strip(): + yield history, "", conversations, render_sidebar(conversations, session_id), \ + _get_title_by_id(conversations, session_id) + return + + api_msgs = _build_api_messages(history) + api_msgs.append({"role": "user", "content": user_msg}) + + history = history + [ + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": "▌"}, + ] + # Pre-compute sidebar (won't change during streaming) + pre_sidebar = render_sidebar(conversations, session_id) + pre_title = _get_title_by_id(conversations, session_id) or "新对话" + yield history, "", conversations, pre_sidebar, pre_title + + full_text = "" + try: + with requests.post( + f"{server_url}/v1/chat/completions", + json={ + "messages": api_msgs, + "temperature": float(temperature), + "top_k": int(top_k), + "top_p": float(top_p), + "max_tokens": int(max_tokens), + "thinking_budget": int(thinking_budget), + "stream": True, + "session_id": session_id, + }, + stream=True, timeout=300, + ) as resp: + resp.raise_for_status() + for line in resp.iter_lines(): + if not line: + continue + s = line.decode() if isinstance(line, bytes) else line + if not s.startswith("data: "): + continue + payload = s[6:].strip() + if payload == "[DONE]": + break + try: + delta = json.loads(payload)["choices"][0]["delta"].get("content", "") + if delta: + full_text += delta + # 实时更新:确保 think 标签完整,流式显示 + normalized_text = _normalize_think(_clean(full_text)) + history[-1]["content"] = normalized_text + yield history, "", conversations, pre_sidebar, pre_title + except Exception: + pass + except Exception as exc: + history[-1]["content"] = f"❌ 连接错误:{exc}" + yield history, "", conversations, pre_sidebar, pre_title + return + + # 生成完毕:确保 think 标签闭合,去除重复内容 + final_text = _normalize_think(_trim_repetition(_clean(full_text))) + history[-1]["content"] = final_text + # 更新历史记录和侧边栏 + updated_convs = _update_conversations(conversations, session_id, history) + new_title = _conv_title(history) + new_sidebar = render_sidebar(updated_convs, session_id) + # 持久化到磁盘 + _save_conversations(updated_convs, session_id) + yield history, "", updated_convs, new_sidebar, new_title + + +def do_clear(session_id, server_url, conversations): + try: + requests.post(f"{server_url}/v1/sessions/{session_id}/clear", timeout=5) + except Exception: + pass + # 清空消息,保留此会话 id + # 同时更新 conversations 中该 session 的消息 + updated_convs = _update_conversations(conversations, session_id, []) + _save_conversations(updated_convs, session_id) + sidebar = render_sidebar(updated_convs, session_id) + title = _get_title_by_id(updated_convs, session_id) or "新对话" + return [], "", updated_convs, sidebar, title + + +def stop_generation(session_id: str, server_url: str): + """通知服务器立即停止当前会话的生成。""" + try: + requests.post(f"{server_url}/v1/sessions/{session_id}/abort", timeout=2) + except Exception: + pass + + +def do_new_session(conversations): + new_id = str(uuid.uuid4()) + _save_conversations(conversations, new_id) + sidebar = render_sidebar(conversations, new_id) + return [], "", new_id, sidebar, "新对话" + + +def on_history_click(conv_id, conversations, server_url): + """处理历史对话点击切换事件。""" + if not conv_id: + return [], gr.update(), gr.update(), gr.update(), "" + + # 查找对应的会话 + for conv in conversations: + if conv["id"] == conv_id: + msgs = conv.get("messages", []) + title = conv.get("title") or "新对话" + sidebar = render_sidebar(conversations, conv_id) + # 返回:chatbot 消息,session_id, sidebar HTML, 标题,清空 hist_click + return msgs, conv_id, sidebar, title, "" + + # 未找到会话,返回空 + return [], conv_id, render_sidebar(conversations, conv_id), "新对话", "" + + +def _restore_chat_history(conversations: list, session_id: str) -> list: + """从 conversations 列表中恢复指定 session 的 chatbot 消息列表。""" + for conv in conversations: + if conv["id"] == session_id: + return conv.get("messages", []) + return [] + + +# ───────────────────────────────────────────────────────────────────────────── +# UI 构建 +# ───────────────────────────────────────────────────────────────────────────── + +def build_ui(server_url: str) -> gr.Blocks: + # 启动时从磁盘加载历史对话 + saved_convs, saved_sid = _load_conversations() + + with gr.Blocks(title="LLAISYS模型聊天机器人") as demo: + + # ── State ───────────────────────────────────────────────────────── + session_id_st = gr.State(saved_sid) + server_url_st = gr.State(server_url) + conversations_st = gr.State(saved_convs) + + with gr.Row(elem_id="outer-row"): + + # ══ LEFT SIDEBAR ═══════════════════════════════════════════════ + with gr.Column(scale=0, min_width=220, elem_id="sidebar"): + + with gr.Row(elem_id="sb-hdr"): + gr.Markdown("🤖 **LLAISYS Chat**") + + sb_new_btn = gr.Button("✦ 新对话", elem_id="sb-new") + + with gr.Row(elem_id="sb-lbl"): + gr.Markdown("历史对话") + + history_html = gr.HTML( + render_sidebar(saved_convs, saved_sid), + elem_id="sb-hist", + ) + + # 隐藏的 textbox,接收侧边栏 JS 点击事件(用 CSS 隐藏而不是 visible=False) + hist_click = gr.Textbox(value="", show_label=False, elem_id="hcb") + + with gr.Row(elem_id="sb-bot"): + gr.Markdown("👤  LLAISYS User") + + # ══ MAIN CHAT COLUMN ═══════════════════════════════════════════ + with gr.Column(scale=5, elem_id="main-col"): + + # 顶部标题栏 + with gr.Row(elem_id="ct-bar"): + conv_title_md = gr.Markdown( + _get_title_by_id(saved_convs, saved_sid) + ) + + # 聊天区 + chatbot = gr.Chatbot( + elem_id="chatbot-box", + height=600, + show_label=False, + render_markdown=True, + sanitize_html=False, + reasoning_tags=[("", "")], + value=_restore_chat_history(saved_convs, saved_sid), + placeholder=( + "
" + "
" + "

" + "有什么可以帮助你?

" + "

" + "推理过程将以可折叠的思考块展示

" + "
" + ), + ) + + # 输入行 + with gr.Row(elem_id="inp-zone"): + msg_box = gr.Textbox( + placeholder="发消息给 LLAISYS(Enter 换行,shift + Enter 发送)", + show_label=False, lines=2, max_lines=8, + scale=9, container=False, autofocus=True, + elem_id="msg-input", + ) + send_btn = gr.Button( + "发 送", variant="primary", scale=1, min_width=90, + elem_id="send-btn", + ) + + # 操作行 + with gr.Row(elem_id="act-row"): + clear_btn = gr.Button("🗑 清空对话", size="sm", elem_id="clear-btn") + stop_btn = gr.Button("⏹ 停止生成", size="sm", elem_id="stop-btn") + + # ══ RIGHT SETTINGS PANEL ═══════════════════════════════════════ + with gr.Column(scale=0, min_width=260, elem_id="rp-col"): + gr.Markdown("### ⚙️ 生成参数") + temperature = gr.Slider( + minimum=0.0, maximum=2.0, value=1, step=0.05, + label="温度 Temperature", info="↑ 更随机 · ↓ 更保守", + ) + top_k = gr.Slider( + minimum=1, maximum=200, value=30, step=1, label="Top-K", + ) + top_p = gr.Slider( + minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P 核采样", + ) + max_tokens = gr.Slider( + minimum=64, maximum=4096, value=2048, step=64, label="最大 Token 数", + ) + thinking_budget = gr.Slider( + minimum=0, maximum=3000, value=2000, step=100, + label="思考限额 (think 块最多字符数)", + info="0 = 不限制", + ) + gr.Markdown( + "
\n\n" + "### 💭 思考块\n\n" + "点击 **Reasoning** 折叠块\n" + "展开 / 收起推理过程。\n\n" + "适用于 **DeepSeek-R1**、\n" + "**Qwen3** 等推理模型。" + ) + gr.Markdown("
") + gr.Markdown("### 🔄 模型") + model_dropdown = gr.Dropdown( + choices=["Qwen2"], + value="Qwen2", + label="当前模型", + interactive=False, + info="服务端加载的模型", + ) + + # ── 事件绑定 ──────────────────────────────────────────────────────── + # gen_outputs: [chatbot, msg_box, conversations_st, history_html, conv_title_md] + gen_inputs = [msg_box, chatbot, session_id_st, server_url_st, + temperature, top_k, top_p, max_tokens, thinking_budget, + conversations_st] + gen_outputs = [chatbot, msg_box, conversations_st, history_html, conv_title_md] + + gen_event = send_btn.click(respond, inputs=gen_inputs, outputs=gen_outputs) + + # 绑定 Enter 键发送(Gradio 原生支持) + msg_box.submit(respond, inputs=gen_inputs, outputs=gen_outputs) + + stop_btn.click( + stop_generation, + inputs=[session_id_st, server_url_st], + outputs=[], + cancels=[gen_event], + ) + + clear_btn.click( + do_clear, + inputs=[session_id_st, server_url_st, conversations_st], + outputs=[chatbot, msg_box, conversations_st, history_html, conv_title_md], + ) + + sb_new_btn.click( + do_new_session, + inputs=[conversations_st], + outputs=[chatbot, msg_box, session_id_st, history_html, conv_title_md], + ) + + # 历史对话点击切换:使用 .change() 事件(比 .input 更可靠) + hist_click.change( + on_history_click, + inputs=[hist_click, conversations_st, server_url_st], + outputs=[chatbot, session_id_st, history_html, conv_title_md, msg_box], + ) + + # 每次浏览器连接/刷新时从磁盘重新加载历史,避免刷新后丢失会话 + def _on_page_load(): + convs, sid = _load_conversations() + history = _restore_chat_history(convs, sid) + sidebar = render_sidebar(convs, sid) + title = _get_title_by_id(convs, sid) or "新对话" + return convs, sid, history, sidebar, title + + demo.load( + _on_page_load, + inputs=[], + outputs=[conversations_st, session_id_st, chatbot, history_html, conv_title_md], + ) + + return demo + + +# ───────────────────────────────────────────────────────────────────────────── +# 入口 +# ───────────────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser(description="LLAISYS Chat UI (Gradio 6)") + parser.add_argument("--server", default="http://localhost:8000") + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=7860) + parser.add_argument("--share", action="store_true") + args = parser.parse_args() + + print(f"[ui] 后端: {args.server}") + print(f"[ui] 界面: http://localhost:{args.port}") + demo = build_ui(args.server) + demo.queue() + demo.launch( + server_name=args.host, + server_port=args.port, + share=args.share, + favicon_path="assets/AI.ico", # 新增这一行 + theme=gr.themes.Soft( + primary_hue=gr.themes.colors.emerald, + secondary_hue=gr.themes.colors.teal, + neutral_hue=gr.themes.colors.gray, + ), + css=CSS, + head=f'', + ) + + +if __name__ == "__main__": + main() diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h index 7054626d4..0eb36a3ed 100644 --- a/include/llaisys/models/qwen2.h +++ b/include/llaisys/models/qwen2.h @@ -30,6 +30,7 @@ __C { }; struct LlaisysQwen2Model; + struct LlaisysQwen2Session; __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice); @@ -37,6 +38,30 @@ __C { __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model); - __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); + // ── 向后兼容:操作模型内置 default_session + __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t *token_ids, size_t ntoken); + + __export int64_t llaisysQwen2ModelInferSample(struct LlaisysQwen2Model * model, int64_t *token_ids, size_t ntoken, float temperature, int top_k, float top_p); + + __export void llaisysQwen2ModelSetCachePos(struct LlaisysQwen2Model * model, size_t pos); + + __export size_t llaisysQwen2ModelGetCachePos(struct LlaisysQwen2Model * model); + + __export void llaisysQwen2ModelResetCache(struct LlaisysQwen2Model * model); + + // ── 多用户 Session API:每用户独占 KV-Cache + __export struct LlaisysQwen2Session *llaisysQwen2SessionCreate(struct LlaisysQwen2Model * model); + + __export void llaisysQwen2SessionDestroy(struct LlaisysQwen2Session * session); + + __export int64_t llaisysQwen2SessionInfer(struct LlaisysQwen2Model * model, struct LlaisysQwen2Session * session, int64_t *token_ids, size_t ntoken); + + __export int64_t llaisysQwen2SessionInferSample(struct LlaisysQwen2Model * model, struct LlaisysQwen2Session * session, int64_t *token_ids, size_t ntoken, float temperature, int top_k, float top_p); + + __export void llaisysQwen2SessionSetCachePos(struct LlaisysQwen2Session * session, size_t pos); + + __export size_t llaisysQwen2SessionGetCachePos(struct LlaisysQwen2Session * session); + + __export void llaisysQwen2SessionResetCache(struct LlaisysQwen2Session * session); } #endif // LLAISYS_MODELS_QWEN2_H diff --git a/include/llaisys/runtime.h b/include/llaisys/runtime.h index d8e6f66f1..f4c88b8d8 100644 --- a/include/llaisys/runtime.h +++ b/include/llaisys/runtime.h @@ -42,6 +42,10 @@ __C { // Llaisys API for switching device context __export void llaisysSetContextRuntime(llaisysDeviceType_t, int); + + // Returns non-zero if the library was compiled with support for the given device type. + // A return value of 0 means the library must be recompiled with the matching backend option. + __export int llaisysIsDeviceSupported(llaisysDeviceType_t); } #endif // LLAISYS_RUNTIME_H diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py index f536fb527..fa243642b 100644 --- a/python/llaisys/libllaisys/__init__.py +++ b/python/llaisys/libllaisys/__init__.py @@ -12,6 +12,7 @@ from .tensor import llaisysTensor_t from .tensor import load_tensor from .ops import load_ops +from .qwen2 import load_qwen2, LlaisysQwen2Meta, LlaisysQwen2Weights, llaisysQwen2Model_t, llaisysQwen2Session_t def load_shared_library(): @@ -38,6 +39,7 @@ def load_shared_library(): load_runtime(LIB_LLAISYS) load_tensor(LIB_LLAISYS) load_ops(LIB_LLAISYS) +load_qwen2(LIB_LLAISYS) __all__ = [ @@ -47,9 +49,12 @@ def load_shared_library(): "llaisysTensor_t", "llaisysDataType_t", "DataType", - "llaisysDeviceType_t", "DeviceType", "llaisysMemcpyKind_t", "MemcpyKind", - "llaisysStream_t", + "LlaisysQwen2Meta", + "LlaisysQwen2Weights", + "llaisysQwen2Model_t", + "llaisysQwen2Session_t", + "llaisysDeviceType_t" ] diff --git a/python/llaisys/libllaisys/qwen2.py b/python/llaisys/libllaisys/qwen2.py new file mode 100644 index 000000000..878a1e5dc --- /dev/null +++ b/python/llaisys/libllaisys/qwen2.py @@ -0,0 +1,105 @@ +from .llaisys_types import llaisysDataType_t, llaisysDeviceType_t +from .tensor import llaisysTensor_t +from ctypes import Structure, POINTER, c_size_t, c_float, c_int64, c_int, c_void_p + + +class LlaisysQwen2Meta(Structure): + _fields_ = [ + ("dtype", llaisysDataType_t), + ("nlayer", c_size_t), + ("hs", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("maxseq", c_size_t), + ("voc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_int64), + ] + + +class LlaisysQwen2Weights(Structure): + _fields_ = [ + ("in_embed", llaisysTensor_t), + ("out_embed", llaisysTensor_t), + ("out_norm_w", llaisysTensor_t), + ("attn_norm_w", POINTER(llaisysTensor_t)), + ("attn_q_w", POINTER(llaisysTensor_t)), + ("attn_q_b", POINTER(llaisysTensor_t)), + ("attn_k_w", POINTER(llaisysTensor_t)), + ("attn_k_b", POINTER(llaisysTensor_t)), + ("attn_v_w", POINTER(llaisysTensor_t)), + ("attn_v_b", POINTER(llaisysTensor_t)), + ("attn_o_w", POINTER(llaisysTensor_t)), + ("mlp_norm_w", POINTER(llaisysTensor_t)), + ("mlp_gate_w", POINTER(llaisysTensor_t)), + ("mlp_up_w", POINTER(llaisysTensor_t)), + ("mlp_down_w", POINTER(llaisysTensor_t)), + ] + + +llaisysQwen2Model_t = c_void_p +llaisysQwen2Session_t = c_void_p + + +def load_qwen2(lib): + # llaisysQwen2ModelCreate + lib.llaisysQwen2ModelCreate.argtypes = [ + POINTER(LlaisysQwen2Meta), + llaisysDeviceType_t, + POINTER(c_int), + c_int, + ] + lib.llaisysQwen2ModelCreate.restype = llaisysQwen2Model_t + + # llaisysQwen2ModelDestroy + lib.llaisysQwen2ModelDestroy.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelDestroy.restype = None + + # llaisysQwen2ModelWeights + lib.llaisysQwen2ModelWeights.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelWeights.restype = POINTER(LlaisysQwen2Weights) + + # ── 向后兼容:默认 session ──────────────────────────────────────────────── + lib.llaisysQwen2ModelInfer.argtypes = [llaisysQwen2Model_t, POINTER(c_int64), c_size_t] + lib.llaisysQwen2ModelInfer.restype = c_int64 + + lib.llaisysQwen2ModelInferSample.argtypes = [ + llaisysQwen2Model_t, POINTER(c_int64), c_size_t, c_float, c_int, c_float] + lib.llaisysQwen2ModelInferSample.restype = c_int64 + + lib.llaisysQwen2ModelSetCachePos.argtypes = [llaisysQwen2Model_t, c_size_t] + lib.llaisysQwen2ModelSetCachePos.restype = None + + lib.llaisysQwen2ModelGetCachePos.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelGetCachePos.restype = c_size_t + + lib.llaisysQwen2ModelResetCache.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelResetCache.restype = None + + # ── 多用户 Session API ──────────────────────────────────────────────────── + lib.llaisysQwen2SessionCreate.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2SessionCreate.restype = llaisysQwen2Session_t + + lib.llaisysQwen2SessionDestroy.argtypes = [llaisysQwen2Session_t] + lib.llaisysQwen2SessionDestroy.restype = None + + lib.llaisysQwen2SessionInfer.argtypes = [ + llaisysQwen2Model_t, llaisysQwen2Session_t, POINTER(c_int64), c_size_t] + lib.llaisysQwen2SessionInfer.restype = c_int64 + + lib.llaisysQwen2SessionInferSample.argtypes = [ + llaisysQwen2Model_t, llaisysQwen2Session_t, + POINTER(c_int64), c_size_t, c_float, c_int, c_float] + lib.llaisysQwen2SessionInferSample.restype = c_int64 + + lib.llaisysQwen2SessionSetCachePos.argtypes = [llaisysQwen2Session_t, c_size_t] + lib.llaisysQwen2SessionSetCachePos.restype = None + + lib.llaisysQwen2SessionGetCachePos.argtypes = [llaisysQwen2Session_t] + lib.llaisysQwen2SessionGetCachePos.restype = c_size_t + + lib.llaisysQwen2SessionResetCache.argtypes = [llaisysQwen2Session_t] + lib.llaisysQwen2SessionResetCache.restype = None diff --git a/python/llaisys/libllaisys/runtime.py b/python/llaisys/libllaisys/runtime.py index 3e5b8be5b..02ad6ceb5 100644 --- a/python/llaisys/libllaisys/runtime.py +++ b/python/llaisys/libllaisys/runtime.py @@ -46,3 +46,6 @@ def load_runtime(lib): lib.llaisysSetContextRuntime.argtypes = [llaisysDeviceType_t, c_int] lib.llaisysSetContextRuntime.restype = None + + lib.llaisysIsDeviceSupported.argtypes = [llaisysDeviceType_t] + lib.llaisysIsDeviceSupported.restype = c_int diff --git a/python/llaisys/models/__init__.py b/python/llaisys/models/__init__.py index af9918b0d..8600f9b8e 100644 --- a/python/llaisys/models/__init__.py +++ b/python/llaisys/models/__init__.py @@ -1 +1 @@ -from .qwen2 import Qwen2 +from .qwen2 import Qwen2, Qwen2Session diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b21..88502eebb 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,33 +1,324 @@ from typing import Sequence -from ..libllaisys import LIB_LLAISYS -from ..libllaisys import DeviceType +from ..libllaisys import ( + LIB_LLAISYS, + DeviceType, + DataType, + LlaisysQwen2Meta, + LlaisysQwen2Weights, + llaisysQwen2Model_t, + llaisysQwen2Session_t, + llaisysDeviceType_t, +) +from ..tensor import Tensor +from ctypes import c_int64, c_size_t, c_int, c_float, c_char, c_void_p, POINTER, addressof from pathlib import Path import safetensors +from safetensors import safe_open +import json +import numpy as np +import torch + + +# ───────────────────────────────────────────────────────────────────────────── +# Qwen2Session: 封装每用户独立的 KV-Cache 状态 +# ───────────────────────────────────────────────────────────────────────────── + +class Qwen2Session: + """ + 每个用户对话独占一个 Qwen2Session,持有独立的 KV-Cache。 + 多个 Session 可并发绑定到同一个 Qwen2 模型(权重只读共享)。 + """ + + def __init__(self, model: "Qwen2"): + self._model = model + self._sess = LIB_LLAISYS.llaisysQwen2SessionCreate(model._model) + if self._sess is None: + raise RuntimeError("llaisysQwen2SessionCreate returned null") + self._meta = model._meta + self._device = model._device + + # ── 基础属性 ──────────────────────────────────────────────────────────── + + @property + def cache_pos(self) -> int: + return LIB_LLAISYS.llaisysQwen2SessionGetCachePos(self._sess) + + @cache_pos.setter + def cache_pos(self, pos: int): + LIB_LLAISYS.llaisysQwen2SessionSetCachePos(self._sess, c_size_t(pos)) + + def reset_cache(self): + LIB_LLAISYS.llaisysQwen2SessionResetCache(self._sess) + + # ── 推理 ───────────────────────────────────────────────────────────────── + + def _infer_sample(self, token_ids: list, temperature: float, top_k: int, top_p: float) -> int: + LIB_LLAISYS.llaisysSetContextRuntime(llaisysDeviceType_t(self._device.value), c_int(0)) + arr = (c_int64 * len(token_ids))(*token_ids) + return LIB_LLAISYS.llaisysQwen2SessionInferSample( + self._model._model, self._sess, arr, len(token_ids), + c_float(temperature), c_int(top_k), c_float(top_p) + ) + + def stream_generate( + self, + inputs: Sequence[int], + max_new_tokens: int = 512, + top_k: int = 50, + top_p: float = 0.9, + temperature: float = 0.8, + ): + """Generator: yield 新生成的 token IDs(不含 prompt 部分)。""" + if not inputs: + return + + next_token = self._infer_sample(list(inputs), temperature, top_k, top_p) + yield next_token + + max_new = max_new_tokens if max_new_tokens is not None else 512 + for _ in range(max_new - 1): + if next_token == self._meta.end_token: + break + next_token = self._infer_sample([next_token], temperature, top_k, top_p) + yield next_token + + def __del__(self): + if hasattr(self, "_sess") and self._sess is not None: + LIB_LLAISYS.llaisysQwen2SessionDestroy(self._sess) + self._sess = None class Qwen2: def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor model_path = Path(model_path) + # Load config + with open(model_path / "config.json", "r") as f: + config = json.load(f) + + # Check device availability before allocating anything + api = LIB_LLAISYS.llaisysGetRuntimeAPI(llaisysDeviceType_t(device.value)) + ndev = api.contents.get_device_count() + if ndev == 0: + raise RuntimeError( + f"No devices available for device type '{device.name}'. " + "Make sure the library is compiled with the correct backend " + "and the hardware is accessible." + ) + + self._device = device + + # Create meta + meta = LlaisysQwen2Meta() + meta.dtype = DataType.BF16.value + meta.nlayer = config["num_hidden_layers"] + meta.hs = config["hidden_size"] + meta.nh = config["num_attention_heads"] + meta.nkvh = config["num_key_value_heads"] + meta.dh = config["hidden_size"] // config["num_attention_heads"] + meta.di = config["intermediate_size"] + meta.maxseq = config.get("max_position_embeddings", 4096) + meta.voc = config["vocab_size"] + meta.epsilon = config["rms_norm_eps"] + meta.theta = config.get("rope_theta", 10000.0) + meta.end_token = config.get("eos_token_id", 151643) + + # Create model + device_id = c_int(0) + self._model = LIB_LLAISYS.llaisysQwen2ModelCreate( + POINTER(LlaisysQwen2Meta)(meta), + device.value, + POINTER(c_int)(device_id), + 1 + ) + + # Get weights pointer + self._weights = LIB_LLAISYS.llaisysQwen2ModelWeights(self._model) + + # Load weights from safetensors for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") - for name_ in data_.keys(): - ## TODO: load the model weights - pass + print(f"Loading weights from {file.name}...") + with safe_open(file, framework="pt", device="cpu") as f: + for name in f.keys(): + print(f" Loading {name}... ", end="", flush=True) + tensor_data = f.get_tensor(name) + self._load_weight(name, tensor_data) + print("OK") + print("All weights loaded successfully!") + + self._meta = meta + + def _load_weight(self, name: str, data): + """Load a single weight tensor""" + # Convert to numpy array and keep alive during load + if isinstance(data, torch.Tensor): + print(f"shape={data.shape}, dtype={data.dtype}") + if data.dtype == torch.bfloat16: + # For bfloat16, view as uint16 first, then convert to numpy + data_np = data.cpu().view(torch.uint16).numpy() + else: + # For other types, convert to numpy + data_np = data.cpu().numpy() + elif hasattr(data, 'ctypes'): + # Already numpy array + data_np = data + else: + raise TypeError(f"Unsupported data type: {type(data)}") + + # Ensure contiguous memory layout + if not data_np.flags['C_CONTIGUOUS']: + data_np = np.ascontiguousarray(data_np) + + print(f"numpy shape={data_np.shape}, dtype={data_np.dtype}, contiguous={data_np.flags['C_CONTIGUOUS']}") + data_ptr = c_void_p(data_np.ctypes.data) + print(f"data_ptr={data_ptr}") + + weights = self._weights.contents + print(f"weights={weights}") + + # Parse weight name + if name == "model.embed_tokens.weight": + tensor = Tensor(tensor=weights.in_embed) + print(f"tensor object created, calling load...") + tensor.load(data_ptr) + elif name == "lm_head.weight": + print(f"Accessing weights.out_embed...") + tensor = Tensor(tensor=weights.out_embed) + print(f"tensor object created, calling load...") + tensor.load(data_ptr) + elif name == "model.norm.weight": + tensor = Tensor(tensor=weights.out_norm_w) + tensor.load(data_ptr) + elif "model.layers." in name: + # Parse layer index + parts = name.split(".") + layer_idx = int(parts[2]) + + if "input_layernorm.weight" in name: + tensor = Tensor(tensor=weights.attn_norm_w[layer_idx]) + tensor.load(data_ptr) + elif "self_attn.q_proj.weight" in name: + tensor = Tensor(tensor=weights.attn_q_w[layer_idx]) + tensor.load(data_ptr) + elif "self_attn.q_proj.bias" in name: + tensor = Tensor(tensor=weights.attn_q_b[layer_idx]) + tensor.load(data_ptr) + elif "self_attn.k_proj.weight" in name: + tensor = Tensor(tensor=weights.attn_k_w[layer_idx]) + tensor.load(data_ptr) + elif "self_attn.k_proj.bias" in name: + tensor = Tensor(tensor=weights.attn_k_b[layer_idx]) + tensor.load(data_ptr) + elif "self_attn.v_proj.weight" in name: + tensor = Tensor(tensor=weights.attn_v_w[layer_idx]) + tensor.load(data_ptr) + elif "self_attn.v_proj.bias" in name: + tensor = Tensor(tensor=weights.attn_v_b[layer_idx]) + tensor.load(data_ptr) + elif "self_attn.o_proj.weight" in name: + tensor = Tensor(tensor=weights.attn_o_w[layer_idx]) + tensor.load(data_ptr) + elif "post_attention_layernorm.weight" in name: + tensor = Tensor(tensor=weights.mlp_norm_w[layer_idx]) + tensor.load(data_ptr) + elif "mlp.gate_proj.weight" in name: + tensor = Tensor(tensor=weights.mlp_gate_w[layer_idx]) + tensor.load(data_ptr) + elif "mlp.up_proj.weight" in name: + tensor = Tensor(tensor=weights.mlp_up_w[layer_idx]) + tensor.load(data_ptr) + elif "mlp.down_proj.weight" in name: + tensor = Tensor(tensor=weights.mlp_down_w[layer_idx]) + tensor.load(data_ptr) + else: + print(f"WARNING: Unmatched weight name: {name}") + + + def create_session(self) -> Qwen2Session: + """创建一个新的独立会话(每用户 KV-Cache 隔离)。""" + return Qwen2Session(self) def generate( self, inputs: Sequence[int], - max_new_tokens: int = None, - top_k: int = 1, - top_p: float = 0.8, + max_new_tokens: int = 512, + top_k: int = 50, + top_p: float = 0.9, + temperature: float = 0.8, + ): + """Generate tokens (blocking). Returns full list including prompt tokens.""" + LIB_LLAISYS.llaisysSetContextRuntime(llaisysDeviceType_t(self._device.value), c_int(0)) + + input_tokens = (c_int64 * len(inputs))(*inputs) + next_token = LIB_LLAISYS.llaisysQwen2ModelInferSample( + self._model, input_tokens, len(inputs), + c_float(temperature), c_int(top_k), c_float(top_p) + ) + + generated = list(inputs) + [next_token] + + max_new = max_new_tokens if max_new_tokens is not None else 512 + for _ in range(max_new - 1): + if next_token == self._meta.end_token: + break + token_array = (c_int64 * 1)(next_token) + next_token = LIB_LLAISYS.llaisysQwen2ModelInferSample( + self._model, token_array, 1, + c_float(temperature), c_int(top_k), c_float(top_p) + ) + generated.append(next_token) + + return generated + + def stream_generate( + self, + inputs: Sequence[int], + max_new_tokens: int = 512, + top_k: int = 50, + top_p: float = 0.9, temperature: float = 0.8, ): + """Generator: yields token IDs one by one as they are produced.""" + LIB_LLAISYS.llaisysSetContextRuntime(llaisysDeviceType_t(self._device.value), c_int(0)) + + if not inputs: + return + + input_tokens = (c_int64 * len(inputs))(*inputs) + next_token = LIB_LLAISYS.llaisysQwen2ModelInferSample( + self._model, input_tokens, len(inputs), + c_float(temperature), c_int(top_k), c_float(top_p) + ) + yield next_token + + max_new = max_new_tokens if max_new_tokens is not None else 512 + for _ in range(max_new - 1): + if next_token == self._meta.end_token: + break + token_array = (c_int64 * 1)(next_token) + next_token = LIB_LLAISYS.llaisysQwen2ModelInferSample( + self._model, token_array, 1, + c_float(temperature), c_int(top_k), c_float(top_p) + ) + yield next_token + + @property + def cache_pos(self) -> int: + """Current KV cache position (number of tokens already processed).""" + return LIB_LLAISYS.llaisysQwen2ModelGetCachePos(self._model) - # TODO: Implement generate function + @cache_pos.setter + def cache_pos(self, pos: int): + LIB_LLAISYS.llaisysQwen2ModelSetCachePos(self._model, c_size_t(pos)) - return [] + def reset_cache(self): + """Reset the KV cache to position 0.""" + LIB_LLAISYS.llaisysQwen2ModelResetCache(self._model) + + def __del__(self): + if hasattr(self, "_model") and self._model is not None: + LIB_LLAISYS.llaisysQwen2ModelDestroy(self._model) + self._model = None diff --git a/python/llaisys/tensor.py b/python/llaisys/tensor.py index 1466d851e..5b8d6355e 100644 --- a/python/llaisys/tensor.py +++ b/python/llaisys/tensor.py @@ -20,8 +20,11 @@ def __init__( device_id: int = 0, tensor: llaisysTensor_t = None, ): - if tensor: + if tensor is not None: + # Wrap an existing C tensor handle — we do NOT own it, + # so __del__ must not call tensorDestroy on it. self._tensor = tensor + self._owned = False else: _ndim = 0 if shape is None else len(shape) _shape = None if shape is None else (c_size_t * len(shape))(*shape) @@ -32,9 +35,10 @@ def __init__( llaisysDeviceType_t(device), c_int(device_id), ) + self._owned = True def __del__(self): - if hasattr(self, "_tensor") and self._tensor is not None: + if hasattr(self, "_tensor") and self._tensor is not None and getattr(self, "_owned", True): LIB_LLAISYS.tensorDestroy(self._tensor) self._tensor = None diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..9a1f535ce --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +# LLAISYS 项目依赖 +# 主要应用 +gradio>=6.0 +requests +fastapi +uvicorn +pydantic + +# 机器学习/推理 +torch +transformers + +# 可选依赖 +prompt_toolkit>=3.0.0 # CLI 客户端使用 +huggingface_hub # 模型下载 + +# 开发/测试 +# pytest # 如需添加测试可取消注释 diff --git a/src/core/context/context.cpp b/src/core/context/context.cpp index 44894b9e7..7b1e0c0bb 100644 --- a/src/core/context/context.cpp +++ b/src/core/context/context.cpp @@ -52,8 +52,11 @@ Context::~Context() { void Context::setDevice(llaisysDeviceType_t device_type, int device_id) { // If doest not match the current runtime. if (_current_runtime == nullptr || _current_runtime->deviceType() != device_type || _current_runtime->deviceId() != device_id) { - auto runtimes = _runtime_map[device_type]; - CHECK_ARGUMENT((size_t)device_id < runtimes.size() && device_id >= 0, "invalid device id"); + // Use find() to avoid inserting an empty vector via operator[] + auto it = _runtime_map.find(device_type); + CHECK_ARGUMENT(it != _runtime_map.end() && device_id >= 0 && (size_t)device_id < it->second.size(), + "invalid device id"); + auto &runtimes = it->second; // reference, not copy if (_current_runtime != nullptr) { _current_runtime->_deactivate(); } diff --git a/src/device/nvidia/nvidia_resource.cu b/src/device/nvidia/nvidia_resource.cu index 2e63647e5..310f7d27d 100644 --- a/src/device/nvidia/nvidia_resource.cu +++ b/src/device/nvidia/nvidia_resource.cu @@ -1,7 +1,48 @@ #include "nvidia_resource.cuh" +#include +#include +#include +#include + namespace llaisys::device::nvidia { -Resource::Resource(int device_id) : llaisys::device::DeviceResource(LLAISYS_DEVICE_NVIDIA, device_id) {} +// 按设备 ID 存储 cuBLAS handles +static std::unordered_map g_cublas_handles; +static std::mutex g_cublas_mutex; + +cublasHandle_t getCublasHandle() { + int device_id = 0; + cudaGetDevice(&device_id); + + std::lock_guard lock(g_cublas_mutex); + auto it = g_cublas_handles.find(device_id); + if (it != g_cublas_handles.end()) { + return it->second; + } + + // 初始化该设备的 cuBLAS handle + cublasHandle_t handle; + cublasStatus_t status = cublasCreate(&handle); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("getCublasHandle: cublasCreate failed"); + } + // A100 SM80 开启 TF32 数学模式,对 f32 GEMM 自动使用 Tensor Core + cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH); + + g_cublas_handles[device_id] = handle; + return handle; +} + +Resource::Resource(int device_id) + : llaisys::device::DeviceResource(LLAISYS_DEVICE_NVIDIA, device_id) { + // 切换到对应设备并确保 cuBLAS handle 已初始化 + cudaSetDevice(device_id); + getCublasHandle(); +} + +Resource::~Resource() { + // handle 会在进程退出时由 CUDA driver 清理,这里不主动销毁以避免重入问题 +} } // namespace llaisys::device::nvidia diff --git a/src/device/nvidia/nvidia_resource.cuh b/src/device/nvidia/nvidia_resource.cuh index a3002170b..24253bf36 100644 --- a/src/device/nvidia/nvidia_resource.cuh +++ b/src/device/nvidia/nvidia_resource.cuh @@ -2,10 +2,17 @@ #include "../device_resource.hpp" +#include + namespace llaisys::device::nvidia { + +// 获取当前设备的 cuBLAS handle(懒初始化,按设备 ID 缓存) +cublasHandle_t getCublasHandle(); + class Resource : public llaisys::device::DeviceResource { public: Resource(int device_id); ~Resource(); }; + } // namespace llaisys::device::nvidia diff --git a/src/device/nvidia/nvidia_runtime_api.cu b/src/device/nvidia/nvidia_runtime_api.cu index cab928261..911b7c305 100644 --- a/src/device/nvidia/nvidia_runtime_api.cu +++ b/src/device/nvidia/nvidia_runtime_api.cu @@ -1,56 +1,103 @@ #include "../runtime_api.hpp" -#include -#include +#include +#include +#include + +// CUDA error check helper +#define CUDA_CHECK(call) \ + do { \ + cudaError_t _err = (call); \ + if (_err != cudaSuccess) { \ + std::ostringstream _oss; \ + _oss << "CUDA error at " << __FILE__ << ":" << __LINE__ \ + << " : " << cudaGetErrorString(_err); \ + throw std::runtime_error(_oss.str()); \ + } \ + } while (0) namespace llaisys::device::nvidia { namespace runtime_api { + int getDeviceCount() { - TO_BE_IMPLEMENTED(); + int count = 0; + // Force CUDA runtime initialization before querying device count. + // In some container environments (e.g. DSW, Docker without proper NVIDIA hooks) + // cudaGetDeviceCount() returns 0 until the runtime is explicitly initialized. + // cudaFree(nullptr) is a guaranteed no-op that triggers that initialization; + // its return value may be non-success if no device is present, so we ignore it. + (void)cudaFree(nullptr); + CUDA_CHECK(cudaGetDeviceCount(&count)); + return count; } -void setDevice(int) { - TO_BE_IMPLEMENTED(); +void setDevice(int device_id) { + CUDA_CHECK(cudaSetDevice(device_id)); } void deviceSynchronize() { - TO_BE_IMPLEMENTED(); + CUDA_CHECK(cudaDeviceSynchronize()); } llaisysStream_t createStream() { - TO_BE_IMPLEMENTED(); + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + return static_cast(stream); } void destroyStream(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + CUDA_CHECK(cudaStreamDestroy(static_cast(stream))); } + void streamSynchronize(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + CUDA_CHECK(cudaStreamSynchronize(static_cast(stream))); } void *mallocDevice(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr = nullptr; + CUDA_CHECK(cudaMalloc(&ptr, size)); + return ptr; } void freeDevice(void *ptr) { - TO_BE_IMPLEMENTED(); + CUDA_CHECK(cudaFree(ptr)); } void *mallocHost(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr = nullptr; + // cudaMallocHost provides page-locked memory for faster H2D/D2H transfers + CUDA_CHECK(cudaMallocHost(&ptr, size)); + return ptr; } void freeHost(void *ptr) { - TO_BE_IMPLEMENTED(); + CUDA_CHECK(cudaFreeHost(ptr)); +} + +static cudaMemcpyKind tocudaMemcpyKind(llaisysMemcpyKind_t kind) { + switch (kind) { + case LLAISYS_MEMCPY_H2H: + return cudaMemcpyHostToHost; + case LLAISYS_MEMCPY_H2D: + return cudaMemcpyHostToDevice; + case LLAISYS_MEMCPY_D2H: + return cudaMemcpyDeviceToHost; + case LLAISYS_MEMCPY_D2D: + return cudaMemcpyDeviceToDevice; + default: + throw std::invalid_argument("memcpy: unknown llaisysMemcpyKind_t"); + } } void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); + CUDA_CHECK(cudaMemcpy(dst, src, size, tocudaMemcpyKind(kind))); } -void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); +void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, + llaisysStream_t stream) { + CUDA_CHECK(cudaMemcpyAsync(dst, src, size, tocudaMemcpyKind(kind), + static_cast(stream))); } static const LlaisysRuntimeAPI RUNTIME_API = { @@ -72,4 +119,5 @@ static const LlaisysRuntimeAPI RUNTIME_API = { const LlaisysRuntimeAPI *getRuntimeAPI() { return &runtime_api::RUNTIME_API; } + } // namespace llaisys::device::nvidia diff --git a/src/llaisys/qwen2.cc b/src/llaisys/qwen2.cc new file mode 100644 index 000000000..ef2cceb7b --- /dev/null +++ b/src/llaisys/qwen2.cc @@ -0,0 +1,491 @@ +#include "llaisys/models/qwen2.h" +#include "../ops/ops.hpp" +#include "../utils.hpp" +#include "llaisys_tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace llaisys { + +// ─── BF16/F16 → F32 bit-cast helpers(logit 类型转换,仅 CPU)─────────────── +static float bf16_to_f32(uint16_t bf16) { + uint32_t bits = static_cast(bf16) << 16; + float result; + std::memcpy(&result, &bits, sizeof(float)); + return result; +} + +static float f16_to_f32(uint16_t h) { + uint32_t sign = (h >> 15) & 1u; + uint32_t exp = (h >> 10) & 0x1fu; + uint32_t mant = h & 0x3ffu; + uint32_t bits; + if (exp == 0u) { + if (mant == 0u) { + bits = sign << 31; + } else { + exp = 1u; + while (!(mant & 0x400u)) { + mant <<= 1u; + exp--; + } + mant &= 0x3ffu; + bits = (sign << 31) | ((exp + 112u) << 23) | (mant << 13u); + } + } else if (exp == 31u) { + bits = (sign << 31) | (0xffu << 23) | (mant << 13u); + } else { + bits = (sign << 31) | ((exp + 112u) << 23) | (mant << 13u); + } + float result; + std::memcpy(&result, &bits, sizeof(float)); + return result; +} + +// ─── 每用户 Session:独立 KV-Cache + 采样状态 ──────────────────────────────── +// 权重只读共享,多个 Session 可并发操作同一 Qwen2Model 而互不干扰。 +struct Qwen2Session { + LlaisysQwen2Meta meta; // 模型 metadata 副本(用于分配 tensor) + llaisysDeviceType_t device; + std::vector k_cache; + std::vector v_cache; + size_t cache_pos = 0; + std::mt19937 rng; + + Qwen2Session(const LlaisysQwen2Meta &m, llaisysDeviceType_t dev) + : meta(m), device(dev), cache_pos(0), rng(std::random_device{}()) { + k_cache.resize(meta.nlayer); + v_cache.resize(meta.nlayer); + for (size_t i = 0; i < meta.nlayer; i++) { + k_cache[i] = Tensor::create({meta.maxseq, meta.nkvh, meta.dh}, meta.dtype, device); + v_cache[i] = Tensor::create({meta.maxseq, meta.nkvh, meta.dh}, meta.dtype, device); + } + } + + void reset() { + cache_pos = 0; + } + void set_pos(size_t pos) { cache_pos = pos; } + size_t get_pos() const { return cache_pos; } +}; + +struct Qwen2Model { + LlaisysQwen2Meta meta; + llaisysDeviceType_t device; + + // Weights + tensor_t in_embed; + tensor_t out_embed; + tensor_t out_norm_w; + std::vector attn_norm_w; + std::vector attn_q_w; + std::vector attn_q_b; + std::vector attn_k_w; + std::vector attn_k_b; + std::vector attn_v_w; + std::vector attn_v_b; + std::vector attn_o_w; + std::vector mlp_norm_w; + std::vector mlp_gate_w; + std::vector mlp_up_w; + std::vector mlp_down_w; + + // 向后兼容的默认会话(单用户接口通过此 session 操作) + std::unique_ptr default_session; + + Qwen2Model(const LlaisysQwen2Meta &m, llaisysDeviceType_t dev) + : meta(m), device(dev) { + // Create embedding and output tensors + in_embed = Tensor::create({meta.voc, meta.hs}, meta.dtype, device); + out_embed = Tensor::create({meta.voc, meta.hs}, meta.dtype, device); + out_norm_w = Tensor::create({meta.hs}, meta.dtype, device); + + // Initialize weight vectors + attn_norm_w.resize(meta.nlayer); + attn_q_w.resize(meta.nlayer); + attn_q_b.resize(meta.nlayer); + attn_k_w.resize(meta.nlayer); + attn_k_b.resize(meta.nlayer); + attn_v_w.resize(meta.nlayer); + attn_v_b.resize(meta.nlayer); + attn_o_w.resize(meta.nlayer); + mlp_norm_w.resize(meta.nlayer); + mlp_gate_w.resize(meta.nlayer); + mlp_up_w.resize(meta.nlayer); + mlp_down_w.resize(meta.nlayer); + + // Create per-layer weight tensors + for (size_t i = 0; i < meta.nlayer; i++) { + attn_norm_w[i] = Tensor::create({meta.hs}, meta.dtype, device); + attn_q_w[i] = Tensor::create({meta.nh * meta.dh, meta.hs}, meta.dtype, device); + attn_q_b[i] = Tensor::create({meta.nh * meta.dh}, meta.dtype, device); + attn_k_w[i] = Tensor::create({meta.nkvh * meta.dh, meta.hs}, meta.dtype, device); + attn_k_b[i] = Tensor::create({meta.nkvh * meta.dh}, meta.dtype, device); + attn_v_w[i] = Tensor::create({meta.nkvh * meta.dh, meta.hs}, meta.dtype, device); + attn_v_b[i] = Tensor::create({meta.nkvh * meta.dh}, meta.dtype, device); + attn_o_w[i] = Tensor::create({meta.hs, meta.hs}, meta.dtype, device); + mlp_norm_w[i] = Tensor::create({meta.hs}, meta.dtype, device); + mlp_gate_w[i] = Tensor::create({meta.di, meta.hs}, meta.dtype, device); + mlp_up_w[i] = Tensor::create({meta.di, meta.hs}, meta.dtype, device); + mlp_down_w[i] = Tensor::create({meta.hs, meta.di}, meta.dtype, device); + } + + // 创建默认会话(向后兼容单用户接口) + default_session = std::make_unique(m, dev); + } + + // 运行完整前向传播,更新 sess.cache_pos,返回设备端 logits [voc] + tensor_t run_forward(Qwen2Session &sess, int64_t *token_ids, size_t ntoken) { + size_t seq_len = ntoken; + + // Create position ids + auto pos_ids = Tensor::create({seq_len}, LLAISYS_DTYPE_I64, device); + std::vector pos_data(seq_len); + for (size_t i = 0; i < seq_len; i++) { + pos_data[i] = static_cast(sess.cache_pos) + static_cast(i); + } + pos_ids->load(pos_data.data()); + + // Embedding lookup + auto token_tensor = Tensor::create({seq_len}, LLAISYS_DTYPE_I64, device); + token_tensor->load(token_ids); + auto hidden = Tensor::create({seq_len, meta.hs}, meta.dtype, device); + ops::embedding(hidden, token_tensor, in_embed); + + // Process each layer + for (size_t layer = 0; layer < meta.nlayer; layer++) { + // Attention norm + auto normed = Tensor::create({seq_len, meta.hs}, meta.dtype, device); + ops::rms_norm(normed, hidden, attn_norm_w[layer], meta.epsilon); + + // Q, K, V projections + auto q = Tensor::create({seq_len, meta.nh * meta.dh}, meta.dtype, device); + auto k = Tensor::create({seq_len, meta.nkvh * meta.dh}, meta.dtype, device); + auto v = Tensor::create({seq_len, meta.nkvh * meta.dh}, meta.dtype, device); + + ops::linear(q, normed, attn_q_w[layer], attn_q_b[layer]); + ops::linear(k, normed, attn_k_w[layer], attn_k_b[layer]); + ops::linear(v, normed, attn_v_w[layer], attn_v_b[layer]); + + // Reshape to [seq_len, n_heads, head_dim] + auto q_shaped = q->view({seq_len, meta.nh, meta.dh}); + auto k_shaped = k->view({seq_len, meta.nkvh, meta.dh}); + auto v_shaped = v->view({seq_len, meta.nkvh, meta.dh}); + + // Apply RoPE + auto q_rope = Tensor::create({seq_len, meta.nh, meta.dh}, meta.dtype, device); + auto k_rope = Tensor::create({seq_len, meta.nkvh, meta.dh}, meta.dtype, device); + ops::rope(q_rope, q_shaped, pos_ids, meta.theta); + ops::rope(k_rope, k_shaped, pos_ids, meta.theta); + + // Update KV cache + for (size_t i = 0; i < seq_len; i++) { + auto k_slice = k_rope->slice(0, i, i + 1); + auto v_slice = v_shaped->slice(0, i, i + 1); + auto k_cache_slice = sess.k_cache[layer]->slice(0, sess.cache_pos + i, sess.cache_pos + i + 1); + auto v_cache_slice = sess.v_cache[layer]->slice(0, sess.cache_pos + i, sess.cache_pos + i + 1); + + // Flatten to 1D and perform D2D copy + auto k_flat = k_slice->view({meta.nkvh * meta.dh}); + auto v_flat = v_slice->view({meta.nkvh * meta.dh}); + auto kc_flat = k_cache_slice->view({meta.nkvh * meta.dh}); + auto vc_flat = v_cache_slice->view({meta.nkvh * meta.dh}); + + size_t copy_bytes = meta.nkvh * meta.dh * k_flat->elementSize(); + core::context().runtime().api()->memcpy_sync( + kc_flat->data(), k_flat->data(), copy_bytes, LLAISYS_MEMCPY_D2D); + core::context().runtime().api()->memcpy_sync( + vc_flat->data(), v_flat->data(), copy_bytes, LLAISYS_MEMCPY_D2D); + } + + // Get full KV from cache + size_t total_len = sess.cache_pos + seq_len; + auto k_full = sess.k_cache[layer]->slice(0, 0, total_len); + auto v_full = sess.v_cache[layer]->slice(0, 0, total_len); + + // Self attention + auto attn_out = Tensor::create({seq_len, meta.nh, meta.dh}, meta.dtype, device); + float scale = 1.0f / std::sqrt(static_cast(meta.dh)); + ops::self_attention(attn_out, q_rope, k_full, v_full, scale); + + // Output projection + auto attn_flat = attn_out->view({seq_len, meta.nh * meta.dh}); + auto attn_proj = Tensor::create({seq_len, meta.hs}, meta.dtype, device); + ops::linear(attn_proj, attn_flat, attn_o_w[layer], nullptr); + + // Residual connection + ops::add(hidden, hidden, attn_proj); + + // MLP norm + auto mlp_normed = Tensor::create({seq_len, meta.hs}, meta.dtype, device); + ops::rms_norm(mlp_normed, hidden, mlp_norm_w[layer], meta.epsilon); + + // MLP + auto gate = Tensor::create({seq_len, meta.di}, meta.dtype, device); + auto up = Tensor::create({seq_len, meta.di}, meta.dtype, device); + ops::linear(gate, mlp_normed, mlp_gate_w[layer], nullptr); + ops::linear(up, mlp_normed, mlp_up_w[layer], nullptr); + + auto mlp_out = Tensor::create({seq_len, meta.di}, meta.dtype, device); + ops::swiglu(mlp_out, gate, up); + + auto mlp_proj = Tensor::create({seq_len, meta.hs}, meta.dtype, device); + ops::linear(mlp_proj, mlp_out, mlp_down_w[layer], nullptr); + + // Residual connection + ops::add(hidden, hidden, mlp_proj); + } + + // Final norm + auto final_normed = Tensor::create({seq_len, meta.hs}, meta.dtype, device); + ops::rms_norm(final_normed, hidden, out_norm_w, meta.epsilon); + + // Get last token + auto last_hidden = final_normed->slice(0, seq_len - 1, seq_len); + auto last_flat = last_hidden->view({meta.hs}); + + // LM head + auto logits = Tensor::create({1, meta.voc}, meta.dtype, device); + ops::linear(logits, last_flat->view({1, meta.hs}), out_embed, nullptr); + + // 更新 cache 位置并返回设备端 logits + sess.cache_pos += seq_len; + + return logits->view({meta.voc}); + } + + // ── Argmax 贪心解码(原行为)────────────────────────────────────────────── + int64_t infer(Qwen2Session &sess, int64_t *token_ids, size_t ntoken) { + auto logits_flat = run_forward(sess, token_ids, ntoken); + auto max_idx = Tensor::create({1}, LLAISYS_DTYPE_I64, device); + auto max_val = Tensor::create({1}, LLAISYS_DTYPE_F32, device); + ops::argmax(max_idx, max_val, logits_flat); + std::vector result_vec(1); + core::context().runtime().api()->memcpy_sync( + reinterpret_cast(result_vec.data()), + max_idx->data(), sizeof(int64_t), LLAISYS_MEMCPY_D2H); + return result_vec[0]; + } + + // ── 设备端 logits (BF16/F16/F32) → CPU float32 ─────────────────────────── + void logits_to_f32_cpu(const tensor_t &logits_dev, std::vector &out) { + out.resize(meta.voc); + if (meta.dtype == LLAISYS_DTYPE_F32) { + core::context().runtime().api()->memcpy_sync( + reinterpret_cast(out.data()), + logits_dev->data(), meta.voc * sizeof(float), LLAISYS_MEMCPY_D2H); + } else { + std::vector raw(meta.voc); + core::context().runtime().api()->memcpy_sync( + reinterpret_cast(raw.data()), + logits_dev->data(), meta.voc * sizeof(uint16_t), LLAISYS_MEMCPY_D2H); + if (meta.dtype == LLAISYS_DTYPE_BF16) { + for (size_t i = 0; i < meta.voc; i++) { + out[i] = bf16_to_f32(raw[i]); + } + } else { + for (size_t i = 0; i < meta.voc; i++) { + out[i] = f16_to_f32(raw[i]); + } + } + } + } + + // ── Temperature / Top-K / Top-P 采样(CPU 端)─────────────────────────── + int64_t sample_token(std::mt19937 &rng, const std::vector &logits, + float temperature, int top_k, float top_p) { + size_t voc = logits.size(); + // 贪心 argmax + if (temperature <= 0.0f || top_k == 1) { + return static_cast( + std::max_element(logits.begin(), logits.end()) - logits.begin()); + } + // 带温度的 Softmax + float max_l = *std::max_element(logits.begin(), logits.end()); + std::vector probs(voc); + float sum = 0.0f; + for (size_t i = 0; i < voc; i++) { + probs[i] = std::exp((logits[i] - max_l) / temperature); + sum += probs[i]; + } + for (auto &p : probs) { + p /= sum; + } + // Top-K 截断 + int k = (top_k > 0 && top_k < static_cast(voc)) ? top_k : static_cast(voc); + std::vector idx(voc); + std::iota(idx.begin(), idx.end(), 0u); + std::partial_sort(idx.begin(), idx.begin() + k, idx.end(), + [&](size_t a, size_t b) { return probs[a] > probs[b]; }); + // Top-P Nucleus 截断 + size_t cutoff = static_cast(k); + if (top_p > 0.0f && top_p < 1.0f) { + float cum = 0.0f; + for (int i = 0; i < k; i++) { + cum += probs[idx[i]]; + if (cum >= top_p) { + cutoff = static_cast(i) + 1; + break; + } + } + } + // 重归一化后采样 + sum = 0.0f; + for (size_t i = 0; i < cutoff; i++) { + sum += probs[idx[i]]; + } + std::uniform_real_distribution uni(0.0f, sum); + float r = uni(rng); + float cum = 0.0f; + for (size_t i = 0; i < cutoff; i++) { + cum += probs[idx[i]]; + if (r < cum) { + return static_cast(idx[i]); + } + } + return static_cast(idx[cutoff - 1]); + } + + // ── 采样解码 ───────────────────────────────────────────────────────────── + int64_t infer_sample(Qwen2Session &sess, int64_t *token_ids, size_t ntoken, + float temperature, int top_k, float top_p) { + auto logits_flat = run_forward(sess, token_ids, ntoken); + std::vector logits_cpu; + logits_to_f32_cpu(logits_flat, logits_cpu); + return sample_token(sess.rng, logits_cpu, temperature, top_k, top_p); + } +}; + +} // namespace llaisys + +extern "C" { + +struct LlaisysQwen2Model *llaisysQwen2ModelCreate( + const LlaisysQwen2Meta *meta, + llaisysDeviceType_t device, + int *device_ids, + int ndevice) { + auto model = new llaisys::Qwen2Model(*meta, device); + return reinterpret_cast(model); +} + +void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model *model) { + delete reinterpret_cast(model); +} + +struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model *model_) { + auto model = reinterpret_cast(model_); + auto weights = new LlaisysQwen2Weights(); + + weights->in_embed = new LlaisysTensor{model->in_embed}; + weights->out_embed = new LlaisysTensor{model->out_embed}; + weights->out_norm_w = new LlaisysTensor{model->out_norm_w}; + + weights->attn_norm_w = new llaisysTensor_t[model->meta.nlayer]; + weights->attn_q_w = new llaisysTensor_t[model->meta.nlayer]; + weights->attn_q_b = new llaisysTensor_t[model->meta.nlayer]; + weights->attn_k_w = new llaisysTensor_t[model->meta.nlayer]; + weights->attn_k_b = new llaisysTensor_t[model->meta.nlayer]; + weights->attn_v_w = new llaisysTensor_t[model->meta.nlayer]; + weights->attn_v_b = new llaisysTensor_t[model->meta.nlayer]; + weights->attn_o_w = new llaisysTensor_t[model->meta.nlayer]; + weights->mlp_norm_w = new llaisysTensor_t[model->meta.nlayer]; + weights->mlp_gate_w = new llaisysTensor_t[model->meta.nlayer]; + weights->mlp_up_w = new llaisysTensor_t[model->meta.nlayer]; + weights->mlp_down_w = new llaisysTensor_t[model->meta.nlayer]; + + for (size_t i = 0; i < model->meta.nlayer; i++) { + weights->attn_norm_w[i] = new LlaisysTensor{model->attn_norm_w[i]}; + weights->attn_q_w[i] = new LlaisysTensor{model->attn_q_w[i]}; + weights->attn_q_b[i] = new LlaisysTensor{model->attn_q_b[i]}; + weights->attn_k_w[i] = new LlaisysTensor{model->attn_k_w[i]}; + weights->attn_k_b[i] = new LlaisysTensor{model->attn_k_b[i]}; + weights->attn_v_w[i] = new LlaisysTensor{model->attn_v_w[i]}; + weights->attn_v_b[i] = new LlaisysTensor{model->attn_v_b[i]}; + weights->attn_o_w[i] = new LlaisysTensor{model->attn_o_w[i]}; + weights->mlp_norm_w[i] = new LlaisysTensor{model->mlp_norm_w[i]}; + weights->mlp_gate_w[i] = new LlaisysTensor{model->mlp_gate_w[i]}; + weights->mlp_up_w[i] = new LlaisysTensor{model->mlp_up_w[i]}; + weights->mlp_down_w[i] = new LlaisysTensor{model->mlp_down_w[i]}; + } + + return weights; +} + +// ─── 向后兼容接口:路由到模型内置的 default_session ────────────────────────── + +int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model *model_, int64_t *token_ids, size_t ntoken) { + auto model = reinterpret_cast(model_); + return model->infer(*model->default_session, token_ids, ntoken); +} + +int64_t llaisysQwen2ModelInferSample(struct LlaisysQwen2Model *model_, + int64_t *token_ids, size_t ntoken, + float temperature, int top_k, float top_p) { + auto model = reinterpret_cast(model_); + return model->infer_sample(*model->default_session, token_ids, ntoken, + temperature, top_k, top_p); +} + +void llaisysQwen2ModelSetCachePos(struct LlaisysQwen2Model *model_, size_t pos) { + reinterpret_cast(model_)->default_session->set_pos(pos); +} + +size_t llaisysQwen2ModelGetCachePos(struct LlaisysQwen2Model *model_) { + return reinterpret_cast(model_)->default_session->get_pos(); +} + +void llaisysQwen2ModelResetCache(struct LlaisysQwen2Model *model_) { + reinterpret_cast(model_)->default_session->reset(); +} + +// ─── 多用户 Session API ─────────────────────────────────────────────────────── + +struct LlaisysQwen2Session *llaisysQwen2SessionCreate(struct LlaisysQwen2Model *model_) { + auto model = reinterpret_cast(model_); + auto sess = new llaisys::Qwen2Session(model->meta, model->device); + return reinterpret_cast(sess); +} + +void llaisysQwen2SessionDestroy(struct LlaisysQwen2Session *sess_) { + delete reinterpret_cast(sess_); +} + +int64_t llaisysQwen2SessionInfer(struct LlaisysQwen2Model *model_, + struct LlaisysQwen2Session *sess_, + int64_t *token_ids, size_t ntoken) { + auto model = reinterpret_cast(model_); + auto sess = reinterpret_cast(sess_); + return model->infer(*sess, token_ids, ntoken); +} + +int64_t llaisysQwen2SessionInferSample(struct LlaisysQwen2Model *model_, + struct LlaisysQwen2Session *sess_, + int64_t *token_ids, size_t ntoken, + float temperature, int top_k, float top_p) { + auto model = reinterpret_cast(model_); + auto sess = reinterpret_cast(sess_); + return model->infer_sample(*sess, token_ids, ntoken, temperature, top_k, top_p); +} + +void llaisysQwen2SessionSetCachePos(struct LlaisysQwen2Session *sess_, size_t pos) { + reinterpret_cast(sess_)->set_pos(pos); +} + +size_t llaisysQwen2SessionGetCachePos(struct LlaisysQwen2Session *sess_) { + return reinterpret_cast(sess_)->get_pos(); +} + +void llaisysQwen2SessionResetCache(struct LlaisysQwen2Session *sess_) { + reinterpret_cast(sess_)->reset(); +} + +} // extern "C" diff --git a/src/llaisys/runtime.cc b/src/llaisys/runtime.cc index 7b00ff1bb..f441114fe 100644 --- a/src/llaisys/runtime.cc +++ b/src/llaisys/runtime.cc @@ -10,4 +10,20 @@ __C void llaisysSetContextRuntime(llaisysDeviceType_t device_type, int device_id // Llaisys API for getting the runtime APIs __C const LlaisysRuntimeAPI *llaisysGetRuntimeAPI(llaisysDeviceType_t device_type) { return llaisys::device::getRuntimeAPI(device_type); +} + +// Returns 1 if the library was compiled with support for the given device type, 0 otherwise. +__C int llaisysIsDeviceSupported(llaisysDeviceType_t device_type) { + switch (device_type) { + case LLAISYS_DEVICE_CPU: + return 1; // CPU is always supported + case LLAISYS_DEVICE_NVIDIA: +#ifdef ENABLE_NVIDIA_API + return 1; +#else + return 0; +#endif + default: + return 0; + } } \ No newline at end of file diff --git a/src/ops/add/nvidia/add_cuda.cu b/src/ops/add/nvidia/add_cuda.cu new file mode 100644 index 000000000..b60206fef --- /dev/null +++ b/src/ops/add/nvidia/add_cuda.cu @@ -0,0 +1,121 @@ +#include "add_cuda.cuh" + +#include +#include +#include + +// ───────────────────────────────────────────────────────────────────────────── +// F32:使用 float4 宽加载(A100 128-bit 访存对齐) +// ───────────────────────────────────────────────────────────────────────────── +__global__ void add_f32_kernel(float *__restrict__ c, + const float *__restrict__ a, + const float *__restrict__ b, + size_t n4, size_t n) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < n4) { + float4 va = reinterpret_cast(a)[idx]; + float4 vb = reinterpret_cast(b)[idx]; + float4 vc; + vc.x = va.x + vb.x; + vc.y = va.y + vb.y; + vc.z = va.z + vb.z; + vc.w = va.w + vb.w; + reinterpret_cast(c)[idx] = vc; + } + // 处理尾部未对齐元素 + size_t tail_start = n4 * 4; + if (idx == 0) { + for (size_t i = tail_start; i < n; i++) { + c[i] = a[i] + b[i]; + } + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// F16:使用 __half2 一次处理两个 fp16 元素 +// ───────────────────────────────────────────────────────────────────────────── +__global__ void add_f16_kernel(__half *__restrict__ c, + const __half *__restrict__ a, + const __half *__restrict__ b, + size_t n2, size_t n) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < n2) { + __half2 va = reinterpret_cast(a)[idx]; + __half2 vb = reinterpret_cast(b)[idx]; + reinterpret_cast<__half2 *>(c)[idx] = __hadd2(va, vb); + } + if (idx == 0 && n % 2 != 0) { + c[n - 1] = __hadd(a[n - 1], b[n - 1]); + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// BF16:使用 __nv_bfloat162 一次处理两个 bf16 元素 +// ───────────────────────────────────────────────────────────────────────────── +__global__ void add_bf16_kernel(__nv_bfloat16 *__restrict__ c, + const __nv_bfloat16 *__restrict__ a, + const __nv_bfloat16 *__restrict__ b, + size_t n2, size_t n) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < n2) { + __nv_bfloat162 va = reinterpret_cast(a)[idx]; + __nv_bfloat162 vb = reinterpret_cast(b)[idx]; + reinterpret_cast<__nv_bfloat162 *>(c)[idx] = __hadd2(va, vb); + } + if (idx == 0 && n % 2 != 0) { + c[n - 1] = __hadd(a[n - 1], b[n - 1]); + } +} + +namespace llaisys::ops::nvidia { + +void add(std::byte *c, const std::byte *a, const std::byte *b, + llaisysDataType_t type, size_t numel) { + constexpr int BLOCK = 256; + + switch (type) { + case LLAISYS_DTYPE_F32: { + size_t n4 = numel / 4; + int grid = static_cast((n4 + BLOCK - 1) / BLOCK); + if (grid == 0) { + grid = 1; + } + add_f32_kernel<<>>( + reinterpret_cast(c), + reinterpret_cast(a), + reinterpret_cast(b), + n4, numel); + break; + } + case LLAISYS_DTYPE_F16: { + size_t n2 = numel / 2; + int grid = static_cast((n2 + BLOCK - 1) / BLOCK); + if (grid == 0) { + grid = 1; + } + add_f16_kernel<<>>( + reinterpret_cast<__half *>(c), + reinterpret_cast(a), + reinterpret_cast(b), + n2, numel); + break; + } + case LLAISYS_DTYPE_BF16: { + size_t n2 = numel / 2; + int grid = static_cast((n2 + BLOCK - 1) / BLOCK); + if (grid == 0) { + grid = 1; + } + add_bf16_kernel<<>>( + reinterpret_cast<__nv_bfloat16 *>(c), + reinterpret_cast(a), + reinterpret_cast(b), + n2, numel); + break; + } + default: + throw std::runtime_error("add CUDA: unsupported data type"); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/add/nvidia/add_cuda.cuh b/src/ops/add/nvidia/add_cuda.cuh new file mode 100644 index 000000000..720062cd4 --- /dev/null +++ b/src/ops/add/nvidia/add_cuda.cuh @@ -0,0 +1,11 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +// 向量化 add 算子,支持 f32 / f16 / bf16 +// A100 上使用 float4/__half2/__nv_bfloat162 宽加载提高带宽利用率 +void add(std::byte *c, const std::byte *a, const std::byte *b, + llaisysDataType_t type, size_t numel); +} // namespace llaisys::ops::nvidia diff --git a/src/ops/add/op.cpp b/src/ops/add/op.cpp index a057330d7..6e40b70bc 100644 --- a/src/ops/add/op.cpp +++ b/src/ops/add/op.cpp @@ -5,6 +5,10 @@ #include "cpu/add_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/add_cuda.cuh" +#endif + namespace llaisys::ops { void add(tensor_t c, tensor_t a, tensor_t b) { CHECK_SAME_DEVICE(c, a, b); @@ -25,8 +29,7 @@ void add(tensor_t c, tensor_t a, tensor_t b) { return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #ifdef ENABLE_NVIDIA_API case LLAISYS_DEVICE_NVIDIA: - TO_BE_IMPLEMENTED(); - return; + return nvidia::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #endif default: EXCEPTION_UNSUPPORTED_DEVICE; diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 000000000..e604b7b50 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,48 @@ +#include "argmax_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +template +void argmax_(int64_t *max_idx, float *max_val, const T *vals, size_t numel) { + size_t idx = 0; + float max_value = -std::numeric_limits::infinity(); + + for (size_t i = 0; i < numel; i++) { + float val; + if constexpr (std::is_same_v || std::is_same_v) { + val = llaisys::utils::cast(vals[i]); + } else { + val = static_cast(vals[i]); + } + + if (val > max_value) { + max_value = val; + idx = i; + } + } + + // max_val is always F32 + max_idx[0] = static_cast(idx); + max_val[0] = max_value; +} + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) { + int64_t *idx_ptr = reinterpret_cast(max_idx); + float *val_ptr = reinterpret_cast(max_val); // always F32 + + switch (type) { + case LLAISYS_DTYPE_F32: + return argmax_(idx_ptr, val_ptr, reinterpret_cast(vals), numel); + case LLAISYS_DTYPE_BF16: + return argmax_(idx_ptr, val_ptr, reinterpret_cast(vals), numel); + case LLAISYS_DTYPE_F16: + return argmax_(idx_ptr, val_ptr, reinterpret_cast(vals), numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 000000000..26ae3ef03 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel); +} diff --git a/src/ops/argmax/nvidia/argmax_cuda.cu b/src/ops/argmax/nvidia/argmax_cuda.cu new file mode 100644 index 000000000..ed2d0a5f3 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_cuda.cu @@ -0,0 +1,155 @@ +#include "argmax_cuda.cuh" + +#include +#include +#include +#include + +// ───────────────────────────────────────────────────────────────────────────── +// Warp-level (val, idx) max 规约 +// ───────────────────────────────────────────────────────────────────────────── +__device__ __forceinline__ void warpReduceMax(float &val, int &idx) { + constexpr unsigned FULL_MASK = 0xffffffff; +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + float other_val = __shfl_down_sync(FULL_MASK, val, offset); + int other_idx = __shfl_down_sync(FULL_MASK, idx, offset); + if (other_val > val) { + val = other_val; + idx = other_idx; + } + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// 每个 block 找到自己负责段的最大值/索引,然后写入共享内存 +// 第二阶段由单 block 归约(block_size = 1 启动)完成全局 argmax +// 对于推理中 vocab 规模(~32k~128k),通常一次 launch 就够 +// ───────────────────────────────────────────────────────────────────────────── +template +__global__ void argmax_kernel(const T *__restrict__ vals, + size_t n, + float *__restrict__ blk_max_val, + int *__restrict__ blk_max_idx) { + extern __shared__ char smem_raw[]; + float *smem_val = reinterpret_cast(smem_raw); + int *smem_idx = reinterpret_cast(smem_val + (blockDim.x / 32)); + + float local_val = -3.402823466e+38f; // -FLT_MAX + int local_idx = 0; + + // 每线程在自己的条带范围内找局部最大 + for (size_t i = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + i < n; + i += static_cast(gridDim.x) * blockDim.x) { + float v; + if constexpr (std::is_same_v) { + v = __half2float(vals[i]); + } else if constexpr (std::is_same_v) { + v = __bfloat162float(vals[i]); + } else { + v = static_cast(vals[i]); + } + if (v > local_val) { + local_val = v; + local_idx = static_cast(i); + } + } + + // Warp reduce + warpReduceMax(local_val, local_idx); + + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + int num_warps = (blockDim.x + 31) / 32; + + if (lane_id == 0) { + smem_val[warp_id] = local_val; + smem_idx[warp_id] = local_idx; + } + __syncthreads(); + + // Block reduce(第一个 warp 完成跨 warp 归约) + float bval = -3.402823466e+38f; + int bidx = 0; + if (threadIdx.x < static_cast(num_warps)) { + bval = smem_val[threadIdx.x]; + bidx = smem_idx[threadIdx.x]; + } + if (threadIdx.x < 32) { + warpReduceMax(bval, bidx); + } + if (threadIdx.x == 0) { + blk_max_val[blockIdx.x] = bval; + blk_max_idx[blockIdx.x] = bidx; + } +} + +// 最终全局归约(在 CPU 端完成,或用单 block kernel) +__global__ void argmax_final_kernel(const float *__restrict__ blk_val, + const int *__restrict__ blk_idx, + int num_blocks, + int64_t *__restrict__ out_idx, + float *__restrict__ out_val) { + float best_val = -3.402823466e+38f; + int best_idx = 0; + for (int i = 0; i < num_blocks; i++) { + if (blk_val[i] > best_val) { + best_val = blk_val[i]; + best_idx = blk_idx[i]; + } + } + out_idx[0] = static_cast(best_idx); + out_val[0] = best_val; +} + +namespace llaisys::ops::nvidia { + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, + llaisysDataType_t type, size_t numel) { + constexpr int BLOCK = 256; + int num_blocks = static_cast((numel + BLOCK - 1) / BLOCK); + // 限制 block 数,避免显存申请过多 + if (num_blocks > 512) { + num_blocks = 512; + } + + int num_warps = (BLOCK + 31) / 32; + size_t smem = static_cast(num_warps) * (sizeof(float) + sizeof(int)); + + // 临时显存存储每个 block 的局部结果 + float *d_blk_val = nullptr; + int *d_blk_idx = nullptr; + cudaMalloc(&d_blk_val, num_blocks * sizeof(float)); + cudaMalloc(&d_blk_idx, num_blocks * sizeof(int)); + + switch (type) { + case LLAISYS_DTYPE_F32: + argmax_kernel<<>>( + reinterpret_cast(vals), numel, d_blk_val, d_blk_idx); + break; + case LLAISYS_DTYPE_F16: + argmax_kernel<__half><<>>( + reinterpret_cast(vals), numel, d_blk_val, d_blk_idx); + break; + case LLAISYS_DTYPE_BF16: + argmax_kernel<__nv_bfloat16><<>>( + reinterpret_cast(vals), numel, d_blk_val, d_blk_idx); + break; + default: + cudaFree(d_blk_val); + cudaFree(d_blk_idx); + throw std::runtime_error("argmax CUDA: unsupported data type"); + } + + // 单线程最终归约(num_blocks 通常 ≤ 512,可在单线程完成) + argmax_final_kernel<<<1, 1>>>( + d_blk_val, d_blk_idx, num_blocks, + reinterpret_cast(max_idx), + reinterpret_cast(max_val)); + + cudaFree(d_blk_val); + cudaFree(d_blk_idx); +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/argmax/nvidia/argmax_cuda.cuh b/src/ops/argmax/nvidia/argmax_cuda.cuh new file mode 100644 index 000000000..59856c030 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_cuda.cuh @@ -0,0 +1,11 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +// Argmax CUDA 实现 +// 使用两阶段并行规约:block 内 warp reduce → atomicMax 全局归约 +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, + llaisysDataType_t type, size_t numel); +} // namespace llaisys::ops::nvidia diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d426..25b70c320 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,48 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/argmax_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/argmax_cuda.cuh" +#endif + namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + // 实现argmax算子 + // 检查设备一致性 + CHECK_SAME_DEVICE(max_idx, max_val, vals); + + // 检查 max_idx 和 max_val 的形状 + CHECK_ARGUMENT(max_idx->numel() == 1, "argmax: max_idx must be a single element tensor"); + CHECK_ARGUMENT(max_val->numel() == 1, "argmax: max_val must be a single element tensor"); + + // max_val is always F32 (the CUDA/CPU kernel stores a float result) + CHECK_ARGUMENT(max_val->dtype() == LLAISYS_DTYPE_F32, + "argmax: max_val must be of type F32"); + + // 检查 max_idx 的数据类型为 int64 + CHECK_ARGUMENT(max_idx->dtype() == LLAISYS_DTYPE_I64, "argmax: max_idx must be of type int64"); + + // 检查所有张量都是连续的 + ASSERT(max_idx->isContiguous() && max_val->isContiguous() && vals->isContiguous(), + "argmax: all tensors must be contiguous."); + + // 设置设备上下文 + llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId()); + + switch (vals->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::argmax(max_idx->data(), max_val->data(), vals->data(), + vals->dtype(), vals->numel()); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 000000000..4776c4f09 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,44 @@ +#include "embedding_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +void embedding_(T *out, const int64_t *index, const T *weight, size_t num_indices, size_t embedding_dim) { + // 对于每个索引,从 weight 中复制对应的行到 out + for (size_t i = 0; i < num_indices; i++) { + int64_t idx = index[i]; + const T *src = weight + idx * embedding_dim; + T *dst = out + i * embedding_dim; + + // 复制整行 + if constexpr (std::is_same_v || std::is_same_v) { + // 对于半精度类型,直接内存复制更高效 + std::memcpy(dst, src, embedding_dim * sizeof(T)); + } else { + std::memcpy(dst, src, embedding_dim * sizeof(T)); + } + } +} + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, size_t num_indices, size_t embedding_dim) { + const int64_t *idx_ptr = reinterpret_cast(index); + + switch (type) { + case LLAISYS_DTYPE_F32: + return embedding_(reinterpret_cast(out), idx_ptr, + reinterpret_cast(weight), num_indices, embedding_dim); + case LLAISYS_DTYPE_BF16: + return embedding_(reinterpret_cast(out), idx_ptr, + reinterpret_cast(weight), num_indices, embedding_dim); + case LLAISYS_DTYPE_F16: + return embedding_(reinterpret_cast(out), idx_ptr, + reinterpret_cast(weight), num_indices, embedding_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 000000000..8ee157565 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, size_t num_indices, size_t embedding_dim); +} diff --git a/src/ops/embedding/nvidia/embedding_cuda.cu b/src/ops/embedding/nvidia/embedding_cuda.cu new file mode 100644 index 000000000..6a0776c5e --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_cuda.cu @@ -0,0 +1,102 @@ +#include "embedding_cuda.cuh" + +#include + +// ───────────────────────────────────────────────────────────────────────────── +// Embedding lookup kernel +// gridDim.x = num_indices +// blockDim.x = min(embedding_dim, 1024) +// 每个 block 负责将 weight[index[i], :] 复制到 out[i, :] +// 使用 float4 宽加载(16 bytes / 线程)提升 DRAM 带宽 +// ───────────────────────────────────────────────────────────────────────────── +template +__global__ void embedding_kernel(T *__restrict__ out, + const int64_t *__restrict__ index, + const T *__restrict__ weight, + size_t embedding_dim) { + size_t row = blockIdx.x; + int64_t idx = index[row]; + const T *src = weight + idx * embedding_dim; + T *dst = out + row * embedding_dim; + + for (size_t i = threadIdx.x; i < embedding_dim; i += blockDim.x) { + dst[i] = src[i]; + } +} + +// float4 特化版本(embedding_dim 为 4 的倍数时使用) +__global__ void embedding_f32_kernel(float *__restrict__ out, + const int64_t *__restrict__ index, + const float *__restrict__ weight, + size_t dim4, size_t embedding_dim) { + size_t row = blockIdx.x; + int64_t idx = index[row]; + const float4 *src4 = reinterpret_cast(weight + idx * embedding_dim); + float4 *dst4 = reinterpret_cast(out + row * embedding_dim); + + for (size_t i = threadIdx.x; i < dim4; i += blockDim.x) { + dst4[i] = src4[i]; + } + // 尾部处理 + const float *src = weight + idx * embedding_dim; + float *dst = out + row * embedding_dim; + size_t tail_start = dim4 * 4; + for (size_t i = tail_start + threadIdx.x; i < embedding_dim; i += blockDim.x) { + dst[i] = src[i]; + } +} + +namespace llaisys::ops::nvidia { + +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, size_t num_indices, size_t embedding_dim) { + if (num_indices == 0 || embedding_dim == 0) { + return; + } + + const int64_t *idx_ptr = reinterpret_cast(index); + const int BLOCK = static_cast(embedding_dim < 1024 ? embedding_dim : 1024); + const int GRID = static_cast(num_indices); + + switch (type) { + case LLAISYS_DTYPE_F32: { + size_t dim4 = embedding_dim / 4; + if (dim4 > 0 && embedding_dim % 4 == 0) { + int blk = static_cast(dim4 < 256 ? dim4 : 256); + embedding_f32_kernel<<>>( + reinterpret_cast(out), + idx_ptr, + reinterpret_cast(weight), + dim4, embedding_dim); + } else { + embedding_kernel<<>>( + reinterpret_cast(out), + idx_ptr, + reinterpret_cast(weight), + embedding_dim); + } + break; + } + case LLAISYS_DTYPE_F16: { + // __half 与 uint16_t 等宽,可安全 reinterpret + embedding_kernel<<>>( + reinterpret_cast(out), + idx_ptr, + reinterpret_cast(weight), + embedding_dim); + break; + } + case LLAISYS_DTYPE_BF16: { + embedding_kernel<<>>( + reinterpret_cast(out), + idx_ptr, + reinterpret_cast(weight), + embedding_dim); + break; + } + default: + throw std::runtime_error("embedding CUDA: unsupported data type"); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/embedding/nvidia/embedding_cuda.cuh b/src/ops/embedding/nvidia/embedding_cuda.cuh new file mode 100644 index 000000000..0c14a62a1 --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_cuda.cuh @@ -0,0 +1,10 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +// GPU embedding lookup:每个 block 处理一行,通过合并访存提升带宽效率 +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, size_t num_indices, size_t embedding_dim); +} // namespace llaisys::ops::nvidia diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d06..40813d637 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,57 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/embedding_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/embedding_cuda.cuh" +#endif + namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + // 检查设备一致性 + CHECK_SAME_DEVICE(out, index, weight); + + // 检查 index 必须是 int64 类型 + CHECK_ARGUMENT(index->dtype() == LLAISYS_DTYPE_I64, "embedding: index must be of type int64"); + + // 检查 out 和 weight 的数据类型一致 + CHECK_SAME_DTYPE(out->dtype(), weight->dtype()); + + // 检查维度 + CHECK_ARGUMENT(index->ndim() == 1, "embedding: index must be 1D tensor"); + CHECK_ARGUMENT(weight->ndim() == 2, "embedding: weight must be 2D tensor"); + CHECK_ARGUMENT(out->ndim() == 2, "embedding: out must be 2D tensor"); + + // 检查形状兼容性 + size_t num_indices = index->shape()[0]; + size_t embedding_dim = weight->shape()[1]; + + CHECK_ARGUMENT(out->shape()[0] == num_indices, + "embedding: out's first dimension must match index length"); + CHECK_ARGUMENT(out->shape()[1] == embedding_dim, + "embedding: out's second dimension must match weight's embedding dimension"); + + // 检查所有张量都是连续的 + ASSERT(out->isContiguous() && index->isContiguous() && weight->isContiguous(), + "embedding: all tensors must be contiguous."); + + // 设置设备上下文 + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::embedding(out->data(), index->data(), weight->data(), + out->dtype(), num_indices, embedding_dim); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::embedding(out->data(), index->data(), weight->data(), + out->dtype(), num_indices, embedding_dim); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 000000000..dc7d9754d --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,78 @@ +#include "linear_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +// Y = X * W^T + b +// X: [batch_size, in_features] +// W: [out_features, in_features] (需要转置) +// b: [out_features] (可选) +// Y: [batch_size, out_features] +template +void linear_(T *out, const T *in, const T *weight, const T *bias, + size_t batch_size, size_t in_features, size_t out_features) { + // 对于每个批次样本 + for (size_t b = 0; b < batch_size; b++) { + // 对于每个输出特征 + for (size_t o = 0; o < out_features; o++) { + float sum = 0.0f; + + // 计算点积: X[b, :] * W[o, :]^T + for (size_t i = 0; i < in_features; i++) { + float x_val, w_val; + + if constexpr (std::is_same_v || std::is_same_v) { + x_val = llaisys::utils::cast(in[b * in_features + i]); + w_val = llaisys::utils::cast(weight[o * in_features + i]); + } else { + x_val = static_cast(in[b * in_features + i]); + w_val = static_cast(weight[o * in_features + i]); + } + + sum += x_val * w_val; + } + + // 添加偏置(如果提供) + if (bias != nullptr) { + if constexpr (std::is_same_v || std::is_same_v) { + sum += llaisys::utils::cast(bias[o]); + } else { + sum += static_cast(bias[o]); + } + } + + // 存储结果 + if constexpr (std::is_same_v || std::is_same_v) { + out[b * out_features + o] = llaisys::utils::cast(sum); + } else { + out[b * out_features + o] = static_cast(sum); + } + } + } +} + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t batch_size, size_t in_features, size_t out_features) { + switch (type) { + case LLAISYS_DTYPE_F32: + return linear_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), + bias ? reinterpret_cast(bias) : nullptr, + batch_size, in_features, out_features); + case LLAISYS_DTYPE_BF16: + return linear_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), + bias ? reinterpret_cast(bias) : nullptr, + batch_size, in_features, out_features); + case LLAISYS_DTYPE_F16: + return linear_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), + bias ? reinterpret_cast(bias) : nullptr, + batch_size, in_features, out_features); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 000000000..2d2a0ab34 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t batch_size, size_t in_features, size_t out_features); +} diff --git a/src/ops/linear/nvidia/linear_cuda.cu b/src/ops/linear/nvidia/linear_cuda.cu new file mode 100644 index 000000000..ef360b6cf --- /dev/null +++ b/src/ops/linear/nvidia/linear_cuda.cu @@ -0,0 +1,154 @@ +#include "linear_cuda.cuh" + +#include "../../../device/nvidia/nvidia_resource.cuh" + +#include +#include +#include +#include +#include + +#define CUBLAS_CHECK(call) \ + do { \ + cublasStatus_t _st = (call); \ + if (_st != CUBLAS_STATUS_SUCCESS) { \ + std::ostringstream _oss; \ + _oss << "cuBLAS error " << static_cast(_st) \ + << " at " << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(_oss.str()); \ + } \ + } while (0) + +// ───────────────────────────────────────────────────────────────────────────── +// Bias add kernel(cuBLAS GEMM 不原生支持 bias,用独立 kernel 追加) +// ───────────────────────────────────────────────────────────────────────────── +template +__global__ void add_bias_kernel(T *__restrict__ out, + const T *__restrict__ bias, + size_t batch_size, + size_t out_features) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= batch_size * out_features) { + return; + } + + size_t col = idx % out_features; + float o_val, b_val; + if constexpr (std::is_same_v) { + o_val = __half2float(out[idx]); + b_val = __half2float(bias[col]); + } else if constexpr (std::is_same_v) { + o_val = __bfloat162float(out[idx]); + b_val = __bfloat162float(bias[col]); + } else { + o_val = static_cast(out[idx]); + b_val = static_cast(bias[col]); + } + float result = o_val + b_val; + if constexpr (std::is_same_v) { + out[idx] = __float2half(result); + } else if constexpr (std::is_same_v) { + out[idx] = __float2bfloat16(result); + } else { + out[idx] = static_cast(result); + } +} + +namespace llaisys::ops::nvidia { + +// ───────────────────────────────────────────────────────────────────────────── +// 行主序 Y = X * W^T,等价于列主序 Y^T = W * X^T +// 调用公式(cuBLAS col-major): +// cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, N, M, K, &a, W, K, X, K, &b, Y, N) +// 其中 M=batch_size, K=in_features, N=out_features +// ───────────────────────────────────────────────────────────────────────────── +void linear(std::byte *out, const std::byte *in, const std::byte *weight, + const std::byte *bias, + llaisysDataType_t type, + size_t batch_size, size_t in_features, size_t out_features) { + cublasHandle_t handle = llaisys::device::nvidia::getCublasHandle(); + + int M = static_cast(batch_size); + int K = static_cast(in_features); + int N = static_cast(out_features); + + switch (type) { + case LLAISYS_DTYPE_F32: { + // TF32 Tensor Core(默认由 CUBLAS_TF32_TENSOR_OP_MATH 开启) + const float alpha = 1.0f, beta_val = 0.0f; + CUBLAS_CHECK(cublasSgemm( + handle, + CUBLAS_OP_T, CUBLAS_OP_N, + N, M, K, + &alpha, + reinterpret_cast(weight), K, + reinterpret_cast(in), K, + &beta_val, + reinterpret_cast(out), N)); + + if (bias) { + size_t numel = batch_size * out_features; + int grid = static_cast((numel + 255) / 256); + add_bias_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(bias), + batch_size, out_features); + } + break; + } + case LLAISYS_DTYPE_F16: { + // FP16 Tensor Core(A100 支持 CUBLAS_COMPUTE_16F) + const __half alpha = __float2half(1.0f); + const __half beta_val = __float2half(0.0f); + CUBLAS_CHECK(cublasHgemm( + handle, + CUBLAS_OP_T, CUBLAS_OP_N, + N, M, K, + &alpha, + reinterpret_cast(weight), K, + reinterpret_cast(in), K, + &beta_val, + reinterpret_cast<__half *>(out), N)); + + if (bias) { + size_t numel = batch_size * out_features; + int grid = static_cast((numel + 255) / 256); + add_bias_kernel<__half><<>>( + reinterpret_cast<__half *>(out), + reinterpret_cast(bias), + batch_size, out_features); + } + break; + } + case LLAISYS_DTYPE_BF16: { + // BF16 Tensor Core:A100 SM_80 峰值 312 TFLOPS + // 使用 cublasGemmEx 指定 BF16 计算类型 + const float alpha = 1.0f, beta_val = 0.0f; + CUBLAS_CHECK(cublasGemmEx( + handle, + CUBLAS_OP_T, CUBLAS_OP_N, + N, M, K, + &alpha, + weight, CUDA_R_16BF, K, + in, CUDA_R_16BF, K, + &beta_val, + out, CUDA_R_16BF, N, + CUBLAS_COMPUTE_32F, // 累加器使用 f32 确保精度 + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + if (bias) { + size_t numel = batch_size * out_features; + int grid = static_cast((numel + 255) / 256); + add_bias_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16 *>(out), + reinterpret_cast(bias), + batch_size, out_features); + } + break; + } + default: + throw std::runtime_error("linear CUDA: unsupported data type"); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/linear/nvidia/linear_cuda.cuh b/src/ops/linear/nvidia/linear_cuda.cuh new file mode 100644 index 000000000..1a400691f --- /dev/null +++ b/src/ops/linear/nvidia/linear_cuda.cuh @@ -0,0 +1,14 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +// Linear CUDA 实现(Y = X * W^T + b) +// 使用 cuBLAS SGEMM(f32 TF32 Tensor Core)/ HGEMM(f16/bf16 BF16 Tensor Core) +// A100 SM_80 在 BF16 下峰值 312 TFLOPS,在 TF32 下峰值 156 TFLOPS +void linear(std::byte *out, const std::byte *in, const std::byte *weight, + const std::byte *bias, + llaisysDataType_t type, + size_t batch_size, size_t in_features, size_t out_features); +} // namespace llaisys::ops::nvidia diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f8655..2e5b011a1 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,73 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/linear_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/linear_cuda.cuh" +#endif + namespace llaisys::ops { void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + // 检查设备一致性 + CHECK_SAME_DEVICE(out, in, weight); + if (bias) { + CHECK_SAME_DEVICE(out, bias); + } + + // 检查数据类型一致性 + CHECK_SAME_DTYPE(out->dtype(), in->dtype(), weight->dtype()); + if (bias) { + CHECK_SAME_DTYPE(out->dtype(), bias->dtype()); + } + + // 检查维度 + CHECK_ARGUMENT(in->ndim() == 2, "linear: input must be 2D tensor"); + CHECK_ARGUMENT(weight->ndim() == 2, "linear: weight must be 2D tensor"); + CHECK_ARGUMENT(out->ndim() == 2, "linear: output must be 2D tensor"); + if (bias) { + CHECK_ARGUMENT(bias->ndim() == 1, "linear: bias must be 1D tensor"); + } + + // 获取形状参数 + size_t batch_size = in->shape()[0]; + size_t in_features = in->shape()[1]; + size_t out_features = weight->shape()[0]; + + CHECK_ARGUMENT(weight->shape()[1] == in_features, + "linear: weight's second dimension must match input's second dimension"); + CHECK_ARGUMENT(out->shape()[0] == batch_size, + "linear: output's first dimension must match input's first dimension"); + CHECK_ARGUMENT(out->shape()[1] == out_features, + "linear: output's second dimension must match weight's first dimension"); + if (bias) { + CHECK_ARGUMENT(bias->shape()[0] == out_features, + "linear: bias dimension must match weight's first dimension"); + } + + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), + "linear: all tensors must be contiguous."); + if (bias) { + ASSERT(bias->isContiguous(), "linear: bias must be contiguous."); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::linear(out->data(), in->data(), weight->data(), + bias ? bias->data() : nullptr, + out->dtype(), batch_size, in_features, out_features); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::linear(out->data(), in->data(), weight->data(), + bias ? bias->data() : nullptr, + out->dtype(), batch_size, in_features, out_features); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/ops.hpp b/src/ops/ops.hpp new file mode 100644 index 000000000..fbc8b98f8 --- /dev/null +++ b/src/ops/ops.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include "../tensor/tensor.hpp" + +namespace llaisys::ops { + +// Basic operations +void add(tensor_t c, tensor_t a, tensor_t b); + +// Model operations +void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals); + +void embedding(tensor_t out, tensor_t index, tensor_t weight); + +void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias); + +void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps); + +void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta); + +void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale); + +void swiglu(tensor_t out, tensor_t gate, tensor_t up); + +} // namespace llaisys::ops diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.cpp b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp new file mode 100644 index 000000000..c7f8e2ba6 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp @@ -0,0 +1,70 @@ +#include "rms_norm_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +// RMS Normalization: Y_i = (W_i × X_i) / sqrt((1/d) * sum(X_j^2) + eps) +template +void rms_norm_(T *out, const T *in, const T *weight, size_t num_rows, size_t row_dim, float eps) { + // 对每一行进行归一化 + for (size_t r = 0; r < num_rows; r++) { + const T *in_row = in + r * row_dim; + T *out_row = out + r * row_dim; + + // 计算均方根: sqrt((1/d) * sum(X_j^2) + eps) + float sum_squares = 0.0f; + for (size_t i = 0; i < row_dim; i++) { + float val; + if constexpr (std::is_same_v || std::is_same_v) { + val = llaisys::utils::cast(in_row[i]); + } else { + val = static_cast(in_row[i]); + } + sum_squares += val * val; + } + + // 计算 RMS: 1 / sqrt((1/d) * sum + eps) + float rms = 1.0f / std::sqrt(sum_squares / row_dim + eps); + + // 应用归一化和权重: Y_i = (W_i × X_i) × rms + for (size_t i = 0; i < row_dim; i++) { + float x_val, w_val; + + if constexpr (std::is_same_v || std::is_same_v) { + x_val = llaisys::utils::cast(in_row[i]); + w_val = llaisys::utils::cast(weight[i]); + } else { + x_val = static_cast(in_row[i]); + w_val = static_cast(weight[i]); + } + + float result = x_val * rms * w_val; + + if constexpr (std::is_same_v || std::is_same_v) { + out_row[i] = llaisys::utils::cast(result); + } else { + out_row[i] = static_cast(result); + } + } + } +} + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, + llaisysDataType_t type, size_t num_rows, size_t row_dim, float eps) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), num_rows, row_dim, eps); + case LLAISYS_DTYPE_BF16: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), num_rows, row_dim, eps); + case LLAISYS_DTYPE_F16: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), num_rows, row_dim, eps); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.hpp b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp new file mode 100644 index 000000000..a4a7aed3e --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, + llaisysDataType_t type, size_t num_rows, size_t row_dim, float eps); +} diff --git a/src/ops/rms_norm/nvidia/rms_norm_cuda.cu b/src/ops/rms_norm/nvidia/rms_norm_cuda.cu new file mode 100644 index 000000000..b3e3dc4f2 --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_cuda.cu @@ -0,0 +1,147 @@ +#include "rms_norm_cuda.cuh" + +#include +#include +#include + +// ───────────────────────────────────────────────────────────────────────────── +// Warp-level sum 规约(A100 warp size = 32) +// ───────────────────────────────────────────────────────────────────────────── +__device__ __forceinline__ float warpReduceSum(float val) { + constexpr unsigned FULL_MASK = 0xffffffff; +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(FULL_MASK, val, offset); + } + return val; +} + +// ───────────────────────────────────────────────────────────────────────────── +// RMS Norm kernel +// gridDim.x = num_rows;每个 block 使用最多 1024 个线程处理一行 +// 共享内存用于 warp 间规约中间结果 +// ───────────────────────────────────────────────────────────────────────────── +template +__global__ void rms_norm_kernel(T *__restrict__ out, + const T *__restrict__ in, + const T *__restrict__ weight, + size_t row_dim, + float eps) { + extern __shared__ float smem[]; // [num_warps] 个 float + + const size_t row = blockIdx.x; + const T *in_row = in + row * row_dim; + T *out_row = out + row * row_dim; + + // ── 阶段1:每线程并行累加 x^2 ────────────────────────────────────────── + float local_sum = 0.0f; + for (size_t i = threadIdx.x; i < row_dim; i += blockDim.x) { + float v; + if constexpr (std::is_same_v) { + v = __half2float(in_row[i]); + } else if constexpr (std::is_same_v) { + v = __bfloat162float(in_row[i]); + } else { + v = static_cast(in_row[i]); + } + local_sum += v * v; + } + + // ── 阶段2:warp 内规约 ──────────────────────────────────────────────── + local_sum = warpReduceSum(local_sum); + + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + int num_warps = (blockDim.x + 31) / 32; + + if (lane_id == 0) { + smem[warp_id] = local_sum; + } + __syncthreads(); + + // ── 阶段3:跨 warp 规约(由第一个 warp 完成)──────────────────────── + float block_sum = 0.0f; + if (threadIdx.x < static_cast(num_warps)) { + block_sum = smem[threadIdx.x]; + } + if (threadIdx.x < 32) { + block_sum = warpReduceSum(block_sum); + } + if (threadIdx.x == 0) { + smem[0] = block_sum; + } + __syncthreads(); + + // ── 阶段4:计算归一化因子 rms_inv ───────────────────────────────────── + float rms_inv = rsqrtf(smem[0] / static_cast(row_dim) + eps); + + // ── 阶段5:写出归一化结果 ───────────────────────────────────────────── + for (size_t i = threadIdx.x; i < row_dim; i += blockDim.x) { + float x_val, w_val; + if constexpr (std::is_same_v) { + x_val = __half2float(in_row[i]); + w_val = __half2float(weight[i]); + } else if constexpr (std::is_same_v) { + x_val = __bfloat162float(in_row[i]); + w_val = __bfloat162float(weight[i]); + } else { + x_val = static_cast(in_row[i]); + w_val = static_cast(weight[i]); + } + float result = x_val * rms_inv * w_val; + + if constexpr (std::is_same_v) { + out_row[i] = __float2half(result); + } else if constexpr (std::is_same_v) { + out_row[i] = __float2bfloat16(result); + } else { + out_row[i] = static_cast(result); + } + } +} + +namespace llaisys::ops::nvidia { + +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, + llaisysDataType_t type, size_t num_rows, size_t row_dim, float eps) { + if (num_rows == 0 || row_dim == 0) { + return; + } + + // block 大小:clamp 到 1024,按 32 对齐 + int block = static_cast(row_dim < 1024 ? row_dim : 1024); + block = ((block + 31) / 32) * 32; // round up to warp + int num_warps = (block + 31) / 32; + size_t smem_bytes = static_cast(num_warps) * sizeof(float); + + dim3 grid(static_cast(num_rows)); + dim3 blk(static_cast(block)); + + switch (type) { + case LLAISYS_DTYPE_F32: + rms_norm_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + row_dim, eps); + break; + case LLAISYS_DTYPE_F16: + rms_norm_kernel<__half><<>>( + reinterpret_cast<__half *>(out), + reinterpret_cast(in), + reinterpret_cast(weight), + row_dim, eps); + break; + case LLAISYS_DTYPE_BF16: + rms_norm_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16 *>(out), + reinterpret_cast(in), + reinterpret_cast(weight), + row_dim, eps); + break; + default: + throw std::runtime_error("rms_norm CUDA: unsupported data type"); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rms_norm/nvidia/rms_norm_cuda.cuh b/src/ops/rms_norm/nvidia/rms_norm_cuda.cuh new file mode 100644 index 000000000..dd2e30c0b --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_cuda.cuh @@ -0,0 +1,12 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +// RMS Norm CUDA 实现 +// 利用 warp-level __shfl_down_sync + shared memory 两阶段规约 +// 每个 block 处理一行,适用于 A100 的 164KB 共享内存 +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, + llaisysDataType_t type, size_t num_rows, size_t row_dim, float eps); +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9d..c3890c137 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,53 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rms_norm_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/rms_norm_cuda.cuh" +#endif + namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + // 检查设备一致性 + CHECK_SAME_DEVICE(out, in, weight); + + // 检查数据类型一致性 + CHECK_SAME_DTYPE(out->dtype(), in->dtype(), weight->dtype()); + + // 检查维度 + CHECK_ARGUMENT(in->ndim() == 2, "rms_norm: input must be 2D tensor"); + CHECK_ARGUMENT(out->ndim() == 2, "rms_norm: output must be 2D tensor"); + CHECK_ARGUMENT(weight->ndim() == 1, "rms_norm: weight must be 1D tensor"); + + // 检查形状兼容性 + size_t num_rows = in->shape()[0]; + size_t row_dim = in->shape()[1]; + + CHECK_SAME_SHAPE(out->shape(), in->shape()); + CHECK_ARGUMENT(weight->shape()[0] == row_dim, + "rms_norm: weight dimension must match input's row dimension"); + + // 检查所有张量都是连续的 + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), + "rms_norm: all tensors must be contiguous."); + + // 设置设备上下文 + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rms_norm(out->data(), in->data(), weight->data(), + out->dtype(), num_rows, row_dim, eps); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::rms_norm(out->data(), in->data(), weight->data(), + out->dtype(), num_rows, row_dim, eps); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 000000000..f47aa9474 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,83 @@ +#include "rope_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +// RoPE (Rotary Position Embedding) +// 对每个向量 x_i = [a_i, b_i] 应用旋转 +// φ_{i,j} = p_i / θ^(2j/d) +// a'_{i,j} = a_{i,j}cos(φ_{i,j}) - b_{i,j}sin(φ_{i,j}) +// b'_{i,j} = b_{i,j}cos(φ_{i,j}) + a_{i,j}sin(φ_{i,j}) +template +void rope_(T *out, const T *in, const int64_t *pos_ids, + size_t seq_len, size_t n_heads, size_t head_dim, float theta) { + size_t half_dim = head_dim / 2; + + // 对序列中的每个位置 + for (size_t s = 0; s < seq_len; s++) { + int64_t pos = pos_ids[s]; + + // 对每个注意力头 + for (size_t h = 0; h < n_heads; h++) { + // 对每个维度对 (a, b) + for (size_t j = 0; j < half_dim; j++) { + // 计算角度 φ = pos / θ^(2j/d) + float freq_exponent = (2.0f * j) / head_dim; + float freq = pos / std::pow(theta, freq_exponent); + float cos_freq = std::cos(freq); + float sin_freq = std::sin(freq); + + // 获取输入的 a 和 b + size_t idx = s * n_heads * head_dim + h * head_dim; + size_t a_idx = idx + j; + size_t b_idx = idx + half_dim + j; + + float a_val, b_val; + if constexpr (std::is_same_v || std::is_same_v) { + a_val = llaisys::utils::cast(in[a_idx]); + b_val = llaisys::utils::cast(in[b_idx]); + } else { + a_val = static_cast(in[a_idx]); + b_val = static_cast(in[b_idx]); + } + + // 应用旋转 + // a' = a*cos(φ) - b*sin(φ) + // b' = b*cos(φ) + a*sin(φ) + float a_prime = a_val * cos_freq - b_val * sin_freq; + float b_prime = b_val * cos_freq + a_val * sin_freq; + + // 存储结果 + if constexpr (std::is_same_v || std::is_same_v) { + out[a_idx] = llaisys::utils::cast(a_prime); + out[b_idx] = llaisys::utils::cast(b_prime); + } else { + out[a_idx] = static_cast(a_prime); + out[b_idx] = static_cast(b_prime); + } + } + } + } +} + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, + llaisysDataType_t type, size_t seq_len, size_t n_heads, size_t head_dim, float theta) { + const int64_t *pos_ptr = reinterpret_cast(pos_ids); + + switch (type) { + case LLAISYS_DTYPE_F32: + return rope_(reinterpret_cast(out), reinterpret_cast(in), + pos_ptr, seq_len, n_heads, head_dim, theta); + case LLAISYS_DTYPE_BF16: + return rope_(reinterpret_cast(out), reinterpret_cast(in), + pos_ptr, seq_len, n_heads, head_dim, theta); + case LLAISYS_DTYPE_F16: + return rope_(reinterpret_cast(out), reinterpret_cast(in), + pos_ptr, seq_len, n_heads, head_dim, theta); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 000000000..6bb6c0da6 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, + llaisysDataType_t type, size_t seq_len, size_t n_heads, size_t head_dim, float theta); +} diff --git a/src/ops/rope/nvidia/rope_cuda.cu b/src/ops/rope/nvidia/rope_cuda.cu new file mode 100644 index 000000000..bfadf337b --- /dev/null +++ b/src/ops/rope/nvidia/rope_cuda.cu @@ -0,0 +1,115 @@ +#include "rope_cuda.cuh" + +#include +#include +#include + +// ───────────────────────────────────────────────────────────────────────────── +// RoPE kernel +// grid: [seq_len, n_heads, half_dim/BLOCK_DIM] — 三维网格充分并行 +// block: BLOCK_DIM threads 处理同一 (s, h) 的维度对 +// +// 对每个 (s, h, j): +// freq = pos_ids[s] / theta^(2j/d) +// a' = a*cos(freq) - b*sin(freq) +// b' = b*cos(freq) + a*sin(freq) +// ───────────────────────────────────────────────────────────────────────────── +template +__global__ void rope_kernel(T *__restrict__ out, + const T *__restrict__ in, + const int64_t *__restrict__ pos_ids, + size_t n_heads, + size_t head_dim, + float theta) { + size_t s = blockIdx.x; // 序列位置 + size_t h = blockIdx.y; // attention head + size_t j = blockIdx.z * blockDim.x + threadIdx.x; // 维度对索引 + + size_t half_dim = head_dim / 2; + if (j >= half_dim) { + return; + } + + int64_t pos = pos_ids[s]; + + // freq = pos / theta^(2j/d) + float freq_exp = (2.0f * static_cast(j)) / static_cast(head_dim); + float freq = static_cast(pos) / powf(theta, freq_exp); + float cos_freq, sin_freq; + sincosf(freq, &sin_freq, &cos_freq); + + size_t a_idx = s * n_heads * head_dim + h * head_dim + j; + size_t b_idx = a_idx + half_dim; + + float a_val, b_val; + if constexpr (std::is_same_v) { + a_val = __half2float(in[a_idx]); + b_val = __half2float(in[b_idx]); + } else if constexpr (std::is_same_v) { + a_val = __bfloat162float(in[a_idx]); + b_val = __bfloat162float(in[b_idx]); + } else { + a_val = static_cast(in[a_idx]); + b_val = static_cast(in[b_idx]); + } + + float a_prime = a_val * cos_freq - b_val * sin_freq; + float b_prime = b_val * cos_freq + a_val * sin_freq; + + if constexpr (std::is_same_v) { + out[a_idx] = __float2half(a_prime); + out[b_idx] = __float2half(b_prime); + } else if constexpr (std::is_same_v) { + out[a_idx] = __float2bfloat16(a_prime); + out[b_idx] = __float2bfloat16(b_prime); + } else { + out[a_idx] = static_cast(a_prime); + out[b_idx] = static_cast(b_prime); + } +} + +namespace llaisys::ops::nvidia { + +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, + llaisysDataType_t type, size_t seq_len, size_t n_heads, + size_t head_dim, float theta) { + if (seq_len == 0 || n_heads == 0 || head_dim == 0) { + return; + } + + size_t half_dim = head_dim / 2; + const int BLOCK = 32; // 每个 block 处理 32 个维度对(一个 warp) + unsigned z_dim = static_cast((half_dim + BLOCK - 1) / BLOCK); + + dim3 grid(static_cast(seq_len), + static_cast(n_heads), + z_dim); + dim3 blk(static_cast(BLOCK)); + + const int64_t *pos_ptr = reinterpret_cast(pos_ids); + + switch (type) { + case LLAISYS_DTYPE_F32: + rope_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(in), + pos_ptr, n_heads, head_dim, theta); + break; + case LLAISYS_DTYPE_F16: + rope_kernel<__half><<>>( + reinterpret_cast<__half *>(out), + reinterpret_cast(in), + pos_ptr, n_heads, head_dim, theta); + break; + case LLAISYS_DTYPE_BF16: + rope_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16 *>(out), + reinterpret_cast(in), + pos_ptr, n_heads, head_dim, theta); + break; + default: + throw std::runtime_error("rope CUDA: unsupported data type"); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rope/nvidia/rope_cuda.cuh b/src/ops/rope/nvidia/rope_cuda.cuh new file mode 100644 index 000000000..0f5b1c374 --- /dev/null +++ b/src/ops/rope/nvidia/rope_cuda.cuh @@ -0,0 +1,13 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +// RoPE CUDA 实现 +// 每个 thread 处理一个 (sequence_pos, head, dim_pair) 三元组 +// 充分利用 A100 的线程级并行度 +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, + llaisysDataType_t type, size_t seq_len, size_t n_heads, + size_t head_dim, float theta); +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64e..23d1d4da4 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,56 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rope_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/rope_cuda.cuh" +#endif + namespace llaisys::ops { void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + // 检查设备一致性 + CHECK_SAME_DEVICE(out, in, pos_ids); + + // 检查数据类型 + CHECK_SAME_DTYPE(out->dtype(), in->dtype()); + CHECK_ARGUMENT(pos_ids->dtype() == LLAISYS_DTYPE_I64, "rope: pos_ids must be int64 type"); + + // 检查维度 + CHECK_ARGUMENT(in->ndim() == 3, "rope: input must be 3D tensor [seqlen, nhead, d]"); + CHECK_ARGUMENT(out->ndim() == 3, "rope: output must be 3D tensor [seqlen, nhead, d]"); + CHECK_ARGUMENT(pos_ids->ndim() == 1, "rope: pos_ids must be 1D tensor"); + + // 检查形状兼容性 + size_t seq_len = in->shape()[0]; + size_t n_heads = in->shape()[1]; + size_t head_dim = in->shape()[2]; + + CHECK_SAME_SHAPE(out->shape(), in->shape()); + CHECK_ARGUMENT(pos_ids->shape()[0] == seq_len, + "rope: pos_ids length must match sequence length"); + CHECK_ARGUMENT(head_dim % 2 == 0, "rope: head dimension must be even"); + + // 检查所有张量都是连续的 + ASSERT(out->isContiguous() && in->isContiguous() && pos_ids->isContiguous(), + "rope: all tensors must be contiguous."); + + // 设置设备上下文 + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rope(out->data(), in->data(), pos_ids->data(), + out->dtype(), seq_len, n_heads, head_dim, theta); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::rope(out->data(), in->data(), pos_ids->data(), + out->dtype(), seq_len, n_heads, head_dim, theta); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/self_attention/cpu/flash_attention_cpu.cpp b/src/ops/self_attention/cpu/flash_attention_cpu.cpp new file mode 100644 index 000000000..8f2815ca3 --- /dev/null +++ b/src/ops/self_attention/cpu/flash_attention_cpu.cpp @@ -0,0 +1,232 @@ +#include "flash_attention_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include +#include + +// ── 类型转换 helpers ────────────────────────────────────────────────────────── +template +static inline float to_f32(T v) { + if constexpr (std::is_same_v || std::is_same_v) { + return llaisys::utils::cast(v); + } else { + return static_cast(v); + } +} + +template +static inline T from_f32(float v) { + if constexpr (std::is_same_v || std::is_same_v) { + return llaisys::utils::cast(v); + } else { + return static_cast(v); + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// DECODE 路径 (seq_len == 1) +// +// 每次生成一个 token 时 seq_len=1,对全部 KV 位置做标准 attention。 +// 优化重点: +// • 直接访问原始 KV 布局,零复制(消除旧实现的 K_h/V_h 转置拷贝) +// • V 按行迭代(cache-friendly),消除旧实现的按列访问导致的 cache miss +// • 缓冲区在 head 循环外分配一次,所有 head 复用 +// ───────────────────────────────────────────────────────────────────────────── +template +static void decode_attention_(T *out, const T *q, const T *k, const T *v, + size_t total_len, + size_t n_heads, size_t n_kv_heads, size_t head_dim, + float scale) { + const size_t kv_stride = n_kv_heads * head_dim; + const size_t heads_per_kv = n_heads / n_kv_heads; + + // head 循环外分配,所有 head 复用 + std::vector scores(total_len); + std::vector out_f(head_dim); + + for (size_t h = 0; h < n_heads; h++) { + const size_t kv_h = h / heads_per_kv; + const T *q_row = q + h * head_dim; // q: [1, n_heads, head_dim] + + // ── QKᵀ ─────────────────────────────────────────────────────────── + for (size_t j = 0; j < total_len; j++) { + const T *k_row = k + j * kv_stride + kv_h * head_dim; + float dot = 0.0f; + for (size_t d = 0; d < head_dim; d++) { + dot += to_f32(q_row[d]) * to_f32(k_row[d]); + } + scores[j] = dot * scale; + } + + // ── Softmax ─────────────────────────────────────────────────────── + float max_s = scores[0]; + for (size_t j = 1; j < total_len; j++) { + if (scores[j] > max_s) { + max_s = scores[j]; + } + } + float sum_e = 0.0f; + for (size_t j = 0; j < total_len; j++) { + scores[j] = std::exp(scores[j] - max_s); + sum_e += scores[j]; + } + const float inv_sum = 1.0f / sum_e; + for (size_t j = 0; j < total_len; j++) { + scores[j] *= inv_sum; + } + + // ── O = attn · V (按行迭代 V,顺序内存访问,L1 友好) ───────────── + std::fill(out_f.begin(), out_f.end(), 0.0f); + for (size_t j = 0; j < total_len; j++) { + const float a_j = scores[j]; + const T *v_row = v + j * kv_stride + kv_h * head_dim; + for (size_t d = 0; d < head_dim; d++) { + out_f[d] += a_j * to_f32(v_row[d]); + } + } + + // ── 写回 ────────────────────────────────────────────────────────── + T *out_row = out + h * head_dim; + for (size_t d = 0; d < head_dim; d++) { + out_row[d] = from_f32(out_f[d]); + } + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// PREFILL 路径 (seq_len > 1) +// +// 与原版 self_attention_cpu 算法完全一致(标准两遍 softmax), +// 仅做内存访问优化:V 按行迭代(cache-friendly),attn_scores 复用缓冲区。 +// • Causal mask:第 i 个 query 可见位置 ≤ total_len - seq_len + i +// ───────────────────────────────────────────────────────────────────────────── +template +static void prefill_attention_(T *out, const T *q, const T *k, const T *v, + size_t seq_len, size_t total_len, + size_t n_heads, size_t n_kv_heads, size_t head_dim, + float scale) { + const size_t q_stride = n_heads * head_dim; + const size_t kv_stride = n_kv_heads * head_dim; + const size_t heads_per_kv = n_heads / n_kv_heads; + + // 在 head 循环外分配,所有 head 复用 + std::vector attn_scores(seq_len * total_len); + std::vector out_f(head_dim); + + for (size_t h = 0; h < n_heads; h++) { + const size_t kv_h = h / heads_per_kv; + + // ── S = Q · Kᵀ * scale,应用 causal mask ───────────────────────── + for (size_t i = 0; i < seq_len; i++) { + const size_t query_pos = total_len - seq_len + i; // 该 query 的绝对位置 + const T *q_row = q + i * q_stride + h * head_dim; + + for (size_t j = 0; j < total_len; j++) { + if (j > query_pos) { + attn_scores[i * total_len + j] = -std::numeric_limits::infinity(); + continue; + } + const T *k_row = k + j * kv_stride + kv_h * head_dim; + float dot = 0.0f; + for (size_t d = 0; d < head_dim; d++) { + dot += to_f32(q_row[d]) * to_f32(k_row[d]); + } + attn_scores[i * total_len + j] = dot * scale; + } + } + + // ── 逐行 Softmax ────────────────────────────────────────────────── + for (size_t i = 0; i < seq_len; i++) { + float *row = attn_scores.data() + i * total_len; + + // 找最大值(数值稳定) + float max_score = -std::numeric_limits::infinity(); + for (size_t j = 0; j < total_len; j++) { + if (row[j] > max_score) { + max_score = row[j]; + } + } + + // exp & sum(-inf → 0) + float sum_exp = 0.0f; + for (size_t j = 0; j < total_len; j++) { + if (std::isinf(row[j]) && row[j] < 0.0f) { + row[j] = 0.0f; + } else { + row[j] = std::exp(row[j] - max_score); + sum_exp += row[j]; + } + } + + // 归一化 + const float inv_sum = (sum_exp > 0.0f) ? (1.0f / sum_exp) : 0.0f; + for (size_t j = 0; j < total_len; j++) { + row[j] *= inv_sum; + } + } + + // ── O = attn · V (按行迭代 V,L1 友好) ────────────────────────── + for (size_t i = 0; i < seq_len; i++) { + const float *a_row = attn_scores.data() + i * total_len; + std::fill(out_f.begin(), out_f.end(), 0.0f); + + for (size_t j = 0; j < total_len; j++) { + const float a_j = a_row[j]; + if (a_j == 0.0f) { + continue; + } + const T *v_row = v + j * kv_stride + kv_h * head_dim; + for (size_t d = 0; d < head_dim; d++) { + out_f[d] += a_j * to_f32(v_row[d]); + } + } + + // 写回 + T *out_row = out + i * q_stride + h * head_dim; + for (size_t d = 0; d < head_dim; d++) { + out_row[d] = from_f32(out_f[d]); + } + } + } +} + +namespace llaisys::ops::cpu { +void flash_attention(std::byte *out, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, + size_t seq_len, size_t total_len, + size_t n_heads, size_t n_kv_heads, size_t head_dim, float scale) { + +// 根据 seq_len 分路径:decode (seq_len==1) 与 prefill (seq_len>1) +#define FA_DISPATCH(CPP_TYPE) \ + do { \ + auto *out_ = reinterpret_cast(out); \ + const auto *q_ = reinterpret_cast(q); \ + const auto *k_ = reinterpret_cast(k); \ + const auto *v_ = reinterpret_cast(v); \ + if (seq_len == 1) \ + decode_attention_(out_, q_, k_, v_, \ + total_len, n_heads, n_kv_heads, head_dim, scale); \ + else \ + prefill_attention_(out_, q_, k_, v_, \ + seq_len, total_len, n_heads, n_kv_heads, head_dim, scale); \ + } while (0) + + switch (type) { + case LLAISYS_DTYPE_F32: + FA_DISPATCH(float); + return; + case LLAISYS_DTYPE_BF16: + FA_DISPATCH(llaisys::bf16_t); + return; + case LLAISYS_DTYPE_F16: + FA_DISPATCH(llaisys::fp16_t); + return; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +#undef FA_DISPATCH +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/cpu/flash_attention_cpu.hpp b/src/ops/self_attention/cpu/flash_attention_cpu.hpp new file mode 100644 index 000000000..1faad49c3 --- /dev/null +++ b/src/ops/self_attention/cpu/flash_attention_cpu.hpp @@ -0,0 +1,18 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + +// Flash Attention 2 (CPU, tiled online-softmax) +// Supports Grouped Query Attention (GQA): n_heads % n_kv_heads == 0 +// Layout: Q[seq_len, n_heads, head_dim] +// K[total_len, n_kv_heads, head_dim] +// V[total_len, n_kv_heads, head_dim] +// out[seq_len, n_heads, head_dim] +void flash_attention(std::byte *out, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, + size_t seq_len, size_t total_len, + size_t n_heads, size_t n_kv_heads, size_t head_dim, float scale); + +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/cpu/self_attention_cpu.cpp b/src/ops/self_attention/cpu/self_attention_cpu.cpp new file mode 100644 index 000000000..5bf5c263f --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.cpp @@ -0,0 +1,144 @@ +#include "self_attention_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include +#include + +// Self-Attention: Y = causal_softmax(Q * K^T * scale) * V +template +void self_attention_(T *attn_val, const T *q, const T *k, const T *v, + size_t seq_len, size_t total_len, size_t n_heads, + size_t n_kv_heads, size_t head_dim, float scale) { + // 计算每个 KV head 对应多少个 Q head (Grouped Query Attention) + size_t heads_per_kv = n_heads / n_kv_heads; + + // 为注意力分数分配临时缓冲区 + std::vector attn_scores(seq_len * total_len); + + // 对每个 head 处理 + for (size_t h = 0; h < n_heads; h++) { + size_t kv_head = h / heads_per_kv; // 对应的 KV head + + // 计算注意力分数: A = Q * K^T * scale + for (size_t i = 0; i < seq_len; i++) { + for (size_t j = 0; j < total_len; j++) { + float score = 0.0f; + + // 点积: Q[i] · K[j] + for (size_t d = 0; d < head_dim; d++) { + size_t q_idx = i * n_heads * head_dim + h * head_dim + d; + size_t k_idx = j * n_kv_heads * head_dim + kv_head * head_dim + d; + + float q_val, k_val; + if constexpr (std::is_same_v || std::is_same_v) { + q_val = llaisys::utils::cast(q[q_idx]); + k_val = llaisys::utils::cast(k[k_idx]); + } else { + q_val = static_cast(q[q_idx]); + k_val = static_cast(k[k_idx]); + } + + score += q_val * k_val; + } + + score *= scale; + + // 应用因果掩码 (causal mask): 只能看到当前及之前的位置 + // 当前查询位置在序列中的绝对位置 + size_t query_pos = total_len - seq_len + i; + if (j > query_pos) { + score = -std::numeric_limits::infinity(); + } + + attn_scores[i * total_len + j] = score; + } + } + + // 对每一行应用 softmax + for (size_t i = 0; i < seq_len; i++) { + float *row = &attn_scores[i * total_len]; + + // 找到最大值(用于数值稳定性) + float max_score = -std::numeric_limits::infinity(); + for (size_t j = 0; j < total_len; j++) { + max_score = std::max(max_score, row[j]); + } + + // 计算 exp 和 sum + float sum_exp = 0.0f; + for (size_t j = 0; j < total_len; j++) { + if (std::isinf(row[j]) && row[j] < 0) { + row[j] = 0.0f; + } else { + row[j] = std::exp(row[j] - max_score); + sum_exp += row[j]; + } + } + + // 归一化 + for (size_t j = 0; j < total_len; j++) { + row[j] /= sum_exp; + } + } + + // 计算输出: Y = attention_weights * V + for (size_t i = 0; i < seq_len; i++) { + for (size_t d = 0; d < head_dim; d++) { + float output = 0.0f; + + for (size_t j = 0; j < total_len; j++) { + float attn_weight = attn_scores[i * total_len + j]; + size_t v_idx = j * n_kv_heads * head_dim + kv_head * head_dim + d; + + float v_val; + if constexpr (std::is_same_v || std::is_same_v) { + v_val = llaisys::utils::cast(v[v_idx]); + } else { + v_val = static_cast(v[v_idx]); + } + + output += attn_weight * v_val; + } + + size_t out_idx = i * n_heads * head_dim + h * head_dim + d; + if constexpr (std::is_same_v || std::is_same_v) { + attn_val[out_idx] = llaisys::utils::cast(output); + } else { + attn_val[out_idx] = static_cast(output); + } + } + } + } +} + +namespace llaisys::ops::cpu { +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, size_t seq_len, size_t total_len, + size_t n_heads, size_t n_kv_heads, size_t head_dim, float scale) { + switch (type) { + case LLAISYS_DTYPE_F32: + return self_attention_(reinterpret_cast(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seq_len, total_len, n_heads, n_kv_heads, head_dim, scale); + case LLAISYS_DTYPE_BF16: + return self_attention_(reinterpret_cast(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seq_len, total_len, n_heads, n_kv_heads, head_dim, scale); + case LLAISYS_DTYPE_F16: + return self_attention_(reinterpret_cast(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seq_len, total_len, n_heads, n_kv_heads, head_dim, scale); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/cpu/self_attention_cpu.hpp b/src/ops/self_attention/cpu/self_attention_cpu.hpp new file mode 100644 index 000000000..7b5ff5d8e --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.hpp @@ -0,0 +1,10 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, size_t seq_len, size_t total_len, + size_t n_heads, size_t n_kv_heads, size_t head_dim, float scale); +} diff --git a/src/ops/self_attention/nvidia/flash_attention_cuda.cu b/src/ops/self_attention/nvidia/flash_attention_cuda.cu new file mode 100644 index 000000000..30632e598 --- /dev/null +++ b/src/ops/self_attention/nvidia/flash_attention_cuda.cu @@ -0,0 +1,263 @@ +// Flash Attention 2 CUDA 实现(修正版) +// +// 修正记录: +// 1. FA_BC_CUDA 32→16:smem = (2×16+2×16)×128×4 + 2×16×4 = 32896B ≈ 32KB +// 原 FA_BC=32 时 smem=49280B 超过 48KB 默认上限,kernel 静默失败不写出, +// 输出 tensor 保持未初始化状态 → 纯随机乱码 +// 2. 使用 __int_as_float(0xff800000) 作为真正的 IEEE-754 -inf(替代 -1e30f) +// 原 -1e30f:exp(-1e30-(-1e30))=exp(0)=1,masked 位置被错赋权重 1 +// 3. 全 mask tile:m_new 仍为 -inf 时显式跳过,防止 -inf-(-inf)=NaN 传播 +// 4. m_old==-inf 时 alpha 显式置 0,防止 exp(-inf) 精度边界问题 +// 5. V 累积改为 j 外层 d 内层,V_s 顺序内存访问,减少 bank conflict + +#include "flash_attention_cuda.cuh" + +#include +#include +#include + +// ─── 精度辅助函数 ───────────────────────────────────────────────────────────── +__device__ __forceinline__ float to_f32_fa(__half v) { return __half2float(v); } +__device__ __forceinline__ float to_f32_fa(__nv_bfloat16 v) { return __bfloat162float(v); } +__device__ __forceinline__ float to_f32_fa(float v) { return v; } + +__device__ __forceinline__ __half from_f32_fa_h(float v) { return __float2half(v); } +__device__ __forceinline__ __nv_bfloat16 from_f32_fa_b(float v) { return __float2bfloat16(v); } + +// ─── tile 参数 ──────────────────────────────────────────────────────────────── +// FA_BC_CUDA=16 确保对 head_dim≤128 时 smem ≈ 32KB < 48KB(所有 CUDA GPU 安全) +// 原 FA_BC=32:smem = (16+64+16)×128×4 + 128 = 49280B > 48KB,kernel 静默失败 +static constexpr int FA_BR_CUDA = 16; // Q tile 行数(= blockDim.x) +static constexpr int FA_BC_CUDA = 16; // KV tile 列数(从 32 降为 16) + +// ─── Flash Attention 2 核心 kernel ─────────────────────────────────────────── +// grid = (n_heads, ceil(seq_len / FA_BR)) +// block = (FA_BR, 1) — 每线程负责 Q-tile 中一行 query +template +__global__ void flash_attn_kernel( + T *__restrict__ out, // [seq_len, n_heads, head_dim] + const T *__restrict__ Q, // [seq_len, n_heads, head_dim] + const T *__restrict__ K, // [total_len, n_kv_heads, head_dim] + const T *__restrict__ V, // [total_len, n_kv_heads, head_dim] + int seq_len, int total_len, + int n_heads, int n_kv_heads, int head_dim, float scale, + int heads_per_kv) { + // IEEE-754 负无穷,通过位模式构造,避免 __builtin_inff() 在某些 CUDA 版本的行为差异 + const float NEG_INF = __int_as_float(0xff800000u); + + const int h = blockIdx.x; + const int kv_h = h / heads_per_kv; + const int q_start = blockIdx.y * FA_BR_CUDA; + if (q_start >= seq_len) { + return; + } + + const int q_end = min(q_start + FA_BR_CUDA, seq_len); + const int q_len = q_end - q_start; + const int tx = threadIdx.x; // 0 .. FA_BR_CUDA-1 + + // ── Shared memory 布局 ─────────────────────────────────────────────────── + // Q_s [FA_BR, head_dim] K_s [FA_BC, head_dim] + // V_s [FA_BC, head_dim] O_s [FA_BR, head_dim] + // m_s [FA_BR] l_s [FA_BR] + extern __shared__ float smem[]; + float *Q_s = smem; + float *K_s = Q_s + FA_BR_CUDA * head_dim; + float *V_s = K_s + FA_BC_CUDA * head_dim; + float *O_s = V_s + FA_BC_CUDA * head_dim; + float *m_s = O_s + FA_BR_CUDA * head_dim; + float *l_s = m_s + FA_BR_CUDA; + + // ── 初始化:载入 Q tile,重置 O/m/l ───────────────────────────────────── + if (tx < q_len) { + m_s[tx] = NEG_INF; // 真正的 -inf,确保第一有效 tile 时 alpha=0 + l_s[tx] = 0.0f; + + const int abs_q = q_start + tx; + const T *q_ptr = Q + abs_q * n_heads * head_dim + h * head_dim; + float *q_row = Q_s + tx * head_dim; + float *o_row = O_s + tx * head_dim; + for (int d = 0; d < head_dim; d++) { + q_row[d] = to_f32_fa(q_ptr[d]); + o_row[d] = 0.0f; + } + } + __syncthreads(); + + // ── 遍历 KV tiles ───────────────────────────────────────────────────────── + for (int kv_start = 0; kv_start < total_len; kv_start += FA_BC_CUDA) { + const int kv_end = min(kv_start + FA_BC_CUDA, total_len); + const int kv_len = kv_end - kv_start; + + // 协作加载 K/V tile(FA_BR 个线程分摊 FA_BC 行) + for (int row = tx; row < kv_len; row += FA_BR_CUDA) { + const int abs_kv = kv_start + row; + const T *k_ptr = K + abs_kv * n_kv_heads * head_dim + kv_h * head_dim; + const T *v_ptr = V + abs_kv * n_kv_heads * head_dim + kv_h * head_dim; + float *k_row = K_s + row * head_dim; + float *v_row = V_s + row * head_dim; + for (int d = 0; d < head_dim; d++) { + k_row[d] = to_f32_fa(k_ptr[d]); + v_row[d] = to_f32_fa(v_ptr[d]); + } + } + __syncthreads(); + + // ── 每线程独立处理其负责的 query 行 ───────────────────────────────── + if (tx < q_len) { + const int abs_q = q_start + tx; + const int causal_lim = total_len - seq_len + abs_q; // causal mask 上界(含) + + // ① 计算 score tile 并找非 mask 位置最大值 + float s_local[FA_BC_CUDA]; // FA_BC_CUDA=16 → 64B,保持在寄存器中 + float local_max = NEG_INF; + + const float *q_row = Q_s + tx * head_dim; + for (int j = 0; j < kv_len; j++) { + const int abs_kv = kv_start + j; + if (abs_kv > causal_lim) { + s_local[j] = NEG_INF; // 精确 -inf,exp(-inf)=0 + continue; + } + const float *k_row = K_s + j * head_dim; + float dot = 0.0f; + for (int d = 0; d < head_dim; d++) { + dot += q_row[d] * k_row[d]; + } + s_local[j] = dot * scale; + if (s_local[j] > local_max) { + local_max = s_local[j]; + } + } + + // ② Online softmax 更新 + const float m_old = m_s[tx]; + const float m_new = (local_max > m_old) ? local_max : m_old; + + // 若 m_new 仍为 -inf:本 tile 全被 causal mask,且历史无有效分数 + // → O/l/m 保持不变,直接跳过,避免 -inf-(-inf)=NaN + if (!(__isinff(m_new) && m_new < 0.0f)) { + + // alpha:历史累积 O 的缩放系数 + // m_old==-inf 表示 O 尚未写入任何值,alpha 直接为 0 + const float alpha = __isinff(m_old) ? 0.0f : __expf(m_old - m_new); + + // ③ 计算 exp(s - m_new),mask 位置显式置 0 + float l_tile = 0.0f; + for (int j = 0; j < kv_len; j++) { + const float e = (__isinff(s_local[j]) && s_local[j] < 0.0f) + ? 0.0f + : __expf(s_local[j] - m_new); + s_local[j] = e; + l_tile += e; + } + + // ④ 更新 O:缩放旧值 + 累积本 tile 的 V 贡献 + float *o_row = O_s + tx * head_dim; + for (int d = 0; d < head_dim; d++) { + o_row[d] *= alpha; + } + // j 外 d 内:V_s[j*hd+d] 顺序读取,减少 bank conflict + for (int j = 0; j < kv_len; j++) { + const float sj = s_local[j]; + if (sj == 0.0f) { + continue; + } + const float *v_row = V_s + j * head_dim; + for (int d = 0; d < head_dim; d++) { + o_row[d] += sj * v_row[d]; + } + } + + l_s[tx] = alpha * l_s[tx] + l_tile; + m_s[tx] = m_new; + } + // else: 全 mask tile → O/l/m 不变 + } + __syncthreads(); + } + + // ── 最终归一化并写回全局内存 ───────────────────────────────────────────── + if (tx < q_len) { + const int abs_q = q_start + tx; + const float inv_l = (l_s[tx] > 0.0f) ? (1.0f / l_s[tx]) : 0.0f; + const float *o_row = O_s + tx * head_dim; + T *o_ptr = out + abs_q * n_heads * head_dim + h * head_dim; + for (int d = 0; d < head_dim; d++) { + if constexpr (std::is_same_v) { + o_ptr[d] = o_row[d] * inv_l; + } else if constexpr (std::is_same_v) { + o_ptr[d] = from_f32_fa_h(o_row[d] * inv_l); + } else { // __nv_bfloat16 + o_ptr[d] = from_f32_fa_b(o_row[d] * inv_l); + } + } + } +} + +// ─── kernel 启动封装 ────────────────────────────────────────────────────────── +template +static void launch_flash_attn(T *out, const T *q, const T *k, const T *v, + int seq_len, int total_len, + int n_heads, int n_kv_heads, int head_dim, float scale) { + const int heads_per_kv = n_heads / n_kv_heads; + const int num_q_tiles = (seq_len + FA_BR_CUDA - 1) / FA_BR_CUDA; + + dim3 grid(n_heads, num_q_tiles); + dim3 block(FA_BR_CUDA, 1); + + // smem = (2×BR + 2×BC) × head_dim × 4 + 2×BR × 4 + // BR=BC=16, head_dim=128: (32+32)×512 + 128 = 32896B ≈ 32KB < 48KB ✓ + const size_t smem = static_cast(2 * FA_BR_CUDA + 2 * FA_BC_CUDA) * head_dim * sizeof(float) + + static_cast(2 * FA_BR_CUDA) * sizeof(float); + + flash_attn_kernel<<>>( + out, q, k, v, seq_len, total_len, + n_heads, n_kv_heads, head_dim, scale, heads_per_kv); +} + +namespace llaisys::ops::nvidia { + +void flash_attention(std::byte *attn_val, + const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, + size_t seq_len, size_t total_len, + size_t n_heads, size_t n_kv_heads, + size_t head_dim, float scale) { + switch (type) { + case LLAISYS_DTYPE_F32: + launch_flash_attn( + reinterpret_cast(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + static_cast(seq_len), static_cast(total_len), + static_cast(n_heads), static_cast(n_kv_heads), + static_cast(head_dim), scale); + break; + case LLAISYS_DTYPE_F16: + launch_flash_attn( + reinterpret_cast<__half *>(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + static_cast(seq_len), static_cast(total_len), + static_cast(n_heads), static_cast(n_kv_heads), + static_cast(head_dim), scale); + break; + case LLAISYS_DTYPE_BF16: + launch_flash_attn( + reinterpret_cast<__nv_bfloat16 *>(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + static_cast(seq_len), static_cast(total_len), + static_cast(n_heads), static_cast(n_kv_heads), + static_cast(head_dim), scale); + break; + default: + throw std::runtime_error("flash_attention CUDA: unsupported data type"); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/self_attention/nvidia/flash_attention_cuda.cuh b/src/ops/self_attention/nvidia/flash_attention_cuda.cuh new file mode 100644 index 000000000..eea7c69a2 --- /dev/null +++ b/src/ops/self_attention/nvidia/flash_attention_cuda.cuh @@ -0,0 +1,17 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { + +// Flash Attention 2 CUDA 实现(shared memory tiling,online softmax) +// 支持 GQA (n_heads % n_kv_heads == 0) +// 支持 F32 / F16 / BF16 +void flash_attention(std::byte *attn_val, + const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, + size_t seq_len, size_t total_len, + size_t n_heads, size_t n_kv_heads, + size_t head_dim, float scale); + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/self_attention/nvidia/self_attention_cuda.cu b/src/ops/self_attention/nvidia/self_attention_cuda.cu new file mode 100644 index 000000000..160340fea --- /dev/null +++ b/src/ops/self_attention/nvidia/self_attention_cuda.cu @@ -0,0 +1,358 @@ +#include "self_attention_cuda.cuh" + +#include "../../../device/nvidia/nvidia_resource.cuh" + +#include +#include +#include +#include +#include + +#define CUBLAS_CHECK(call) \ + do { \ + cublasStatus_t _st = (call); \ + if (_st != CUBLAS_STATUS_SUCCESS) { \ + std::ostringstream _oss; \ + _oss << "cuBLAS error " << static_cast(_st) \ + << " at " << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(_oss.str()); \ + } \ + } while (0) + +// ───────────────────────────────────────────────────────────────────────────── +// Warp reduce sum(用于 softmax) +// ───────────────────────────────────────────────────────────────────────────── +__device__ __forceinline__ float warpReduceSum_attn(float val) { + constexpr unsigned FULL_MASK = 0xffffffff; +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(FULL_MASK, val, offset); + } + return val; +} + +__device__ __forceinline__ float warpReduceMax_attn(float val) { + constexpr unsigned FULL_MASK = 0xffffffff; +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val = fmaxf(val, __shfl_down_sync(FULL_MASK, val, offset)); + } + return val; +} + +// ───────────────────────────────────────────────────────────────────────────── +// 因果 Softmax kernel +// gridDim.x = seq_len * n_heads(每行一个 block) +// 每行长度 = total_len +// ───────────────────────────────────────────────────────────────────────────── +__global__ void causal_softmax_kernel(float *__restrict__ attn_scores, + size_t seq_len, + size_t total_len, + size_t n_heads) { + extern __shared__ float smem[]; + + size_t row = blockIdx.x; // row = query_pos * n_heads + head_id + size_t q_idx = row / n_heads; + // 当前 query 在序列中的绝对位置(total_len - seq_len + q_idx) + size_t query_pos = total_len - seq_len + q_idx; + + float *row_ptr = attn_scores + row * total_len; + + int num_warps = (blockDim.x + 31) / 32; + + // ── 阶段1:找最大值(因果掩码 j > query_pos 置 -inf)──────────────── + float local_max = -3.402823466e+38f; + for (size_t j = threadIdx.x; j < total_len; j += blockDim.x) { + float v = (j <= query_pos) ? row_ptr[j] : -3.402823466e+38f; + local_max = fmaxf(local_max, v); + } + local_max = warpReduceMax_attn(local_max); + int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32; + if (lane_id == 0) { + smem[warp_id] = local_max; + } + __syncthreads(); + float block_max = (threadIdx.x < static_cast(num_warps)) ? smem[threadIdx.x] : -3.402823466e+38f; + if (threadIdx.x < 32) { + block_max = warpReduceMax_attn(block_max); + } + if (threadIdx.x == 0) { + smem[0] = block_max; + } + __syncthreads(); + float max_val = smem[0]; + + // ── 阶段2:计算 exp(x - max) 并求和 ────────────────────────────── + float local_sum = 0.0f; + for (size_t j = threadIdx.x; j < total_len; j += blockDim.x) { + float v; + if (j <= query_pos) { + v = __expf(row_ptr[j] - max_val); + row_ptr[j] = v; + } else { + row_ptr[j] = 0.0f; + v = 0.0f; + } + local_sum += v; + } + local_sum = warpReduceSum_attn(local_sum); + if (lane_id == 0) { + smem[warp_id] = local_sum; + } + __syncthreads(); + float block_sum = (threadIdx.x < static_cast(num_warps)) ? smem[threadIdx.x] : 0.0f; + if (threadIdx.x < 32) { + block_sum = warpReduceSum_attn(block_sum); + } + if (threadIdx.x == 0) { + smem[0] = block_sum; + } + __syncthreads(); + float inv_sum = 1.0f / (smem[0] + 1e-12f); + + // ── 阶段3:归一化 ──────────────────────────────────────────────── + for (size_t j = threadIdx.x; j < total_len; j += blockDim.x) { + row_ptr[j] *= inv_sum; + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// 将 T 精度的 Q/K/V 提升到 float(用于中间计算),或保持 float 不变 +// ───────────────────────────────────────────────────────────────────────────── +template +__global__ void cast_to_f32_kernel(float *__restrict__ dst, + const T *__restrict__ src, + size_t n) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= n) { + return; + } + if constexpr (std::is_same_v) { + dst[idx] = __half2float(src[idx]); + } else if constexpr (std::is_same_v) { + dst[idx] = __bfloat162float(src[idx]); + } else { + dst[idx] = static_cast(src[idx]); + } +} + +template +__global__ void cast_from_f32_kernel(T *__restrict__ dst, + const float *__restrict__ src, + size_t n) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= n) { + return; + } + if constexpr (std::is_same_v) { + dst[idx] = __float2half(src[idx]); + } else if constexpr (std::is_same_v) { + dst[idx] = __float2bfloat16(src[idx]); + } else { + dst[idx] = static_cast(src[idx]); + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// GQA/MHA self-attention 核心逻辑(在 float 精度下计算) +// +// Q[seq_len, n_heads, head_dim] → 每个 Q head 对应一个 KV head +// K[total_len, n_kv_heads, head_dim] +// V[total_len, n_kv_heads, head_dim] +// O[seq_len, n_heads, head_dim] +// +// 对每个 head h: +// kv_head = h / (n_heads / n_kv_heads) +// Qh = Q[:, h, :] [seq_len, head_dim] +// Kh = K[:, kv_head, :] [total_len, head_dim] +// Vh = V[:, kv_head, :] [total_len, head_dim] +// Sh = Qh * Kh^T * scale [seq_len, total_len] +// Oh = softmax_causal(Sh) * Vh [seq_len, head_dim] +// ───────────────────────────────────────────────────────────────────────────── +template +static void self_attention_impl(float *out_f32, + const float *q_f32, + const float *k_f32, + const float *v_f32, + size_t seq_len, size_t total_len, + size_t n_heads, size_t n_kv_heads, + size_t head_dim, float scale) { + cublasHandle_t handle = llaisys::device::nvidia::getCublasHandle(); + size_t heads_per_kv = n_heads / n_kv_heads; + + // 临时显存:注意力分数矩阵 [seq_len * total_len](每 head 复用) + float *d_scores = nullptr; + cudaMalloc(&d_scores, seq_len * total_len * sizeof(float)); + + for (size_t h = 0; h < n_heads; h++) { + size_t kv_head = h / heads_per_kv; + + // Qh:步长 = n_heads * head_dim,偏移 = h * head_dim + const float *Qh = q_f32 + h * head_dim; + // Kh:步长 = n_kv_heads * head_dim,偏移 = kv_head * head_dim + const float *Kh = k_f32 + kv_head * head_dim; + // Vh:同 Kh 布局 + const float *Vh = v_f32 + kv_head * head_dim; + // Oh:步长 = n_heads * head_dim,偏移 = h * head_dim + float *Oh = out_f32 + h * head_dim; + + // ── S = Q * K^T * scale ──────────────────────────────────────────── + // Qh: [seq_len, head_dim] 行主序,stride = n_heads * head_dim + // Kh: [total_len, head_dim] 行主序,stride = n_kv_heads * head_dim + // S: [seq_len, total_len] 行主序 + // + // Row-major → cuBLAS 列主序变换: + // Kh (row-major) = Kh^T (列主序),列步长 = n_kv_heads * head_dim + // Qh (row-major) = Qh^T (列主序),列步长 = n_heads * head_dim + // + // S^T[total_len x seq_len] = op(Kh^T) * op(Qh^T) + // = CUBLAS_OP_T(Kh^T) * CUBLAS_OP_N(Qh^T) + // = Kh * Qh^T ✓ + // + // lda 约束:CUBLAS_OP_T 时 lda >= K = head_dim,n_kv_heads*head_dim >= head_dim ✓ + // ldb 约束:CUBLAS_OP_N 时 ldb >= K = head_dim,n_heads*head_dim >= head_dim ✓ + // ldc 约束:ldc >= M = total_len ✓(d_scores 连续分配) + // + int M = static_cast(total_len); + int N = static_cast(seq_len); + int K = static_cast(head_dim); + int lda = static_cast(n_kv_heads * head_dim); // Kh^T 列步长 + int ldb = static_cast(n_heads * head_dim); // Qh^T 列步长 + int ldc = static_cast(total_len); + + const float beta0 = 0.0f; + CUBLAS_CHECK(cublasSgemm(handle, + CUBLAS_OP_T, CUBLAS_OP_N, // ← 修正:原为 OP_N, OP_T + M, N, K, + &scale, + Kh, lda, + Qh, ldb, + &beta0, + d_scores, ldc)); + + // ── 因果 Softmax ─────────────────────────────────────────────────── + { + int blk = static_cast(total_len < 1024 ? total_len : 1024); + blk = ((blk + 31) / 32) * 32; + int num_warps = (blk + 31) / 32; + size_t smem = static_cast(num_warps) * sizeof(float); + // 每行 = 一个 query position × head = (q_idx * n_heads + h) 对应行 + // 但我们对每个 head 单独调用,scores 是 [seq_len, total_len] + causal_softmax_kernel<<(seq_len), blk, smem>>>( + d_scores, seq_len, total_len, 1 /*n_heads=1, 已 per-head*/); + } + + // ── O = softmax(S) * V ───────────────────────────────────────────── + // softmax(S): [seq_len, total_len] 行主序 + // Vh: [total_len, head_dim] 行主序,stride = n_kv_heads * head_dim + // Oh: [seq_len, head_dim] 行主序,stride = n_heads * head_dim + // + // cuBLAS:Oh^T = Vh * softmax(S)^T + // cublasSgemm(OP_N, OP_N, head_dim, seq_len, total_len, + // 1, Vh, lda_v, S, total_len, 0, Oh, ldo) + { + int M2 = static_cast(head_dim); + int N2 = static_cast(seq_len); + int K2 = static_cast(total_len); + int lda2 = static_cast(n_kv_heads * head_dim); // Vh row stride + int ldb2 = static_cast(total_len); // S row stride + int ldc2 = static_cast(n_heads * head_dim); // Oh row stride + + const float alpha1 = 1.0f; + CUBLAS_CHECK(cublasSgemm(handle, + CUBLAS_OP_N, CUBLAS_OP_N, + M2, N2, K2, + &alpha1, + Vh, lda2, + d_scores, ldb2, + &beta0, + Oh, ldc2)); + } + } + + cudaFree(d_scores); +} + +namespace llaisys::ops::nvidia { + +void self_attention(std::byte *attn_val, + const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, + size_t seq_len, size_t total_len, + size_t n_heads, size_t n_kv_heads, + size_t head_dim, float scale) { + size_t q_numel = seq_len * n_heads * head_dim; + size_t k_numel = total_len * n_kv_heads * head_dim; + size_t v_numel = k_numel; + size_t o_numel = q_numel; + + if (type == LLAISYS_DTYPE_F32) { + // f32 直接传入,无需拷贝 + self_attention_impl( + reinterpret_cast(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seq_len, total_len, n_heads, n_kv_heads, head_dim, scale); + return; + } + + // f16/bf16:提升到 f32 后在中间计算,结果转回 + float *q_f32, *k_f32, *v_f32, *o_f32; + cudaMalloc(&q_f32, q_numel * sizeof(float)); + cudaMalloc(&k_f32, k_numel * sizeof(float)); + cudaMalloc(&v_f32, v_numel * sizeof(float)); + cudaMalloc(&o_f32, o_numel * sizeof(float)); + + constexpr int BLOCK = 256; + auto grid = [&](size_t n) { return static_cast((n + BLOCK - 1) / BLOCK); }; + + switch (type) { + case LLAISYS_DTYPE_F16: + cast_to_f32_kernel<__half><<>>( + q_f32, reinterpret_cast(q), q_numel); + cast_to_f32_kernel<__half><<>>( + k_f32, reinterpret_cast(k), k_numel); + cast_to_f32_kernel<__half><<>>( + v_f32, reinterpret_cast(v), v_numel); + break; + case LLAISYS_DTYPE_BF16: + cast_to_f32_kernel<__nv_bfloat16><<>>( + q_f32, reinterpret_cast(q), q_numel); + cast_to_f32_kernel<__nv_bfloat16><<>>( + k_f32, reinterpret_cast(k), k_numel); + cast_to_f32_kernel<__nv_bfloat16><<>>( + v_f32, reinterpret_cast(v), v_numel); + break; + default: + cudaFree(q_f32); + cudaFree(k_f32); + cudaFree(v_f32); + cudaFree(o_f32); + throw std::runtime_error("self_attention CUDA: unsupported data type"); + } + + self_attention_impl( + o_f32, q_f32, k_f32, v_f32, + seq_len, total_len, n_heads, n_kv_heads, head_dim, scale); + + switch (type) { + case LLAISYS_DTYPE_F16: + cast_from_f32_kernel<__half><<>>( + reinterpret_cast<__half *>(attn_val), o_f32, o_numel); + break; + case LLAISYS_DTYPE_BF16: + cast_from_f32_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16 *>(attn_val), o_f32, o_numel); + break; + default: + break; + } + + cudaFree(q_f32); + cudaFree(k_f32); + cudaFree(v_f32); + cudaFree(o_f32); +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/self_attention/nvidia/self_attention_cuda.cuh b/src/ops/self_attention/nvidia/self_attention_cuda.cuh new file mode 100644 index 000000000..9a2ef2da5 --- /dev/null +++ b/src/ops/self_attention/nvidia/self_attention_cuda.cuh @@ -0,0 +1,18 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +// Self-Attention CUDA 实现(支持 GQA) +// 流程: +// 1. 对每个 head,用 cuBLAS 计算 S = Q * K^T * scale +// 2. 用自定义 kernel 对 S 做因果 softmax +// 3. 用 cuBLAS 计算 O = softmax(S) * V +void self_attention(std::byte *attn_val, + const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, + size_t seq_len, size_t total_len, + size_t n_heads, size_t n_kv_heads, + size_t head_dim, float scale); +} // namespace llaisys::ops::nvidia diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d620142..ff740a64f 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,88 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/flash_attention_cpu.hpp" +#include "cpu/self_attention_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/flash_attention_cuda.cuh" +#include "nvidia/self_attention_cuda.cuh" +#endif + +// 定义此宏以启用 Flash Attention 2 后端 +// 对序列较长时内存占用和速度均有提升 +#define USE_FLASH_ATTENTION + namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + // 检查设备一致性 + CHECK_SAME_DEVICE(attn_val, q, k, v); + + // 检查数据类型一致性 + CHECK_SAME_DTYPE(attn_val->dtype(), q->dtype(), k->dtype(), v->dtype()); + + // 检查维度 + CHECK_ARGUMENT(q->ndim() == 3, "self_attention: q must be 3D tensor [seqlen, nhead, d]"); + CHECK_ARGUMENT(k->ndim() == 3, "self_attention: k must be 3D tensor [total_len, nkvhead, d]"); + CHECK_ARGUMENT(v->ndim() == 3, "self_attention: v must be 3D tensor [total_len, nkvhead, dv]"); + CHECK_ARGUMENT(attn_val->ndim() == 3, "self_attention: attn_val must be 3D tensor [seqlen, nhead, dv]"); + + // 获取形状参数 + size_t seq_len = q->shape()[0]; + size_t n_heads = q->shape()[1]; + size_t head_dim = q->shape()[2]; + + size_t total_len = k->shape()[0]; + size_t n_kv_heads = k->shape()[1]; + size_t k_head_dim = k->shape()[2]; + + size_t v_total_len = v->shape()[0]; + size_t v_kv_heads = v->shape()[1]; + size_t v_head_dim = v->shape()[2]; + + // 检查形状兼容性 + CHECK_ARGUMENT(k_head_dim == head_dim, "self_attention: k and q must have same head dimension"); + CHECK_ARGUMENT(total_len == v_total_len, "self_attention: k and v must have same sequence length"); + CHECK_ARGUMENT(n_kv_heads == v_kv_heads, "self_attention: k and v must have same number of kv heads"); + CHECK_ARGUMENT(n_heads % n_kv_heads == 0, "self_attention: n_heads must be divisible by n_kv_heads"); + + CHECK_ARGUMENT(attn_val->shape()[0] == seq_len, "self_attention: attn_val seq_len must match q"); + CHECK_ARGUMENT(attn_val->shape()[1] == n_heads, "self_attention: attn_val n_heads must match q"); + CHECK_ARGUMENT(attn_val->shape()[2] == v_head_dim, "self_attention: attn_val head_dim must match v"); + + // 检查所有张量都是连续的 + ASSERT(attn_val->isContiguous() && q->isContiguous() && k->isContiguous() && v->isContiguous(), + "self_attention: all tensors must be contiguous."); + + // 设置设备上下文 + llaisys::core::context().setDevice(attn_val->deviceType(), attn_val->deviceId()); + + switch (attn_val->deviceType()) { + case LLAISYS_DEVICE_CPU: +#ifdef USE_FLASH_ATTENTION + return cpu::flash_attention(attn_val->data(), q->data(), k->data(), v->data(), + attn_val->dtype(), seq_len, total_len, + n_heads, n_kv_heads, v_head_dim, scale); +#else + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), + attn_val->dtype(), seq_len, total_len, n_heads, n_kv_heads, v_head_dim, scale); +#endif +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: +#ifdef USE_FLASH_ATTENTION + return nvidia::flash_attention(attn_val->data(), q->data(), k->data(), v->data(), + attn_val->dtype(), seq_len, total_len, + n_heads, n_kv_heads, v_head_dim, scale); +#else + return nvidia::self_attention(attn_val->data(), q->data(), k->data(), v->data(), + attn_val->dtype(), seq_len, total_len, + n_heads, n_kv_heads, v_head_dim, scale); +#endif +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 000000000..38d1d4d59 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,59 @@ +#include "swiglu_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +// SwiGLU: out_i = up_i * (gate_i / (1 + e^(-gate_i))) +// 等价于: out_i = up_i * sigmoid(gate_i) +template +void swiglu_(T *out, const T *gate, const T *up, size_t numel) { + for (size_t i = 0; i < numel; i++) { + float gate_val, up_val; + + if constexpr (std::is_same_v || std::is_same_v) { + gate_val = llaisys::utils::cast(gate[i]); + up_val = llaisys::utils::cast(up[i]); + } else { + gate_val = static_cast(gate[i]); + up_val = static_cast(up[i]); + } + + // 计算 sigmoid(gate) = gate_val / (1 + e^(-gate)) + float sigmoid = gate_val / (1.0f + std::exp(-gate_val)); + + // out = up * sigmoid(gate) + float result = up_val * sigmoid; + + if constexpr (std::is_same_v || std::is_same_v) { + out[i] = llaisys::utils::cast(result); + } else { + out[i] = static_cast(result); + } + } +} + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t type, size_t numel) { + switch (type) { + case LLAISYS_DTYPE_F32: + return swiglu_(reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + case LLAISYS_DTYPE_BF16: + return swiglu_(reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + case LLAISYS_DTYPE_F16: + return swiglu_(reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 000000000..e2a71c68e --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t type, size_t numel); +} diff --git a/src/ops/swiglu/nvidia/swiglu_cuda.cu b/src/ops/swiglu/nvidia/swiglu_cuda.cu new file mode 100644 index 000000000..f541c25b0 --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_cuda.cu @@ -0,0 +1,118 @@ +#include "swiglu_cuda.cuh" + +#include +#include +#include + +// ───────────────────────────────────────────────────────────────────────────── +// SwiGLU kernel:out = up * silu(gate) = up * (gate / (1 + exp(-gate))) +// 使用 __expf 快速单精度指数,A100 上比双精度快 ~4× +// ───────────────────────────────────────────────────────────────────────────── +template +__global__ void swiglu_kernel(T *__restrict__ out, + const T *__restrict__ gate, + const T *__restrict__ up, + size_t n) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= n) { + return; + } + + float g_val, u_val; + if constexpr (std::is_same_v) { + g_val = __half2float(gate[idx]); + u_val = __half2float(up[idx]); + } else if constexpr (std::is_same_v) { + g_val = __bfloat162float(gate[idx]); + u_val = __bfloat162float(up[idx]); + } else { + g_val = static_cast(gate[idx]); + u_val = static_cast(up[idx]); + } + + // silu(x) = x * sigmoid(x) = x / (1 + exp(-x)) + float silu_val = g_val / (1.0f + __expf(-g_val)); + float result = u_val * silu_val; + + if constexpr (std::is_same_v) { + out[idx] = __float2half(result); + } else if constexpr (std::is_same_v) { + out[idx] = __float2bfloat16(result); + } else { + out[idx] = static_cast(result); + } +} + +// float4 向量化版本(f32,numel 为 4 的倍数时使用,减少 kernel launch overhead) +__global__ void swiglu_f32x4_kernel(float *__restrict__ out, + const float *__restrict__ gate, + const float *__restrict__ up, + size_t n4) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= n4) { + return; + } + + float4 g4 = reinterpret_cast(gate)[idx]; + float4 u4 = reinterpret_cast(up)[idx]; + float4 r4; + r4.x = u4.x * (g4.x / (1.0f + __expf(-g4.x))); + r4.y = u4.y * (g4.y / (1.0f + __expf(-g4.y))); + r4.z = u4.z * (g4.z / (1.0f + __expf(-g4.z))); + r4.w = u4.w * (g4.w / (1.0f + __expf(-g4.w))); + reinterpret_cast(out)[idx] = r4; +} + +namespace llaisys::ops::nvidia { + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t type, size_t numel) { + constexpr int BLOCK = 256; + + switch (type) { + case LLAISYS_DTYPE_F32: { + size_t n4 = numel / 4; + size_t tail = numel % 4; + if (n4 > 0) { + int grid = static_cast((n4 + BLOCK - 1) / BLOCK); + swiglu_f32x4_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + n4); + } + if (tail > 0) { + int grid = static_cast((tail + BLOCK - 1) / BLOCK); + size_t offset = n4 * 4; + swiglu_kernel<<>>( + reinterpret_cast(out) + offset, + reinterpret_cast(gate) + offset, + reinterpret_cast(up) + offset, + tail); + } + break; + } + case LLAISYS_DTYPE_F16: { + int grid = static_cast((numel + BLOCK - 1) / BLOCK); + swiglu_kernel<__half><<>>( + reinterpret_cast<__half *>(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + break; + } + case LLAISYS_DTYPE_BF16: { + int grid = static_cast((numel + BLOCK - 1) / BLOCK); + swiglu_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16 *>(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + break; + } + default: + throw std::runtime_error("swiglu CUDA: unsupported data type"); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/swiglu/nvidia/swiglu_cuda.cuh b/src/ops/swiglu/nvidia/swiglu_cuda.cuh new file mode 100644 index 000000000..66f44e19a --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_cuda.cuh @@ -0,0 +1,11 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +// SwiGLU CUDA 实现 +// out_i = up_i * silu(gate_i),向量化处理 f32/f16/bf16 +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t type, size_t numel); +} // namespace llaisys::ops::nvidia diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc97..7d9ebf3ac 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,43 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/swiglu_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/swiglu_cuda.cuh" +#endif + namespace llaisys::ops { void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + // 检查设备一致性 + CHECK_SAME_DEVICE(out, gate, up); + + // 检查数据类型一致性 + CHECK_SAME_DTYPE(out->dtype(), gate->dtype(), up->dtype()); + + // 检查形状一致性 + CHECK_SAME_SHAPE(out->shape(), gate->shape(), up->shape()); + + // 检查所有张量都是连续的 + ASSERT(out->isContiguous() && gate->isContiguous() && up->isContiguous(), + "swiglu: all tensors must be contiguous."); + + // 设置设备上下文 + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + size_t numel = out->numel(); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb65..04155b034 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -164,27 +164,134 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + // 检查张量的形状和步长,判断它在内存中是否连续 + if (_meta.shape.empty()) { + return true; + } + + // 从最后一个维度开始检查 + ptrdiff_t expected_stride = 1; + for (size_t i = _meta.shape.size(); i > 0; --i) { + size_t idx = i - 1; + if (_meta.strides[idx] != expected_stride) { + return false; + } + expected_stride *= _meta.shape[idx]; + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + // 创建一个新张量,改变原始张量维度的顺序。转置可以通过这个函数实现,而无需移动数据。 + CHECK_ARGUMENT(order.size() == _meta.shape.size(), + "permute order size must match tensor dimensions"); + + // 检查order是否是有效的排列(包含0到ndim-1的所有索引) + std::vector used(_meta.shape.size(), false); + for (size_t idx : order) { + CHECK_ARGUMENT(idx < _meta.shape.size(), "permute order index out of range"); + CHECK_ARGUMENT(!used[idx], "permute order contains duplicate indices"); + used[idx] = true; + } + + // 创建新的shape和strides + TensorMeta new_meta; + new_meta.dtype = _meta.dtype; + new_meta.shape.resize(order.size()); + new_meta.strides.resize(order.size()); + + for (size_t i = 0; i < order.size(); ++i) { + new_meta.shape[i] = _meta.shape[order[i]]; + new_meta.strides[i] = _meta.strides[order[i]]; + } + // shared_ptr 这种智能指针多人共享 是带有自动释放功能的指针 + return std::shared_ptr(new Tensor(new_meta, _storage, _offset)); } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + // 创建一个新张量,通过拆分或合并原始维度将原始张量重塑为给定形状。不涉及数据传输。例如,通过合并最后两个维度,将形状为(2, 3, 5)的张量更改为(2, 15)。 + + // 这个函数不是简单地改变张量的形状那么简单, + // 尽管测试会通过。如果新视图与原始张量不兼容, + // 它应该引发错误。想想一个形状为(2, 3, 5)、 + // 步长为(30, 10, 1)的张量。你还能在不传输数据的情况下将其重塑为(2, 15)吗? + + // 打印输入形状 + // std::cout << "DEBUG: view called with shape: ["; + // for (size_t i = 0; i < shape.size(); ++i) { + // std::cout << shape[i]; + // if (i < shape.size() - 1) { + // std::cout << ", "; + // } + // } + // std::cout << "]" << std::endl; + // 计算新形状的总元素数 + size_t new_numel = 1; + for (size_t s : shape) { + new_numel *= s; + } + CHECK_ARGUMENT(new_numel == this->numel(), + "view: new shape must have the same number of elements"); + + // view操作要求张量必须是连续的 + CHECK_ARGUMENT(this->isContiguous(), + "view: tensor must be contiguous"); + + // 计算新的步长(行优先顺序) + std::vector new_strides(shape.size()); + ptrdiff_t stride = 1; + for (size_t i = shape.size(); i > 0; --i) { + size_t idx = i - 1; + new_strides[idx] = stride; + stride *= shape[idx]; + } + + TensorMeta new_meta{_meta.dtype, shape, new_strides}; + return std::shared_ptr(new Tensor(new_meta, _storage, _offset)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + // 创建一个新张量,沿给定维度,start(包含)和end(不包含)索引对原始张量进行切片操作。 + CHECK_ARGUMENT(dim < _meta.shape.size(), "slice: dimension out of range"); + CHECK_ARGUMENT(start < end, "slice: start must be less than end"); + CHECK_ARGUMENT(end <= _meta.shape[dim], "slice: end index out of range"); + + // 创建新的元数据 + TensorMeta new_meta; + new_meta.dtype = _meta.dtype; + new_meta.shape = _meta.shape; + new_meta.strides = _meta.strides; + + // 修改切片维度的大小 + new_meta.shape[dim] = end - start; + + // 计算新的偏移量(以字节为单位) + size_t new_offset = _offset + start * _meta.strides[dim] * this->elementSize(); + + return std::shared_ptr(new Tensor(new_meta, _storage, new_offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + // 将主机(cpu)数据加载到张量(可以在设备上)。查看构造函数了解如何获取当前设备上下文的运行时API,并执行从主机到设备的内存复制。 + // 设置设备的上下文 + core::context().setDevice(this->deviceType(), this->deviceId()); + // 计算需要复制的字节数 + size_t bytes = this->numel() * this->elementSize(); + // 根据张量设备类型选择内存复制类型 + llaisysMemcpyKind_t copy_kind; + if (this->deviceType() == LLAISYS_DEVICE_CPU) { + copy_kind = LLAISYS_MEMCPY_H2H; // 主机到主机 + } else { + copy_kind = LLAISYS_MEMCPY_H2D; // 主机到设备 + } + + // 执行同步内存复制 + core::context().runtime().api()->memcpy_sync( + this->data(), // 目标地址(张量数据) + src_, // 源地址(主机数据) + bytes, // 复制字节数 + copy_kind // 复制类型 + ); } tensor_t Tensor::contiguous() const { diff --git a/src/tensor/tensor.hpp b/src/tensor/tensor.hpp index 35e340922..15ab76051 100644 --- a/src/tensor/tensor.hpp +++ b/src/tensor/tensor.hpp @@ -5,7 +5,7 @@ namespace llaisys { class Tensor; using tensor_t = std::shared_ptr; - +// 张量元数据结构 struct TensorMeta { llaisysDataType_t dtype; std::vector shape; diff --git a/test/test_runtime.py b/test/test_runtime.py index e2ac218a1..68b40e5d5 100644 --- a/test/test_runtime.py +++ b/test/test_runtime.py @@ -2,23 +2,49 @@ import torch from test_utils import * import argparse +import sys def test_basic_runtime_api(device_name: str = "cpu"): - api = llaisys.RuntimeAPI(llaisys_device(device_name)) + from llaisys.libllaisys import LIB_LLAISYS, llaisysDeviceType_t + from ctypes import c_int + device_type = llaisys_device(device_name) + + # Check whether this library was compiled with support for the requested device. + is_supported = LIB_LLAISYS.llaisysIsDeviceSupported(llaisysDeviceType_t(device_type.value)) + if not is_supported: + raise RuntimeError( + f"The llaisys library was NOT compiled with {device_name.upper()} support.\n" + f" Recompile with GPU support enabled:\n" + f" xmake f --nv-gpu=y && xmake -j$(nproc) && xmake install" + ) + + api = llaisys.RuntimeAPI(device_type) ndev = api.get_device_count() print(f"Found {ndev} {device_name} devices") + if ndev == 0: - print(" Skipped") - return + if device_name == "cpu": + raise RuntimeError( + "CPU device count is 0, which is unexpected. " + "Something is wrong with the runtime." + ) + else: + raise RuntimeError( + f"No {device_name} devices were found (library has {device_name.upper()} support compiled in).\n" + " Possible causes:\n" + " 1. GPU drivers are not installed or the GPU is not accessible in this container.\n" + " 2. CUDA_VISIBLE_DEVICES is set to empty string (\"\") which hides all GPUs.\n" + " Use CUDA_VISIBLE_DEVICES=0 (or unset it) to expose GPUs.\n" + " Hint: run 'nvidia-smi' to verify GPU visibility." + ) for i in range(ndev): - print("Testing device {i}...") + print(f"Testing device {i}...") api.set_device(i) test_memcpy(api, 1024 * 1024) - print(" Passed") @@ -50,6 +76,9 @@ def test_memcpy(api, size_bytes: int): llaisys.MemcpyKind.D2H, ) + api.free_device(device_a) + api.free_device(device_b) + torch.testing.assert_close(a, b) @@ -58,5 +87,5 @@ def test_memcpy(api, size_bytes: int): parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) args = parser.parse_args() test_basic_runtime_api(args.device) - + print("\033[92mTest passed!\033[0m\n") diff --git a/xmake.lua b/xmake.lua index 1f65f7a95..7dd636de6 100644 --- a/xmake.lua +++ b/xmake.lua @@ -37,6 +37,8 @@ target("llaisys-device") set_kind("static") add_deps("llaisys-utils") add_deps("llaisys-device-cpu") + -- Note: CUDA device files are compiled directly into the llaisys shared lib + -- to avoid __cudaRegisterLinkedBinary_* symbols being dropped by the linker. set_languages("cxx17") set_warnings("all", "error") @@ -83,6 +85,7 @@ target_end() target("llaisys-ops") set_kind("static") add_deps("llaisys-ops-cpu") + -- Note: CUDA ops files are compiled directly into the llaisys shared lib. set_languages("cxx17") set_warnings("all", "error") @@ -103,6 +106,20 @@ target("llaisys") add_deps("llaisys-tensor") add_deps("llaisys-ops") + -- 直接将 .cu 文件编译进共享库,避免中间静态库导致 + -- __cudaRegisterLinkedBinary_* 符号被链接器丢弃的问题(cuda.devlink 老版本不支持)。 + if has_config("nv-gpu") then + add_rules("cuda") + add_cuflags("--generate-code=arch=compute_80,code=sm_80", {force = true}) + add_cuflags("-std=c++17") + if not is_plat("windows") then + add_cuflags("-Xcompiler=-fPIC,-Wno-unknown-pragmas") + end + add_files("src/device/nvidia/*.cu") + add_files("src/ops/*/nvidia/*.cu") + add_links("cublas", "cuda", "cudart") + end + set_languages("cxx17") set_warnings("all", "error") add_files("src/llaisys/*.cc") @@ -110,7 +127,7 @@ target("llaisys") after_install(function (target) - -- copy shared library to python package + -- copy shared library to python package source tree print("Copying llaisys to python/llaisys/libllaisys/ ..") if is_plat("windows") then os.cp("bin/*.dll", "python/llaisys/libllaisys/") @@ -118,5 +135,15 @@ target("llaisys") if is_plat("linux") then os.cp("lib/*.so", "python/llaisys/libllaisys/") end + + -- (re-)install the Python package so that site-packages picks up the new .so + -- Using --no-build-isolation avoids re-running cmake/setup.py build steps. + print("Installing Python package (pip install -e python/) ..") + local ret = os.execv("pip", {"install", "-e", "python/", "--no-build-isolation", "-q"}) + if ret ~= 0 then + -- Fallback: try pip3 + os.execv("pip3", {"install", "-e", "python/", "--no-build-isolation", "-q"}) + end + print("Python package installed. You can now run the tests.") end) target_end() \ No newline at end of file diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua new file mode 100644 index 000000000..fc844b66a --- /dev/null +++ b/xmake/nvidia.lua @@ -0,0 +1,7 @@ +-- CUDA 源文件直接由主目标 llaisys 编译,不再使用中间静态库。 +-- 这样可以避免 __cudaRegisterLinkedBinary_* 符号被链接器丢弃的问题: +-- 当 CUDA 代码通过静态库(.a)链接进共享库(.so)时,Linux 链接器默认 +-- 只引入被显式引用的符号,而 CUDA 的设备代码注册函数不会被 C++ 代码 +-- 直接引用,因此会被丢弃。直接编译则不存在此问题。 +-- +-- 编译配置(cuflags 等)集中在 xmake.lua 的 llaisys 目标中。