diff --git a/.clang-format b/.clang-format index a77ae97c..40ed0dad 100644 --- a/.clang-format +++ b/.clang-format @@ -4,27 +4,9 @@ IndentWidth: 4 # 缩进宽度,LLVM 默认值为 2,改 AccessModifierOffset: -4 # public/protected/private 访问控制符相对成员的偏移,与 IndentWidth 配合,LLVM 默认值为 -2 AlignOperands: AlignAfterOperator # 双目运算符的行间对齐,LLVM 默认值为 Align,改为带符号一起换行 BreakBeforeBinaryOperators: All # 在双目运算符之前换行,LLVM 默认值为 None,改为换行时总是把双目运算符放在行首,包括赋值(=) -ColumnLimit: 0 # 列宽限制,LLVM 默认值为 80,改为不限制 +ColumnLimit: 80 # 列宽限制,LLVM 默认值为 80,改为不限制 AllowShortBlocksOnASingleLine: Always # 是否允许短块(单个语句的块)不换行,LLVM 默认值为 Never,改为允许 AllowShortLoopsOnASingleLine: true # 是否允许短循环不换行,LLVM 默认值为 false,改为允许 -InsertBraces: true # 是否在 if/for/while/switch 等语句后插入大括号,LLVM 默认值为 false,改为允许 -BreakBeforeBraces: Custom # 大括号换行配置,LLVM 默认值为 LLVM,改为自定义以使 BraceWrapping 生效 -BraceWrapping: - AfterCaseLabel: false - AfterClass: false - AfterControlStatement: Never - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - AfterExternBlock: false - BeforeCatch: false - BeforeElse: false - BeforeLambdaBody: false - BeforeWhile: false - IndentBraces: false - SplitEmptyFunction: true - SplitEmptyRecord: true - SplitEmptyNamespace: true +InsertBraces: false # 是否在 if/for/while/switch 等语句后插入大括号,LLVM 默认值为 false,改为允许 + +BinPackParameters: OnePerLine \ No newline at end of file diff --git a/.gitignore b/.gitignore index e38cf574..5bc9f3c7 100644 --- a/.gitignore +++ b/.gitignore @@ -87,4 +87,6 @@ htmlcov/ # Windows Thumbs.db ehthumbs.db -desktop.ini \ No newline at end of file +desktop.ini + +data \ No newline at end of file diff --git a/REPORT.md b/REPORT.md new file mode 100644 index 00000000..59452d4f --- /dev/null +++ b/REPORT.md @@ -0,0 +1,25 @@ +## CUDA Backend for Chat Server + +首先设置环境变量 `LLAISYS_DEVICE=nvidia` 来启用 CUDA backend. + +## 流式输出 + +然后运行 `python/chat_server.py` 启用 OpenAI 风格的 API + +然后可以用 `curl` 来测试流式输出 + +```bash +curl -N http://127.0.0.1:9108/v1/chat/completions \ +-H "Content-Type: application/json" \ +-H "Accept: text/event-stream" \ +-d '{"model":"qwen2","messages":[{"role":"user","content":"Hi who are you?"}],"stream":true,"max_tokens":64,"temperature":0.8,"top_p":0.9,"top_k":40}' +``` + +### TUI Chatting + +```bash +python -m llaisys.chat.tui --url http://127.0.0.1:9108 + +# 或者用 uv +uv run python -m llaisys.chat.tui --url http://127.0.0.1:9108 +``` diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h index 7054626d..8391a1bf 100644 --- a/include/llaisys/models/qwen2.h +++ b/include/llaisys/models/qwen2.h @@ -2,6 +2,8 @@ #define LLAISYS_MODELS_QWEN2_H #include "../tensor.h" +#include +#include __C { struct LlaisysQwen2Meta { @@ -31,12 +33,43 @@ __C { struct LlaisysQwen2Model; - __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice); + __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate( + const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, + int *device_ids, int ndevice); __export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model); - __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model); + __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights( + struct LlaisysQwen2Model * model); - __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); + /** + * @brief Inference function for Qwen2 Model. This function combines both + * prefill and decode through `prefill` flag. + * @note This function will reset KV Caches if `prefill` is true. + * + * @param token_ids input token ids + * @param pos_ids input position ids, used for RoPE + */ + __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, + int64_t *token_ids, int64_t *pos_ids, size_t ntoken, bool prefill); + + /** + * @brief Inference with configurable decoding strategy. + * + * This behaves the same as llaisysQwen2ModelInfer for prefill/decode flow, + * but selects the next token using temperature/top-k/top-p sampling. + * + * @param top_k Keep only top-k tokens before sampling (0 = disabled). + * @param top_p Nucleus threshold in (0, 1] (1.0 = disabled). + * @param temperature Positive temperature scalar. + */ + __export int64_t llaisysQwen2ModelInferSample(struct LlaisysQwen2Model * model, + int64_t *token_ids, int64_t *pos_ids, + size_t ntoken, bool prefill, + int top_k, float top_p, float temperature); + + __export void llaisysQwen2SetWeights(struct LlaisysQwen2Model * model, + int name, int layer_id, + llaisysTensor_t tensor); } #endif // LLAISYS_MODELS_QWEN2_H diff --git a/include/llaisys/ops.h b/include/llaisys/ops.h index ddb3be24..fb61d8b7 100644 --- a/include/llaisys/ops.h +++ b/include/llaisys/ops.h @@ -13,6 +13,20 @@ __C { __export void llaisysROPE(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t pos_ids, float theta); __export void llaisysSelfAttention(llaisysTensor_t attn_val, llaisysTensor_t q, llaisysTensor_t k, llaisysTensor_t v, float scale); __export void llaisysSwiGLU(llaisysTensor_t out, llaisysTensor_t gate, llaisysTensor_t up); + + /** + * @brief Sample a token index from logits with temperature / top-k / top-p. + * @param out Shape {1}, dtype int64. Receives the sampled token index. + * @param logits Shape {vocab_size}, any float dtype. Must be contiguous. + * @param top_k Keep only top-k tokens (0 = disabled). + * @param top_p Nucleus threshold in (0, 1] (1.0 = disabled). + * @param temperature Positive temperature scalar. + */ + __export void llaisysSample(llaisysTensor_t out, llaisysTensor_t logits, + int top_k, float top_p, float temperature); + + /** Set the per-thread RNG seed used by llaisysSample. */ + __export void llaisysSampleSetSeed(uint64_t seed); } #endif diff --git a/python/chat_server.py b/python/chat_server.py new file mode 100644 index 00000000..5b4a1afb --- /dev/null +++ b/python/chat_server.py @@ -0,0 +1,13 @@ +import uvicorn + +from llaisys.chat.server import build_runtime_from_env, create_app + + +def main() -> None: + runtime = build_runtime_from_env() + app = create_app(runtime) + uvicorn.run(app, host="0.0.0.0", port=9108) + + +if __name__ == "__main__": + main() diff --git a/python/llaisys/__init__.py b/python/llaisys/__init__.py index de8d99f4..c95ac7db 100644 --- a/python/llaisys/__init__.py +++ b/python/llaisys/__init__.py @@ -7,6 +7,7 @@ from .ops import Ops from . import models from .models import * +from .libllaisys import LIB_LLAISYS __all__ = [ "RuntimeAPI", @@ -17,4 +18,5 @@ "Tensor", "Ops", "models", + "LIB_LLAISYS" ] diff --git a/python/llaisys/chat/__init__.py b/python/llaisys/chat/__init__.py new file mode 100644 index 00000000..a440a378 --- /dev/null +++ b/python/llaisys/chat/__init__.py @@ -0,0 +1,3 @@ +from .server import create_app + +__all__ = ["create_app"] diff --git a/python/llaisys/chat/server.py b/python/llaisys/chat/server.py new file mode 100644 index 00000000..e230919a --- /dev/null +++ b/python/llaisys/chat/server.py @@ -0,0 +1,209 @@ +import os +import time +import json +import threading +from dataclasses import dataclass, field +from typing import Literal + +from fastapi import FastAPI, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field +from transformers import AutoTokenizer + +from ..libllaisys import DeviceType +from ..models.qwen2 import Qwen2 + + +class ChatMessage(BaseModel): + role: Literal["system", "user", "assistant"] + content: str + + +class ChatCompletionRequest(BaseModel): + model: str = Field(default="qwen2") + messages: list[ChatMessage] + max_tokens: int = Field(default=128, ge=1, le=4096) + temperature: float = Field(default=0.8, gt=0.0) + top_p: float = Field(default=0.9, gt=0.0, le=1.0) + top_k: int = Field(default=40, ge=0) + stream: bool = False + + +class ChoiceMessage(BaseModel): + role: Literal["assistant"] = "assistant" + content: str + + +class ChatCompletionChoice(BaseModel): + index: int + message: ChoiceMessage + finish_reason: Literal["stop", "length"] + + +class ChatCompletionUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionResponse(BaseModel): + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int + model: str + choices: list[ChatCompletionChoice] + usage: ChatCompletionUsage + + +@dataclass +class ChatRuntime: + tokenizer: AutoTokenizer + model: Qwen2 + lock: threading.Lock = field(default_factory=threading.Lock) + + +def _render_prompt(messages: list[ChatMessage], tokenizer: AutoTokenizer) -> str: + role_map = [{"role": m.role, "content": m.content} for m in messages] + if hasattr(tokenizer, "apply_chat_template"): + return tokenizer.apply_chat_template( + role_map, + tokenize=False, + add_generation_prompt=True, + ) + + chunks = [] + for message in role_map: + chunks.append(f"<{message['role']}>\n{message['content']}\n") + chunks.append("\n") + return "".join(chunks) + + +def _decode_new_text(tokenizer: AutoTokenizer, all_tokens: list[int], prompt_len: int) -> str: + return tokenizer.decode(all_tokens[prompt_len:], skip_special_tokens=True) + + +def _sse(event: dict) -> str: + return f"data: {json.dumps(event, ensure_ascii=False)}\n\n" + + +def create_app(runtime: ChatRuntime) -> FastAPI: + app = FastAPI(title="LLAISYS Chat Server", version="0.1.0") + + @app.get("/health") + def health() -> dict[str, str]: + return {"status": "ok"} + + @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) + def chat_completions(req: ChatCompletionRequest): + if not req.messages: + raise HTTPException(status_code=400, detail="messages must not be empty") + + prompt = _render_prompt(req.messages, runtime.tokenizer) + prompt_tokens = runtime.tokenizer.encode(prompt) + + created = int(time.time()) + response_id = f"chatcmpl-{int(time.time() * 1000)}" + + if req.stream: + def event_stream(): + with runtime.lock: + yield _sse({ + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": req.model, + "choices": [{ + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": None, + }], + }) + + generated = [] + last_text = "" + for token in runtime.model.generate_stream( + prompt_tokens, + max_new_tokens=req.max_tokens, + top_k=req.top_k, + top_p=req.top_p, + temperature=req.temperature, + ): + generated.append(token) + current_text = runtime.tokenizer.decode(generated, skip_special_tokens=True) + delta_text = current_text[len(last_text):] + if not delta_text: + continue + last_text = current_text + yield _sse({ + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": req.model, + "choices": [{ + "index": 0, + "delta": {"content": delta_text}, + "finish_reason": None, + }], + }) + + yield _sse({ + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": req.model, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop", + }], + }) + yield "data: [DONE]\n\n" + + return StreamingResponse(event_stream(), media_type="text/event-stream") + + with runtime.lock: + generated_tokens = runtime.model.generate( + prompt_tokens, + max_new_tokens=req.max_tokens, + top_k=req.top_k, + top_p=req.top_p, + temperature=req.temperature, + ) + + answer_text = _decode_new_text(runtime.tokenizer, generated_tokens, len(prompt_tokens)) + completion_tokens = max(0, len(generated_tokens) - len(prompt_tokens)) + + return ChatCompletionResponse( + id=response_id, + created=created, + model=req.model, + choices=[ + ChatCompletionChoice( + index=0, + message=ChoiceMessage(content=answer_text), + finish_reason="stop", + ) + ], + usage=ChatCompletionUsage( + prompt_tokens=len(prompt_tokens), + completion_tokens=completion_tokens, + total_tokens=len(prompt_tokens) + completion_tokens, + ), + ) + + return app + + +def _parse_device(device_name: str) -> DeviceType: + if device_name.lower() == "nvidia": + return DeviceType.NVIDIA + return DeviceType.CPU + + +def build_runtime_from_env() -> ChatRuntime: + model_path = os.environ.get("LLAISYS_MODEL_PATH", "./data") + device = _parse_device(os.environ.get("LLAISYS_DEVICE", "cpu")) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = Qwen2(model_path=model_path, device=device) + return ChatRuntime(tokenizer=tokenizer, model=model) diff --git a/python/llaisys/chat/tui.py b/python/llaisys/chat/tui.py new file mode 100644 index 00000000..642a4308 --- /dev/null +++ b/python/llaisys/chat/tui.py @@ -0,0 +1,237 @@ +import argparse +import json +import sys +from typing import Any +from urllib import error, request + + +class ChatTUI: + def __init__( + self, + base_url: str, + model: str, + max_tokens: int, + temperature: float, + top_p: float, + top_k: int, + stream: bool, + system_prompt: str | None = None, + ) -> None: + self.base_url = base_url.rstrip("/") + self.model = model + self.max_tokens = max_tokens + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.stream = stream + self.messages: list[dict[str, str]] = [] + self.system_prompt = "" + + if system_prompt: + self.set_system_prompt(system_prompt) + + def set_system_prompt(self, text: str) -> None: + self.system_prompt = text + self.messages = [m for m in self.messages if m.get("role") != "system"] + self.messages.insert(0, {"role": "system", "content": text}) + + def clear_history(self) -> None: + self.messages = [] + if self.system_prompt: + self.messages.append({"role": "system", "content": self.system_prompt}) + + def _post_json(self, path: str, payload: dict[str, Any]): + body = json.dumps(payload).encode("utf-8") + req = request.Request( + url=f"{self.base_url}{path}", + data=body, + headers={ + "Content-Type": "application/json", + "Accept": "text/event-stream" if payload.get("stream") else "application/json", + }, + method="POST", + ) + return request.urlopen(req, timeout=600) + + def _chat_non_stream(self) -> str: + payload = { + "model": self.model, + "messages": self.messages, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "stream": False, + } + with self._post_json("/v1/chat/completions", payload) as resp: + data = json.loads(resp.read().decode("utf-8")) + return data["choices"][0]["message"]["content"] + + def _chat_stream(self) -> str: + payload = { + "model": self.model, + "messages": self.messages, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "stream": True, + } + + chunks: list[str] = [] + with self._post_json("/v1/chat/completions", payload) as resp: + for raw in resp: + line = raw.decode("utf-8", errors="ignore").strip() + if not line or not line.startswith("data: "): + continue + + event = line[6:] + if event == "[DONE]": + break + + try: + data = json.loads(event) + except json.JSONDecodeError: + continue + + delta = data.get("choices", [{}])[0].get("delta", {}) + content = delta.get("content") + if not content: + continue + + chunks.append(content) + print(content, end="", flush=True) + + print("") + return "".join(chunks) + + def _chat(self) -> str: + if self.stream: + return self._chat_stream() + answer = self._chat_non_stream() + print(answer) + return answer + + def send_user_message(self, text: str) -> None: + self.messages.append({"role": "user", "content": text}) + print("assistant> ", end="", flush=True) + + try: + answer = self._chat() + except error.HTTPError as ex: + detail = ex.read().decode("utf-8", errors="ignore") + print(f"\n[HTTP {ex.code}] {detail}") + if self.messages and self.messages[-1]["role"] == "user" and self.messages[-1]["content"] == text: + self.messages.pop() + return + except error.URLError as ex: + print(f"\n[Network error] {ex}") + if self.messages and self.messages[-1]["role"] == "user" and self.messages[-1]["content"] == text: + self.messages.pop() + return + + self.messages.append({"role": "assistant", "content": answer}) + + def retry_last(self) -> None: + if not self.messages: + print("No history to retry.") + return + + if self.messages[-1]["role"] == "assistant": + self.messages.pop() + + last_user_idx = -1 + for i in range(len(self.messages) - 1, -1, -1): + if self.messages[i]["role"] == "user": + last_user_idx = i + break + + if last_user_idx == -1: + print("No user message found to retry.") + return + + print("assistant> ", end="", flush=True) + try: + answer = self._chat() + except Exception as ex: # keep retry resilient for terminal usage + print(f"\n[Retry failed] {ex}") + return + + self.messages.append({"role": "assistant", "content": answer}) + + def print_help(self) -> None: + print("Commands:") + print(" /help Show this help") + print(" /exit Exit TUI") + print(" /clear Clear local conversation history") + print(" /retry Regenerate assistant answer for last user turn") + print(" /system Set or replace system prompt") + + def repl(self) -> None: + print("LLAISYS Chat TUI") + print(f"Server: {self.base_url}") + print("Type /help for commands.") + + while True: + try: + line = input("you> ").strip() + except (EOFError, KeyboardInterrupt): + print("\nBye.") + break + + if not line: + continue + + if line == "/exit": + print("Bye.") + break + if line == "/help": + self.print_help() + continue + if line == "/clear": + self.clear_history() + print("History cleared.") + continue + if line == "/retry": + self.retry_last() + continue + if line.startswith("/system "): + prompt = line[len("/system ") :].strip() + if not prompt: + print("Usage: /system ") + continue + self.set_system_prompt(prompt) + print("System prompt updated.") + continue + + self.send_user_message(line) + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description="LLAISYS chat TUI client") + parser.add_argument("--url", default="http://127.0.0.1:9108", help="Base URL of chat server") + parser.add_argument("--model", default="qwen2", help="Model name in request payload") + parser.add_argument("--max-tokens", type=int, default=128) + parser.add_argument("--temperature", type=float, default=0.8) + parser.add_argument("--top-p", type=float, default=0.9) + parser.add_argument("--top-k", type=int, default=40) + parser.add_argument("--system", default="", help="Optional system prompt") + parser.add_argument("--no-stream", action="store_true", help="Disable SSE streaming") + args = parser.parse_args(argv) + + tui = ChatTUI( + base_url=args.url, + model=args.model, + max_tokens=args.max_tokens, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + stream=not args.no_stream, + system_prompt=args.system or None, + ) + tui.repl() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py index f536fb52..f643111d 100644 --- a/python/llaisys/libllaisys/__init__.py +++ b/python/llaisys/libllaisys/__init__.py @@ -9,6 +9,7 @@ from .llaisys_types import llaisysDataType_t, DataType from .llaisys_types import llaisysMemcpyKind_t, MemcpyKind from .llaisys_types import llaisysStream_t +from .qwen2 import LlaisysQwen2Meta, load_qwen2_model from .tensor import llaisysTensor_t from .tensor import load_tensor from .ops import load_ops @@ -31,13 +32,15 @@ def load_shared_library(): if not os.path.isfile(lib_path): raise FileNotFoundError(f"Shared library not found: {lib_path}") - return ctypes.CDLL(str(lib_path)) + lib = ctypes.CDLL(str(lib_path)) + return lib LIB_LLAISYS = load_shared_library() load_runtime(LIB_LLAISYS) load_tensor(LIB_LLAISYS) load_ops(LIB_LLAISYS) +load_qwen2_model(LIB_LLAISYS) __all__ = [ @@ -52,4 +55,5 @@ def load_shared_library(): "llaisysMemcpyKind_t", "MemcpyKind", "llaisysStream_t", + "LlaisysQwen2Meta", ] diff --git a/python/llaisys/libllaisys/ops.py b/python/llaisys/libllaisys/ops.py index 5be095ef..9e8dafd0 100644 --- a/python/llaisys/libllaisys/ops.py +++ b/python/llaisys/libllaisys/ops.py @@ -34,3 +34,10 @@ def load_ops(lib): lib.llaisysSwiGLU.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t] lib.llaisysSwiGLU.restype = None + + from ctypes import c_int, c_uint64 + lib.llaisysSample.argtypes = [llaisysTensor_t, llaisysTensor_t, c_int, c_float, c_float] + lib.llaisysSample.restype = None + + lib.llaisysSampleSetSeed.argtypes = [c_uint64] + lib.llaisysSampleSetSeed.restype = None diff --git a/python/llaisys/libllaisys/qwen2.py b/python/llaisys/libllaisys/qwen2.py new file mode 100644 index 00000000..c706dd4c --- /dev/null +++ b/python/llaisys/libllaisys/qwen2.py @@ -0,0 +1,77 @@ +from ctypes import ( + Structure, + POINTER, + c_int, + c_size_t, + c_float, + c_int64, + c_void_p, + c_char_p, + c_bool, +) +import sys + +from .llaisys_types import llaisysDeviceType_t, llaisysDataType_t +from .tensor import llaisysTensor_t + + +class LlaisysQwen2Meta(Structure): + _fields_ = [ + ("dtype", llaisysDataType_t), + ("nlayer", c_size_t), + ("hs", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("maxseq", c_size_t), + ("voc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_int64), + ] + + +def load_qwen2_model(lib): + lib.llaisysQwen2ModelCreate.argtypes = [ + POINTER(LlaisysQwen2Meta), + llaisysDeviceType_t, + POINTER(c_int), + c_int, + ] + lib.llaisysQwen2ModelCreate.restype = c_void_p + + lib.llaisysQwen2SetWeights.argtypes = [c_void_p, c_int, c_int, llaisysTensor_t] + lib.llaisysQwen2SetWeights.restype = None + + lib.llaisysQwen2ModelInfer.argtypes = [ + c_void_p, + POINTER(c_int64), + POINTER(c_int64), + c_size_t, + c_bool, + ] + lib.llaisysQwen2ModelInfer.restype = c_int64 + + lib.llaisysQwen2ModelInferSample.argtypes = [ + c_void_p, + POINTER(c_int64), + POINTER(c_int64), + c_size_t, + c_bool, + c_int, + c_float, + c_float, + ] + lib.llaisysQwen2ModelInferSample.restype = c_int64 + + lib.llaisysQwen2ModelWeights.argtypes = [c_void_p] + lib.llaisysQwen2ModelWeights.restype = c_void_p + + lib.llaisysQwen2ModelDestroy.argtypes = [c_void_p] + lib.llaisysQwen2ModelDestroy.restype = None + + return lib + + +__all__ = ["LlaisysQwen2Meta", "load_qwen2_model"] diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b2..d8ff27e0 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,33 +1,223 @@ -from typing import Sequence -from ..libllaisys import LIB_LLAISYS -from ..libllaisys import DeviceType +import ctypes +import json +from typing import Iterator, Sequence +from ..libllaisys import LIB_LLAISYS, LlaisysQwen2Meta +from ..libllaisys import DeviceType, DataType +from ..tensor import Tensor from pathlib import Path import safetensors +import torch +import numpy as np +from tqdm import tqdm -class Qwen2: - def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor +DEFAULT_MODEL_PATH = "./data" - model_path = Path(model_path) - for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") - for name_ in data_.keys(): - ## TODO: load the model weights - pass +class Qwen2: + def __init__( + self, + model_path=DEFAULT_MODEL_PATH, + device: DeviceType = DeviceType.CPU, + ): + self.device = device + self._backend = None + + model_path = Path(model_path) + self.__load_config(model_path / "config.json") + self.__load_weights(model_path) def generate( self, inputs: Sequence[int], - max_new_tokens: int = None, + max_new_tokens: int = 128, top_k: int = 1, top_p: float = 0.8, temperature: float = 0.8, ): + answer = list(inputs) + for token in self.generate_stream( + inputs, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ): + answer.append(token) + + return answer + + def generate_stream( + self, + inputs: Sequence[int], + max_new_tokens: int = 128, + top_k: int = 1, + top_p: float = 0.8, + temperature: float = 0.8, + ) -> Iterator[int]: + if len(inputs) == 0 or max_new_tokens <= 0: + return + + array = (ctypes.c_int64 * len(inputs))(*inputs) + pos_ids = (ctypes.c_int64 * len(inputs))(*range(len(inputs))) + + # Prefill and yield first token. + output_token = LIB_LLAISYS.llaisysQwen2ModelInferSample( + self._backend, + array, + pos_ids, + len(inputs), + True, + int(top_k), + float(top_p), + float(temperature), + ) + yield output_token + + # Decode loop: one token per step. + current_token = output_token + current_pos = len(inputs) + for _ in range(max_new_tokens - 1): + if current_token == self.meta.end_token: + break + + array = (ctypes.c_int64 * 1)(current_token) + pos_ids = (ctypes.c_int64 * 1)(current_pos) + output_token = LIB_LLAISYS.llaisysQwen2ModelInferSample( + self._backend, + array, + pos_ids, + 1, + False, + int(top_k), + float(top_p), + float(temperature), + ) + yield output_token + current_token = output_token + current_pos += 1 + + def __load_config(self, config_path: Path): + with open(config_path, "r") as f: + config = json.load(f) + + meta = LlaisysQwen2Meta() + match config.get("torch_dtype", ""): + case "bfloat16": + meta.dtype = DataType.BF16 + case "float16": + meta.dtype = DataType.F16 + case "float32": + meta.dtype = DataType.F32 + case _: + raise ValueError( + f"Unsupported data type: {config.get('torch_dtype', '')}" + ) + meta.dtype = DataType.F32 # always use fp32 for now + meta.nlayer = config.get("num_hidden_layers", 0) + meta.nh = config.get("num_attention_heads", 0) + meta.hs = config.get("hidden_size", 0) + meta.nkvh = config.get("num_key_value_heads", 0) + meta.dh = config.get("head_dim", int(meta.hs / meta.nh) if meta.nh else 0) + meta.di = config.get("intermediate_size", 0) + meta.maxseq = config.get("max_position_embeddings", 0) + meta.voc = config.get("vocab_size", 0) + meta.epsilon = config.get("layer_norm_epsilon", 1e-5) + meta.theta = config.get("rope_theta", 1000000.0) + meta.end_token = config.get("eos_token_id", 0) + + self.meta = meta + + # Init model + self._backend = LIB_LLAISYS.llaisysQwen2ModelCreate( + ctypes.byref(self.meta), + self.device, + None, + 0, + ) + + def __get_name_mapping(self, weights_folder: Path): + self.name_mapping: dict[str, tuple[str, int, int]] = { + # (short name, indicator of weight type, optional layer index) + # see: llaisys/qwen2.cc:setWeights() + "model.embed_tokens.weight": ("in_embed", 0, -1), + "lm_head.weight": ("out_embed", 1, -1), + "model.norm.weight": ("out_norm_w", 2, -1), + } + + for layer_idx in range(self.meta.nlayer): + prefix = f"model.layers.{layer_idx}" + self.name_mapping.update( + { + f"{prefix}.input_layernorm.weight": ("attn_norm_w", 3, layer_idx), + f"{prefix}.self_attn.q_proj.weight": ("attn_q_w", 4, layer_idx), + f"{prefix}.self_attn.q_proj.bias": ("attn_q_b", 5, layer_idx), + f"{prefix}.self_attn.k_proj.weight": ("attn_k_w", 6, layer_idx), + f"{prefix}.self_attn.k_proj.bias": ("attn_k_b", 7, layer_idx), + f"{prefix}.self_attn.v_proj.weight": ("attn_v_w", 8, layer_idx), + f"{prefix}.self_attn.v_proj.bias": ("attn_v_b", 9, layer_idx), + f"{prefix}.self_attn.o_proj.weight": ("attn_o_w", 10, layer_idx), + f"{prefix}.post_attention_layernorm.weight": ( + "mlp_norm_w", + 11, + layer_idx, + ), + f"{prefix}.mlp.gate_proj.weight": ("mlp_gate_w", 12, layer_idx), + f"{prefix}.mlp.up_proj.weight": ("mlp_up_w", 13, layer_idx), + f"{prefix}.mlp.down_proj.weight": ("mlp_down_w", 14, layer_idx), + } + ) + + def __load_weights(self, weights_folder: Path): + self.__get_name_mapping(weights_folder) + + for file in sorted(weights_folder.glob("*.safetensors")): + with safetensors.safe_open(file, framework="pt", device="cpu") as data_: + for name_ in data_.keys(): + if name_ not in self.name_mapping: + raise ValueError(f"Unknown weight name: {name_}") + short_name, weight_type, layer_idx = self.name_mapping[name_] + weight_data = data_.get_tensor(name_) # load as torch + + # Convert to target dtype before numpy conversion + if self.meta.dtype == DataType.BF16: + weight_data = ( + weight_data.to(torch.bfloat16).view(torch.uint16).numpy() + ) + elif self.meta.dtype == DataType.F16: + weight_data = ( + weight_data.to(torch.float16).view(torch.uint16).numpy() + ) + else: # F32 + weight_data = weight_data.float().numpy() + + tensor = Tensor(weight_data.shape, self.meta.dtype, self.device) + tensor.load(weight_data.ctypes.data) + + # Set weights in the backend + LIB_LLAISYS.llaisysQwen2SetWeights( + self._backend, + weight_type, + layer_idx, + tensor.lib_tensor(), + ) + + def generate_no_decode(self, inputs: Sequence[int], max_new_tokens: int): + answer = list(inputs) + for step in tqdm(range(max_new_tokens), desc="Generating"): + array = (ctypes.c_int64 * len(answer))(*answer) + pos_ids = (ctypes.c_int64 * len(answer))(*range(len(answer))) - # TODO: Implement generate function + # Prefill and get first token + output_token = LIB_LLAISYS.llaisysQwen2ModelInfer( + self._backend, + array, + pos_ids, + len(answer), + True, # prefill + ) + answer.append(output_token) - return [] + return answer diff --git a/python/llaisys/ops.py b/python/llaisys/ops.py index ed0180bc..6d4082af 100644 --- a/python/llaisys/ops.py +++ b/python/llaisys/ops.py @@ -1,6 +1,6 @@ from .libllaisys import LIB_LLAISYS from .tensor import Tensor -from ctypes import c_float, c_int +from ctypes import c_float, c_int, c_uint64 class Ops: @@ -53,3 +53,16 @@ def self_attention(attn_val: Tensor, q: Tensor, k: Tensor, v: Tensor, scale: flo @staticmethod def swiglu(out: Tensor, gate: Tensor, up: Tensor): LIB_LLAISYS.llaisysSwiGLU(out.lib_tensor(), gate.lib_tensor(), up.lib_tensor()) + + @staticmethod + def sample(out: Tensor, logits: Tensor, top_k: int = 0, top_p: float = 1.0, temperature: float = 1.0): + """Sample a token from logits. out must be shape (1,) int64.""" + LIB_LLAISYS.llaisysSample( + out.lib_tensor(), logits.lib_tensor(), + c_int(top_k), c_float(top_p), c_float(temperature), + ) + + @staticmethod + def sample_set_seed(seed: int): + """Set the per-thread RNG seed used by Ops.sample for reproducible results.""" + LIB_LLAISYS.llaisysSampleSetSeed(c_uint64(seed)) diff --git a/python/setup.cfg b/python/setup.cfg index b35fc65f..76f429b1 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -13,6 +13,8 @@ install_requires = torch>=2.4.0 transformers accelerate + fastapi + uvicorn [options.package_data] llaisys = diff --git a/src/device/nvidia/nvidia_resource.cuh b/src/device/nvidia/nvidia_resource.cuh index a3002170..709d073e 100644 --- a/src/device/nvidia/nvidia_resource.cuh +++ b/src/device/nvidia/nvidia_resource.cuh @@ -6,6 +6,6 @@ namespace llaisys::device::nvidia { class Resource : public llaisys::device::DeviceResource { public: Resource(int device_id); - ~Resource(); + ~Resource() = default; }; } // namespace llaisys::device::nvidia diff --git a/src/device/nvidia/nvidia_runtime_api.cu b/src/device/nvidia/nvidia_runtime_api.cu index cab92826..6c5d0312 100644 --- a/src/device/nvidia/nvidia_runtime_api.cu +++ b/src/device/nvidia/nvidia_runtime_api.cu @@ -2,55 +2,64 @@ #include #include +#include namespace llaisys::device::nvidia { namespace runtime_api { int getDeviceCount() { - TO_BE_IMPLEMENTED(); + int count; + cudaGetDeviceCount(&count); + return count; } -void setDevice(int) { - TO_BE_IMPLEMENTED(); +void setDevice(int device_id) { + cudaSetDevice(device_id); } void deviceSynchronize() { - TO_BE_IMPLEMENTED(); + cudaDeviceSynchronize(); } llaisysStream_t createStream() { - TO_BE_IMPLEMENTED(); + cudaStream_t stream; + cudaStreamCreate(&stream); + return reinterpret_cast(stream); } void destroyStream(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + cudaStreamDestroy(reinterpret_cast(stream)); } void streamSynchronize(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + cudaStreamSynchronize(reinterpret_cast(stream)); } void *mallocDevice(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr; + cudaMalloc(&ptr, size); + return ptr; } void freeDevice(void *ptr) { - TO_BE_IMPLEMENTED(); + cudaFree(ptr); } void *mallocHost(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr; + cudaMallocHost(&ptr, size); + return ptr; } void freeHost(void *ptr) { - TO_BE_IMPLEMENTED(); + cudaFreeHost(ptr); } void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); + cudaMemcpy(dst, src, size, static_cast(kind)); } -void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); +void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) { + cudaMemcpyAsync(dst, src, size, static_cast(kind), reinterpret_cast(stream)); } static const LlaisysRuntimeAPI RUNTIME_API = { diff --git a/src/kvcache/simple.cc b/src/kvcache/simple.cc new file mode 100644 index 00000000..c241ef1b --- /dev/null +++ b/src/kvcache/simple.cc @@ -0,0 +1,107 @@ +#include "simple.hpp" +#include "../core/context/context.hpp" +#include +#include + +namespace llaisys::kvcache::simple { + +KVCache::KVCache(usize capacity, + usize num_kv_head, + usize head_dim, + usize vdim, + llaisysDataType_t dtype, + llaisysDeviceType_t device, + int device_id) + : capacity(capacity), cache_size(0), num_kv_head(num_kv_head), + head_dim(head_dim), vdim(vdim), dtype(dtype), device(device), + device_id(device_id), + keys(Tensor::create( + {capacity, num_kv_head, head_dim}, dtype, device, device_id)), + values(Tensor::create( + {capacity, num_kv_head, vdim}, dtype, device, device_id)) { + + // Initialize KV cache to zero + usize keys_size = capacity * num_kv_head * head_dim * keys->elementSize(); + usize values_size = capacity * num_kv_head * vdim * values->elementSize(); + + if (device == LLAISYS_DEVICE_CPU) { + // For CPU, use std::memset directly + std::memset(static_cast(keys->data()), 0, keys_size); + std::memset(static_cast(values->data()), 0, values_size); + } else { + // For other devices (CUDA, etc.), use runtime API + core::context().setDevice(device, device_id); + auto *api = core::context().runtime().api(); + + // Create zero buffers on host and copy to device + std::vector zero_keys(keys_size, 0); + std::vector zero_values(values_size, 0); + + api->memcpy_sync(keys->data(), zero_keys.data(), keys_size, LLAISYS_MEMCPY_H2D); + api->memcpy_sync(values->data(), zero_values.data(), values_size, LLAISYS_MEMCPY_H2D); + } +} + +void KVCache::reset() { cache_size = 0; } + +void KVCache::insert(const tensor &new_keys, + const tensor &new_values, + usize n_new) { + if (cache_size + n_new > capacity) + throw std::runtime_error("KVCache insert position exceeds capacity."); + if (new_keys->shape()[0] != n_new || new_keys->shape()[1] != num_kv_head + || new_keys->shape()[2] != head_dim) + throw std::runtime_error("New keys tensor shape mismatch."); + if (new_values->shape()[0] != n_new || new_values->shape()[1] != num_kv_head + || new_values->shape()[2] != vdim) + throw std::runtime_error("New values tensor shape mismatch."); + if (new_keys->deviceType() != device || new_values->deviceType() != device) + throw std::runtime_error("Device mismatch"); + if (!new_keys->isContiguous() || !new_values->isContiguous()) + throw std::runtime_error("New K/V must be contiguous"); + if (!keys->isContiguous() || !values->isContiguous()) + throw std::runtime_error("Cache storage must be contiguous"); + + core::context().setDevice(device, device_id); + auto *api = core::context().runtime().api(); + + // Copy keys + do { + auto begin = static_cast(keys->data()) + + cache_size * num_kv_head * head_dim * keys->elementSize(); + auto numel = n_new * num_kv_head * head_dim; + auto size_bytes = numel * new_keys->elementSize(); + + if (device == LLAISYS_DEVICE_CPU) { + std::memcpy(begin, static_cast(new_keys->data()), size_bytes); + } else { + api->memcpy_sync(begin, new_keys->data(), size_bytes, LLAISYS_MEMCPY_D2D); + } + } while (false); + + // Copy values + do { + auto begin = static_cast(values->data()) + + cache_size * num_kv_head * vdim * values->elementSize(); + auto numel = n_new * num_kv_head * vdim; + auto size_bytes = numel * new_values->elementSize(); + + if (device == LLAISYS_DEVICE_CPU) { + std::memcpy(begin, static_cast(new_values->data()), size_bytes); + } else { + api->memcpy_sync(begin, new_values->data(), size_bytes, LLAISYS_MEMCPY_D2D); + } + } while (false); + + cache_size += n_new; +} + +KVCache::tensor KVCache::getKeysSlice() { + return keys->slice(0, 0, cache_size); +} + +KVCache::tensor KVCache::getValuesSlice() { + return values->slice(0, 0, cache_size); +} + +} // namespace llaisys::kvcache::simple \ No newline at end of file diff --git a/src/kvcache/simple.hpp b/src/kvcache/simple.hpp new file mode 100644 index 00000000..e94ac242 --- /dev/null +++ b/src/kvcache/simple.hpp @@ -0,0 +1,66 @@ +#pragma once + +#include "../tensor/tensor.hpp" +#include "llaisys.h" +#include + +namespace llaisys::kvcache::simple { + +/** + * @brief A simple per-layer KV cache implementation for transformer models. + */ +struct KVCache { + using usize = size_t; + using tensor = tensor_t; + + usize capacity; + usize cache_size; + + usize num_kv_head; + usize head_dim; + usize vdim; + llaisysDataType_t dtype; + llaisysDeviceType_t device; + int device_id; + + tensor keys; + tensor values; + + KVCache(usize capacity, + usize num_kv_head, + usize head_dim, + usize vdim, + llaisysDataType_t dtype, + llaisysDeviceType_t device, + int device_id); + ~KVCache() = default; + + /** + * @brief Reset the KV cache to empty state. + */ + void reset(); + + /** + * @brief Insert new keys and values to the cache at the given position. + * @note This function might overwrite existing entries. + * + * @param new_keys The new keys to insert. Shape: [n_new, num_kv_head, head_dim] + * @param new_values The new values to insert. Shape: [n_new, num_kv_head, vdim] + * @param n_new Number of new key-value pairs to insert. + * @param insert_pos Position to insert the new key-value pairs. + */ + void insert(const tensor &new_keys, const tensor &new_values, usize n_new); + + /** + * @brief get a slice of keys tensor up to the current cache size. + */ + tensor getKeysSlice(); + /** + * @brief get a slice of values tensor up to the current cache size. + */ + tensor getValuesSlice(); + + usize getCacheSize() const { return cache_size; } +}; + +} // namespace llaisys::kvcache::simple \ No newline at end of file diff --git a/src/llaisys/ops.cc b/src/llaisys/ops.cc index c99fbc32..379a66fc 100644 --- a/src/llaisys/ops.cc +++ b/src/llaisys/ops.cc @@ -9,6 +9,7 @@ #include "../ops/rearrange/op.hpp" #include "../ops/rms_norm/op.hpp" #include "../ops/rope/op.hpp" +#include "../ops/sample/op.hpp" #include "../ops/self_attention/op.hpp" #include "../ops/swiglu/op.hpp" @@ -40,4 +41,11 @@ __C { void llaisysSwiGLU(llaisysTensor_t out, llaisysTensor_t gate, llaisysTensor_t up) { llaisys::ops::swiglu(out->tensor, gate->tensor, up->tensor); } + void llaisysSample(llaisysTensor_t out, llaisysTensor_t logits, + int top_k, float top_p, float temperature) { + llaisys::ops::sample(out->tensor, logits->tensor, top_k, top_p, temperature); + } + void llaisysSampleSetSeed(uint64_t seed) { + llaisys::ops::sample_set_seed(seed); + } } diff --git a/src/llaisys/qwen2.cc b/src/llaisys/qwen2.cc new file mode 100644 index 00000000..5da50337 --- /dev/null +++ b/src/llaisys/qwen2.cc @@ -0,0 +1,605 @@ +#include "llaisys/models/qwen2.h" +#include "../kvcache/simple.hpp" +#include "../ops/ops.hpp" +#include "llaisys.h" +#include "llaisys/ops.h" +#include "llaisys/tensor.h" +#include "llaisys_tensor.hpp" + +#include +#include +#include +#include + +#define DBG_LOG false + +#define createTensor(...) \ + new LlaisysTensor { llaisys::Tensor::create(__VA_ARGS__) } + +#define loadTensor(ts) \ + [&]() { \ + auto t = new LlaisysTensor{ts}; \ + return t; \ + }() + +#define CASE(id, name, val) \ + case id: \ + do { \ + if (layer_id != -1) { \ + std::cerr << "[qwen2.cc:setWeights()] " #name \ + " should not have layer_id" \ + << std::endl; \ + exit(1); \ + } \ + auto ts = val->tensor; \ + model->weights.name = createTensor( \ + ts->shape(), ts->dtype(), model->device, model->device_id); \ + model->weights.name->tensor = ts; \ + if constexpr (DBG_LOG) { \ + std::cerr << "[qwen2.cc:setWeights()] Set " #name \ + << std::endl; \ + } \ + } while (0); \ + break; + +#define CASE_ARRAY(id, name, val) \ + case id: \ + do { \ + if (layer_id < 0 || layer_id >= static_cast(nlayer)) { \ + std::cerr << "[qwen2.cc:setWeights()] " #name \ + " layer_id out of range" \ + << std::endl; \ + exit(1); \ + } \ + auto ts = val->tensor; \ + model->weights.name[layer_id] = createTensor( \ + ts->shape(), ts->dtype(), model->device, model->device_id); \ + model->weights.name[layer_id]->tensor = ts; \ + if constexpr (DBG_LOG) { \ + std::cerr << "[qwen2.cc:setWeights()] Set " #name \ + << " for layer " << layer_id << std::endl; \ + } \ + } while (0); \ + break; + +#define MODEL_VALIDITY_CHECK(model) \ + do { \ + if (!model) { \ + std::cerr << "[qwen2.cc:infer()] Model is null, cannot perform " \ + "inference." \ + << std::endl; \ + return -1; \ + } \ + } while (0) + +#define LOG_SHAPE(stage, tensr, name) \ + do { \ + if constexpr (!DBG_LOG) \ + break; \ + std::cerr << "[qwen2.cc:" << stage << "] " << name << " shape: "; \ + for (int i = 0, l = int(tensr->shape().size()); i < l; ++i) { \ + std::cerr << tensr->shape()[i]; \ + if (i != l - 1) \ + std::cerr << " x "; \ + } \ + std::cerr << std::endl; \ + } while (0) +// Define some helper functions here, init model weights array/kvcache +// array, etc. +static void initializeArrays(LlaisysQwen2Model *model); + +static int64_t qwen2_infer_impl(struct LlaisysQwen2Model *model, + int64_t *token_ids, + int64_t *pos_ids, + size_t ntoken, + bool prefill, + int top_k, + float top_p, + float temperature); + +__C { + + struct LlaisysQwen2Model { + LlaisysQwen2Meta meta; + LlaisysQwen2Weights weights; + llaisys::kvcache::simple::KVCache **kvcaches; + llaisysDeviceType_t device; + int device_id; + }; + + __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate( + const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, + int *device_ids, int ndevice) { + if (!meta) { + std::cerr + << "[qwen2.cc:create()] Meta is null, cannot create model." + << std::endl; + return nullptr; + } + + LlaisysQwen2Model *model = new LlaisysQwen2Model(); + model->meta = *meta; + model->device = device; + model->device_id = device_ids ? device_ids[0] : 0; + + initializeArrays(model); + std::cerr << "[qwen2.cc:create()] Model created." << std::endl; + + return model; + } + + __export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model) { + if (!model) { + std::cerr + << "[qwen2.cc:destroy()] Model is null, nothing to destroy." + << std::endl; + return; + } + + // Free KV caches + for (size_t i = 0; i < model->meta.nlayer; ++i) { + delete model->kvcaches[i]; + } + delete[] model->kvcaches; + std::cerr << "[qwen2.cc:destroy()] Destroyed all KV caches." + << std::endl; + + // Free weight arrays + delete[] model->weights.attn_norm_w; + delete[] model->weights.attn_q_w; + delete[] model->weights.attn_q_b; + delete[] model->weights.attn_k_w; + delete[] model->weights.attn_k_b; + delete[] model->weights.attn_v_w; + delete[] model->weights.attn_v_b; + delete[] model->weights.attn_o_w; + delete[] model->weights.mlp_norm_w; + delete[] model->weights.mlp_gate_w; + delete[] model->weights.mlp_up_w; + delete[] model->weights.mlp_down_w; + std::cerr << "[qwen2.cc:destroy()] Destroyed all weight arrays." + << std::endl; + + delete model; + std::cerr << "[qwen2.cc:destroy()] Model destroyed." << std::endl; + } + + __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights( + struct LlaisysQwen2Model * model) { + if (!model) { + std::cerr + << "[qwen2.cc:weights()] Model is null, cannot get weights." + << std::endl; + return nullptr; + } + /** + * Display all tensor shape here for debugging + */ + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:weights()] Model weights shapes:" + << std::endl; + + LOG_SHAPE("weights()", model->weights.in_embed->tensor, "in_embed"); + LOG_SHAPE("weights()", model->weights.out_embed->tensor, "out_embed"); + LOG_SHAPE("weights()", model->weights.out_norm_w->tensor, "out_norm_w"); + + auto nlayer = model->meta.nlayer; + for (size_t i = 0; i < nlayer; ++i) { + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:weights()] Layer " << i + << " weights:" << std::endl; + LOG_SHAPE("weights()", model->weights.attn_norm_w[i]->tensor, + "attn_norm_w"); + LOG_SHAPE("weights()", model->weights.attn_q_w[i]->tensor, + "attn_q_w"); + LOG_SHAPE("weights()", model->weights.attn_q_b[i]->tensor, + "attn_q_b"); + LOG_SHAPE("weights()", model->weights.attn_k_w[i]->tensor, + "attn_k_w"); + LOG_SHAPE("weights()", model->weights.attn_k_b[i]->tensor, + "attn_k_b"); + LOG_SHAPE("weights()", model->weights.attn_v_w[i]->tensor, + "attn_v_w"); + LOG_SHAPE("weights()", model->weights.attn_v_b[i]->tensor, + "attn_v_b"); + LOG_SHAPE("weights()", model->weights.attn_o_w[i]->tensor, + "attn_o_w"); + LOG_SHAPE("weights()", model->weights.mlp_norm_w[i]->tensor, + "mlp_norm_w"); + LOG_SHAPE("weights()", model->weights.mlp_gate_w[i]->tensor, + "mlp_gate_w"); + LOG_SHAPE("weights()", model->weights.mlp_up_w[i]->tensor, + "mlp_up_w"); + LOG_SHAPE("weights()", model->weights.mlp_down_w[i]->tensor, + "mlp_down_w"); + } + + return &model->weights; + } + + __export void llaisysQwen2SetWeights(struct LlaisysQwen2Model * model, + int name, int layer_id, + llaisysTensor_t tensor) { + if (!model) { + std::cerr + << "[qwen2.cc:setWeights()] Model is null, cannot set weights." + << std::endl; + return; + } + + size_t nlayer = model->meta.nlayer; + switch (name) { + CASE(0, in_embed, tensor) // in_embed + CASE(1, out_embed, tensor) // out_embed + CASE(2, out_norm_w, tensor) // out_norm_w + CASE_ARRAY(3, attn_norm_w, tensor) // attn_norm_w + CASE_ARRAY(4, attn_q_w, tensor) // attn_q_w + CASE_ARRAY(5, attn_q_b, tensor) // attn_q_b + CASE_ARRAY(6, attn_k_w, tensor) // attn_k_w + CASE_ARRAY(7, attn_k_b, tensor) // attn_k_b + CASE_ARRAY(8, attn_v_w, tensor) // attn_v_w + CASE_ARRAY(9, attn_v_b, tensor) // attn_v_b + CASE_ARRAY(10, attn_o_w, tensor) // attn_o_w + CASE_ARRAY(11, mlp_norm_w, tensor) // mlp_norm_w + CASE_ARRAY(12, mlp_gate_w, tensor) // mlp_gate_w + CASE_ARRAY(13, mlp_up_w, tensor) // mlp_up_w + CASE_ARRAY(14, mlp_down_w, tensor) // mlp_down_w + default: + std::cerr << "[qwen2.cc:setWeights()] Unknown weight name: " << name + << ", cannot set weight." << std::endl; + exit(1); + } + + LOG_SHAPE("setWeights()", tensor->tensor, "weight tensor from Python"); + } + + __export int64_t llaisysQwen2ModelInfer( + struct LlaisysQwen2Model * model, int64_t *token_ids, int64_t *pos_ids, + size_t ntoken, bool prefill) { + // Keep old API behavior: greedy decoding via top_k=1. + return qwen2_infer_impl(model, token_ids, pos_ids, ntoken, prefill, + 1, 1.0f, 1.0f); + } + + __export int64_t llaisysQwen2ModelInferSample( + struct LlaisysQwen2Model * model, int64_t *token_ids, int64_t *pos_ids, + size_t ntoken, bool prefill, int top_k, float top_p, + float temperature) { + return qwen2_infer_impl(model, token_ids, pos_ids, ntoken, prefill, + top_k, top_p, temperature); + } +} + +static int64_t qwen2_infer_impl(struct LlaisysQwen2Model * model, + int64_t *token_ids, int64_t *pos_ids, + size_t ntoken, bool prefill, int top_k, + float top_p, float temperature) { + //* -1. Do checking + MODEL_VALIDITY_CHECK(model); + if (ntoken == 0) { + std::cerr << "[qwen2.cc:infer()] ntoken must be > 0." << std::endl; + return -1; + } + if (top_k < 0 || top_p <= 0.0f || top_p > 1.0f || temperature <= 0.0f) { + std::cerr << "[qwen2.cc:infer()] Invalid sampling params: top_k=" + << top_k << ", top_p=" << top_p + << ", temperature=" << temperature << std::endl; + return -1; + } + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Start inference." << std::endl; + + //* 0. If prefill, clean KV Caches + if (prefill) { + if constexpr (DBG_LOG) + std::cerr + << "[qwen2.cc:infer()] Prefill mode: resetting KV caches." + << std::endl; + for (size_t i = 0; i < model->meta.nlayer; ++i) + model->kvcaches[i]->reset(); + } + + //* 1. Copy inputs into tensor + using namespace llaisys; + using tensor = tensor_t; + using usize = size_t; + using i64 = int64_t; + + tensor pos_ids_tensor = Tensor::create({ntoken}, LLAISYS_DTYPE_I64, + model->device, model->device_id); + pos_ids_tensor->load(pos_ids); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Loaded position ids." << std::endl; + LOG_SHAPE("infer()", pos_ids_tensor, "pos_ids"); + + tensor input_tokens = Tensor::create({ntoken}, LLAISYS_DTYPE_I64, + model->device, model->device_id); + input_tokens->load(token_ids); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Loaded input token ids." + << std::endl; + LOG_SHAPE("infer()", input_tokens, "input_tokens"); + + //* 2. Token Embedding + tensor hidden_states + = Tensor::create({ntoken, model->meta.hs}, model->meta.dtype, + model->device, model->device_id); + ops::embedding(hidden_states, input_tokens, + model->weights.in_embed->tensor); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Completed token embedding." + << std::endl; + LOG_SHAPE("infer()", hidden_states, "hidden_states"); + + //* 3. Attention Layers + for (usize layer = 0; layer < model->meta.nlayer; ++layer) { + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": (Mock) Completed layer operation." << std::endl; + + //* 3.a Record a residual + tensor residual = hidden_states; + + //* 3.b RMS Norm before Attention + tensor attn_normed + = Tensor::create({ntoken, model->meta.hs}, model->meta.dtype, + model->device, model->device_id); + ops::rms_norm(attn_normed, hidden_states, + model->weights.attn_norm_w[layer]->tensor, + model->meta.epsilon); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Completed RMS norm before attention." + << std::endl; + + //* 3.c QKV Projection + tensor q_proj = Tensor::create( + {ntoken, model->meta.nh * model->meta.dh}, model->meta.dtype, + model->device, model->device_id); + tensor k_proj = Tensor::create( + {ntoken, model->meta.nkvh * model->meta.dh}, model->meta.dtype, + model->device, model->device_id); + tensor v_proj = Tensor::create( + {ntoken, model->meta.nkvh * model->meta.dh}, model->meta.dtype, + model->device, model->device_id); + ops::linear(q_proj, attn_normed, + model->weights.attn_q_w[layer]->tensor, + model->weights.attn_q_b[layer]->tensor); + ops::linear(k_proj, attn_normed, + model->weights.attn_k_w[layer]->tensor, + model->weights.attn_k_b[layer]->tensor); + ops::linear(v_proj, attn_normed, + model->weights.attn_v_w[layer]->tensor, + model->weights.attn_v_b[layer]->tensor); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Completed QKV projection." << std::endl; + + //* 3.c.1 Reshape to (S, H, D) + tensor qview + = q_proj->view({ntoken, model->meta.nh, model->meta.dh}); + tensor kview + = k_proj->view({ntoken, model->meta.nkvh, model->meta.dh}); + tensor vview + = v_proj->view({ntoken, model->meta.nkvh, model->meta.dh}); + + //* 3.d RoPE Encoding for Q, K + tensor pos_q = Tensor::create( + {ntoken, model->meta.nh, model->meta.dh}, model->meta.dtype, + model->device, model->device_id); + tensor pos_k = Tensor::create( + {ntoken, model->meta.nkvh, model->meta.dh}, model->meta.dtype, + model->device, model->device_id); + ops::rope(pos_q, qview, pos_ids_tensor, model->meta.theta); + ops::rope(pos_k, kview, pos_ids_tensor, model->meta.theta); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Completed RoPE encoding for Q and K." + << std::endl; + + //* 3.e Update KV Cache + model->kvcaches[layer]->insert(pos_k, vview, ntoken); + + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Updated KV cache." << std::endl; + + //* 3.f Self attention + float scale = 1.0f / std::sqrt(static_cast(model->meta.dh)); + tensor kcache = model->kvcaches[layer]->getKeysSlice(); + if (prefill) { + // ASSERT(kcache->shape() == pos_k->shape(), + // "K cache shape mismatch!"); + } else { + if constexpr (DBG_LOG) + std::cerr + << "[qwen2.cc:infer()] Decode mode - pos_k shape: [" + << pos_k->shape()[0] << ", " << pos_k->shape()[1] + << ", " << pos_k->shape()[2] << "], " + << "kcache shape: [" << kcache->shape()[0] << ", " + << kcache->shape()[1] << ", " << kcache->shape()[2] + << "]" << std::endl; + } + tensor vcache = model->kvcaches[layer]->getValuesSlice(); + + tensor attn_out = Tensor::create( + {ntoken, model->meta.nh, model->meta.dh}, model->meta.dtype, + model->device, model->device_id); + ops::self_attention(attn_out, pos_q, kcache, vcache, scale); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Completed self-attention computation." + << std::endl; + + //* 3.g Output Projection + tensor attn_proj + = Tensor::create({ntoken, model->meta.hs}, model->meta.dtype, + model->device, model->device_id); + ops::linear( + attn_proj, + attn_out->view({ntoken, model->meta.nh * model->meta.dh}), + model->weights.attn_o_w[layer]->tensor, nullptr); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Completed attention output projection." + << std::endl; + + //* 3.h Residual after attention + ops::add(hidden_states, residual, attn_proj); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Completed residual addition after attention." + << std::endl; + residual = hidden_states; + + //* 3.i MLP block + tensor mlp_normed + = Tensor::create({ntoken, model->meta.hs}, model->meta.dtype, + model->device, model->device_id); + ops::rms_norm(mlp_normed, hidden_states, + model->weights.mlp_norm_w[layer]->tensor, + model->meta.epsilon); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Completed RMS norm before MLP." << std::endl; + + //* 3.j MLP projections + tensor mlp_gate + = Tensor::create({ntoken, model->meta.di}, model->meta.dtype, + model->device, model->device_id); + tensor mlp_up + = Tensor::create({ntoken, model->meta.di}, model->meta.dtype, + model->device, model->device_id); + ops::linear(mlp_gate, mlp_normed, + model->weights.mlp_gate_w[layer]->tensor, nullptr); + ops::linear(mlp_up, mlp_normed, + model->weights.mlp_up_w[layer]->tensor, nullptr); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Completed MLP gate and up projections." + << std::endl; + + tensor mlp_down + = Tensor::create({ntoken, model->meta.di}, model->meta.dtype, + model->device, model->device_id); + ops::swiglu(mlp_down, mlp_gate, mlp_up); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Completed SwiGLU activation." << std::endl; + + //* 3.k Final MLP output projection + tensor mlp_out + = Tensor::create({ntoken, model->meta.hs}, model->meta.dtype, + model->device, model->device_id); + ops::linear(mlp_out, mlp_down, + model->weights.mlp_down_w[layer]->tensor, nullptr); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Completed final MLP output projection." + << std::endl; + + //* 3.l Final residual addition + ops::add(hidden_states, residual, mlp_out); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Layer " << layer + << ": Completed final residual addition." + << std::endl; + } + + //* 4. Final transform and output projection + tensor final_norm + = Tensor::create({ntoken, model->meta.hs}, model->meta.dtype, + model->device, model->device_id); + ops::rms_norm(final_norm, hidden_states, + model->weights.out_norm_w->tensor, model->meta.epsilon); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Completed final RMS norm." + << std::endl; + + //* 5. Get logits + tensor logits + = Tensor::create({ntoken, model->meta.voc}, model->meta.dtype, + model->device, model->device_id); + ops::linear(logits, final_norm, model->weights.out_embed->tensor, + nullptr); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Completed output projection " + "to logits." + << std::endl; + + //* 6. Get the last token's logits and argmax + tensor last_token_logits + = logits->slice(0, ntoken - 1, ntoken); // last token + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Sliced out last token logits." + << std::endl; + LOG_SHAPE("infer()", logits, "logits"); + LOG_SHAPE("infer()", last_token_logits, "last_token_logits"); + + //* 7. Sample to get next token id (top_k=1 equals argmax) + // sample() currently supports CPU only; move logits to host when needed. + tensor sample_logits = last_token_logits; + if (last_token_logits->deviceType() != LLAISYS_DEVICE_CPU) { + sample_logits = Tensor::create( + last_token_logits->shape(), last_token_logits->dtype(), + LLAISYS_DEVICE_CPU, 0); + llaisys::core::context().setDevice(model->device, model->device_id); + llaisys::core::context().runtime().api()->memcpy_sync( + sample_logits->data(), last_token_logits->data(), + last_token_logits->numel() * last_token_logits->elementSize(), + LLAISYS_MEMCPY_D2H); + } + + tensor next_token_id_tensor + = Tensor::create({1}, LLAISYS_DTYPE_I64, LLAISYS_DEVICE_CPU, 0); + ops::sample(next_token_id_tensor, sample_logits, top_k, top_p, + temperature); + if constexpr (DBG_LOG) + std::cerr << "[qwen2.cc:infer()] Completed sampling to get next " + "token id: " + << *((i64 *)next_token_id_tensor->data()) << std::endl; + + // NOTE: keep a safe read path for both host/device output tensors. + i64 next_token_id = -1; + if (next_token_id_tensor->deviceType() == LLAISYS_DEVICE_CPU) { + next_token_id = *((i64 *)next_token_id_tensor->data()); + } else { + llaisys::core::context().setDevice(model->device, model->device_id); + llaisys::core::context().runtime().api()->memcpy_sync( + &next_token_id, next_token_id_tensor->data(), sizeof(i64), + LLAISYS_MEMCPY_D2H); + } + return next_token_id; +} + +static void initializeArrays(LlaisysQwen2Model *model) { + size_t nlayer = model->meta.nlayer; + + // KV Cache init + model->kvcaches = new llaisys::kvcache::simple::KVCache *[nlayer]; + for (size_t i = 0; i < nlayer; ++i) { + model->kvcaches[i] = new llaisys::kvcache::simple::KVCache( + model->meta.maxseq, model->meta.nkvh, model->meta.dh, + model->meta.dh, model->meta.dtype, model->device, model->device_id); + } + + std::cerr << "[qwen2.cc:initializeArrays()] Initialized KV caches for " + << nlayer << " layers." << std::endl; + + // Weight init + model->weights.attn_norm_w = new llaisysTensor_t[nlayer]; + model->weights.attn_q_w = new llaisysTensor_t[nlayer]; + model->weights.attn_q_b = new llaisysTensor_t[nlayer]; + model->weights.attn_k_w = new llaisysTensor_t[nlayer]; + model->weights.attn_k_b = new llaisysTensor_t[nlayer]; + model->weights.attn_v_w = new llaisysTensor_t[nlayer]; + model->weights.attn_v_b = new llaisysTensor_t[nlayer]; + model->weights.attn_o_w = new llaisysTensor_t[nlayer]; + model->weights.mlp_norm_w = new llaisysTensor_t[nlayer]; + model->weights.mlp_gate_w = new llaisysTensor_t[nlayer]; + model->weights.mlp_up_w = new llaisysTensor_t[nlayer]; + model->weights.mlp_down_w = new llaisysTensor_t[nlayer]; + + std::cerr << "[qwen2.cc:initializeArrays()] Initialized Qwen2 model with " + << nlayer << " layers." << std::endl; +} \ No newline at end of file diff --git a/src/ops/add/nvidia/add_cu.cu b/src/ops/add/nvidia/add_cu.cu new file mode 100644 index 00000000..035a4e9b --- /dev/null +++ b/src/ops/add/nvidia/add_cu.cu @@ -0,0 +1,76 @@ +#include +#include +#include + +#include "../../../utils/check.hpp" +#include "../../../utils/types.hpp" +#include "llaisys.h" + +namespace addops::nvidia { + +template +__global__ void add_kernel(T *c, const T *a, const T *b, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] + b[idx]; + } +} + +template <> +__global__ void add_kernel<__half>(__half *c, const __half *a, const __half *b, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __hadd(a[idx], b[idx]); + } +} + +template <> +__global__ void add_kernel<__nv_bfloat16>(__nv_bfloat16 *c, const __nv_bfloat16 *a, const __nv_bfloat16 *b, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __hadd(a[idx], b[idx]); + } +} + +} // namespace addops::nvidia + +namespace llaisys::ops::nvidia { + +void add(std::byte *c, + const std::byte *a, + const std::byte *b, + llaisysDataType_t type, + size_t n) { + int threadsPerBlock = 256; + int blocksPerGrid = (n + threadsPerBlock - 1) / threadsPerBlock; + + switch (type) { + case LLAISYS_DTYPE_F32: + addops::nvidia::add_kernel<<>>( + reinterpret_cast(c), + reinterpret_cast(a), + reinterpret_cast(b), + n); + break; + case LLAISYS_DTYPE_F16: + // fp16_t has the same 16-bit IEEE 754 layout as __half + addops::nvidia::add_kernel<__half><<>>( + reinterpret_cast<__half *>(c), + reinterpret_cast(a), + reinterpret_cast(b), + n); + break; + case LLAISYS_DTYPE_BF16: + // bf16_t has the same 16-bit bfloat layout as __nv_bfloat16 + addops::nvidia::add_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16 *>(c), + reinterpret_cast(a), + reinterpret_cast(b), + n); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::nvidia \ No newline at end of file diff --git a/src/ops/add/nvidia/add_cu.cuh b/src/ops/add/nvidia/add_cu.cuh new file mode 100644 index 00000000..c90eb1a7 --- /dev/null +++ b/src/ops/add/nvidia/add_cu.cuh @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t n); +} // namespace llaisys::ops::nvidia diff --git a/src/ops/add/op.cpp b/src/ops/add/op.cpp index a057330d..176e856e 100644 --- a/src/ops/add/op.cpp +++ b/src/ops/add/op.cpp @@ -4,6 +4,9 @@ #include "../../utils.hpp" #include "cpu/add_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/add_cu.cuh" +#endif namespace llaisys::ops { void add(tensor_t c, tensor_t a, tensor_t b) { @@ -25,8 +28,7 @@ void add(tensor_t c, tensor_t a, tensor_t b) { return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #ifdef ENABLE_NVIDIA_API case LLAISYS_DEVICE_NVIDIA: - TO_BE_IMPLEMENTED(); - return; + return nvidia::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #endif default: EXCEPTION_UNSUPPORTED_DEVICE; diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 00000000..89dbe05e --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,46 @@ +#include "argmax_cpu.hpp" +#include "../../../utils.hpp" +#include + +template +static void argmax_impl(int64_t *max_idx, T *max_val, const T *vals, size_t numel) { + T current_max = vals[0]; + int64_t current_max_idx = 0; + for (size_t i = 1; i < numel; ++i) { + if (casting(float, vals[i]) > casting(float, current_max)) { + current_max = vals[i]; + current_max_idx = static_cast(i); + } + } + *max_val = current_max; + *max_idx = current_max_idx; +} + +namespace llaisys::ops::cpu { + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, size_t numel, llaisysDataType_t dtype) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + argmax_impl(reinterpret_cast(max_idx), + reinterpret_cast(max_val), + reinterpret_cast(vals), + numel); + break; + case LLAISYS_DTYPE_F16: + argmax_impl(reinterpret_cast(max_idx), + reinterpret_cast(max_val), + reinterpret_cast(vals), + numel); + break; + case LLAISYS_DTYPE_BF16: + argmax_impl(reinterpret_cast(max_idx), + reinterpret_cast(max_val), + reinterpret_cast(vals), + numel); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 00000000..61d891bc --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, size_t numel, llaisysDataType_t dtype); + +} \ No newline at end of file diff --git a/src/ops/argmax/nvidia/argmax_cu.cu b/src/ops/argmax/nvidia/argmax_cu.cu new file mode 100644 index 00000000..a90ac959 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_cu.cu @@ -0,0 +1,94 @@ +#include "argmax_cu.cuh" + +#include +#include +#include +#include + +#include "../../../utils/check.hpp" +#include "../../../utils/types.hpp" + +namespace argmaxops::nvidia { + +__device__ inline float to_float(float x) { return x; } +__device__ inline float to_float(__half x) { return __half2float(x); } +__device__ inline float to_float(__nv_bfloat16 x) { return __bfloat162float(x); } + +// Single-block parallel reduction argmax. +// Each thread strides through the input, keeping a local max, then the block +// reduces via shared memory. One block is sufficient for typical vocab sizes. +template +__global__ void +argmax_kernel(int64_t *max_idx, T *max_val, const T *vals, size_t numel) { + extern __shared__ char smem[]; + float *smem_val = reinterpret_cast(smem); + int64_t *smem_idx = reinterpret_cast(smem_val + blockDim.x); + + unsigned tid = threadIdx.x; + float local_max = -FLT_MAX; + int64_t local_idx = 0; + + for (size_t i = tid; i < numel; i += blockDim.x) { + float v = to_float(vals[i]); + if (v > local_max) { + local_max = v; + local_idx = static_cast(i); + } + } + + smem_val[tid] = local_max; + smem_idx[tid] = local_idx; + __syncthreads(); + + for (unsigned s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + if (smem_val[tid + s] > smem_val[tid]) { + smem_val[tid] = smem_val[tid + s]; + smem_idx[tid] = smem_idx[tid + s]; + } + } + __syncthreads(); + } + + if (tid == 0) { + *max_val = static_cast(smem_val[0]); + *max_idx = smem_idx[0]; + } +} + +} // namespace argmaxops::nvidia + +namespace llaisys::ops::nvidia { + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, size_t numel, llaisysDataType_t dtype) { + constexpr int threads = 1024; + size_t smem_size = threads * (sizeof(float) + sizeof(int64_t)); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + argmaxops::nvidia::argmax_kernel<<<1, threads, smem_size>>>( + reinterpret_cast(max_idx), + reinterpret_cast(max_val), + reinterpret_cast(vals), + numel); + break; + case LLAISYS_DTYPE_F16: + argmaxops::nvidia::argmax_kernel<__half><<<1, threads, smem_size>>>( + reinterpret_cast(max_idx), + reinterpret_cast<__half *>(max_val), + reinterpret_cast(vals), + numel); + break; + case LLAISYS_DTYPE_BF16: + argmaxops::nvidia::argmax_kernel<__nv_bfloat16><<<1, threads, smem_size>>>( + reinterpret_cast(max_idx), + reinterpret_cast<__nv_bfloat16 *>(max_val), + reinterpret_cast(vals), + numel); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia \ No newline at end of file diff --git a/src/ops/argmax/nvidia/argmax_cu.cuh b/src/ops/argmax/nvidia/argmax_cu.cuh new file mode 100644 index 00000000..d389c524 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_cu.cuh @@ -0,0 +1,10 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, size_t numel, llaisysDataType_t dtype); + +} \ No newline at end of file diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d42..d2048631 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,31 @@ #include "op.hpp" +#include "../../utils.hpp" +#include "cpu/argmax_cpu.hpp" +#include "llaisys.h" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/argmax_cu.cuh" +#endif namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + ASSERT(max_idx->numel() == 1, + "argmax(): max_idx must have exactly one element"); + ASSERT(max_val->numel() == 1, + "argmax(): max_val must have exactly one element"); + ASSERT(vals->isContiguous(), "argmax(): input tensor must be contiguous"); + + llaisysDeviceType_t device = vals->deviceType(); + switch (device) { + case LLAISYS_DEVICE_CPU: + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), + vals->numel(), vals->dtype()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::argmax(max_idx->data(), max_val->data(), vals->data(), + vals->numel(), vals->dtype()); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 00000000..42089249 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,46 @@ +#include "embedding_cpu.hpp" +#include "../../../utils.hpp" +#include "llaisys.h" +#include +#include +#include +#include + +template +static void +embedding_impl(T *output, const int64_t *indices, const T *weights, size_t num_indices, size_t embedding_dim) { +#pragma omp parallel for + for (size_t i = 0; i < num_indices; i++) { + int64_t idx = indices[i]; + const T *weight_begin = weights + idx * embedding_dim; + T *output_begin = output + i * embedding_dim; + std::memcpy(output_begin, weight_begin, embedding_dim * sizeof(T)); + } +} + +namespace llaisys::ops::cpu { + +void embedding(std::byte *output, + const std::byte *indices, + const std::byte *weights, + size_t num_indices, + size_t embedding_dim, + llaisysDataType_t dtype) { + using namespace llaisys::utils; + + switch (dtype) { + case LLAISYS_DTYPE_F32: + return embedding_impl(recast(float *, output), recast(const int64_t *, indices), recast(const float *, weights), + num_indices, embedding_dim); + case LLAISYS_DTYPE_F16: + return embedding_impl(recast(fp16_t *, output), recast(const int64_t *, indices), + recast(const fp16_t *, weights), num_indices, embedding_dim); + case LLAISYS_DTYPE_BF16: + return embedding_impl(recast(bf16_t *, output), recast(const int64_t *, indices), + recast(const bf16_t *, weights), num_indices, embedding_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 00000000..4e79ad9b --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + +void embedding(std::byte *output, + const std::byte *indices, + const std::byte *weights, + size_t num_indices, + size_t embedding_dim, + llaisysDataType_t dtype); + +} \ No newline at end of file diff --git a/src/ops/embedding/nvidia/embedding_cu.cu b/src/ops/embedding/nvidia/embedding_cu.cu new file mode 100644 index 00000000..d22a0d23 --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_cu.cu @@ -0,0 +1,66 @@ +#include "embedding_cu.cuh" + +#include +#include +#include + +#include "../../../utils/check.hpp" +#include "../../../utils/types.hpp" + +namespace embeddingops::nvidia { + +// Each block handles one output token; threads stride across the embedding dimension. +template +__global__ void +embedding_kernel(T *output, const int64_t *indices, const T *weights, + size_t num_indices, size_t embedding_dim) { + size_t token = blockIdx.x; + if (token >= num_indices) return; + int64_t w_idx = indices[token]; + for (size_t d = threadIdx.x; d < embedding_dim; d += blockDim.x) + output[token * embedding_dim + d] = weights[w_idx * embedding_dim + d]; +} + +} // namespace embeddingops::nvidia + +namespace llaisys::ops::nvidia { + +void embedding(std::byte *output, + const std::byte *indices, + const std::byte *weights, + size_t num_indices, + size_t embedding_dim, + llaisysDataType_t dtype) { + int threads = static_cast(embedding_dim < 1024u ? embedding_dim : 1024u); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + embeddingops::nvidia::embedding_kernel + <<(num_indices), threads>>>( + reinterpret_cast(output), + reinterpret_cast(indices), + reinterpret_cast(weights), + num_indices, embedding_dim); + break; + case LLAISYS_DTYPE_F16: + embeddingops::nvidia::embedding_kernel<__half> + <<(num_indices), threads>>>( + reinterpret_cast<__half *>(output), + reinterpret_cast(indices), + reinterpret_cast(weights), + num_indices, embedding_dim); + break; + case LLAISYS_DTYPE_BF16: + embeddingops::nvidia::embedding_kernel<__nv_bfloat16> + <<(num_indices), threads>>>( + reinterpret_cast<__nv_bfloat16 *>(output), + reinterpret_cast(indices), + reinterpret_cast(weights), + num_indices, embedding_dim); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/embedding/nvidia/embedding_cu.cuh b/src/ops/embedding/nvidia/embedding_cu.cuh new file mode 100644 index 00000000..ab65b4ce --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_cu.cuh @@ -0,0 +1,15 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { + +void embedding(std::byte *output, + const std::byte *indices, + const std::byte *weights, + size_t num_indices, + size_t embedding_dim, + llaisysDataType_t dtype); + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d0..f99d84d2 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,32 @@ #include "op.hpp" +#include "cpu/embedding_cpu.hpp" +#include "llaisys.h" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/embedding_cu.cuh" +#endif namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, index, weight); + CHECK_SAME_DTYPE(out->dtype(), weight->dtype()); + ASSERT(index->dtype() == LLAISYS_DTYPE_I64, "embedding(): index.dtype must be int64"); + ASSERT(out->isContiguous(), "embedding(): output tensor must be contiguous"); + ASSERT(index->isContiguous(), "embedding(): index tensor must be contiguous"); + ASSERT(weight->isContiguous(), "embedding(): weight tensor must be contiguous"); + + auto device = out->deviceType(); + core::context().setDevice(device, out->deviceId()); + + auto embedding_dim = weight->shape().back(); + if (device == LLAISYS_DEVICE_CPU) { + llaisys::ops::cpu::embedding(out->data(), index->data(), weight->data(), index->numel(), embedding_dim, + out->dtype()); +#ifdef ENABLE_NVIDIA_API + } else if (device == LLAISYS_DEVICE_NVIDIA) { + llaisys::ops::nvidia::embedding(out->data(), index->data(), weight->data(), index->numel(), embedding_dim, + out->dtype()); +#endif + } else + EXCEPTION_UNSUPPORTED_DEVICE; } } // namespace llaisys::ops diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 00000000..e05de6a7 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,91 @@ +#include "linear_cpu.hpp" +#include "../../../utils.hpp" +#include "llaisys.h" +#include +#include + +template +static void linear_impl(T *output, + const T *input, + const T *weight, + const T *bias, + size_t N, + size_t M, + size_t K) { +#pragma omp parallel for collapse(2) schedule(static) + for (size_t n = 0; n < N; n++) { + for (size_t k = 0; k < K; k++) { + double sum = 0.0; + +#pragma omp simd reduction(+ : sum) + for (size_t m = 0; m < M; m++) + sum += casting(double, input[n * M + m]) + * casting(double, weight[k * M + m]); + + if (bias != nullptr) + sum += casting(double, bias[k]); + output[n * K + k] = casting(T, static_cast(sum)); + } + } +} + +namespace linear::naive { + +template +void linear(T *output, + const T *input, + const T *weight, + const T *bias, + size_t N, + size_t M, + size_t K) { + for (size_t n = 0; n < N; n++) { + for (size_t k = 0; k < K; k++) { + long double sum = 0.0; + for (size_t m = 0; m < M; m++) + sum += casting(long double, input[n * M + m]) + * casting(long double, weight[k * M + m]); + if (bias != nullptr) + sum += casting(long double, bias[k]); + output[n * K + k] = casting(T, static_cast(sum)); + } + } +} + +} // namespace linear::naive + +namespace llaisys::ops::cpu { + +void linear(std::byte *output, + const std::byte *input, + const std::byte *weight, + const std::byte *bias, + size_t N, + size_t M, + size_t K, + llaisysDataType_t dtype) { + + using namespace llaisys; + + switch (dtype) { + case LLAISYS_DTYPE_F32: + return linear_impl(recast(float *, output), + recast(const float *, input), + recast(const float *, weight), + recast(const float *, bias), N, M, K); + case LLAISYS_DTYPE_F16: + return linear_impl(recast(fp16_t *, output), + recast(const fp16_t *, input), + recast(const fp16_t *, weight), + recast(const fp16_t *, bias), N, M, K); + case LLAISYS_DTYPE_BF16: + return linear_impl(recast(bf16_t *, output), + recast(const bf16_t *, input), + recast(const bf16_t *, weight), + recast(const bf16_t *, bias), N, M, K); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 00000000..c5585c27 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + +/** + * MatMul: input=[N, M] * weight=[K, M], bias=None or [N, 1] -> output=[N, K] + */ +void linear(std::byte *output, + const std::byte *input, + const std::byte *weight, + const std::byte *bias, + size_t N, + size_t M, + size_t K, + llaisysDataType_t dtype); + +} \ No newline at end of file diff --git a/src/ops/linear/nvidia/linear_cu.cu b/src/ops/linear/nvidia/linear_cu.cu new file mode 100644 index 00000000..25465421 --- /dev/null +++ b/src/ops/linear/nvidia/linear_cu.cu @@ -0,0 +1,88 @@ +#include "linear_cu.cuh" + +#include +#include +#include + +#include "../../../utils/check.hpp" +#include "../../../utils/types.hpp" + +namespace linearops::nvidia { + +__device__ inline float to_float(float x) { return x; } +__device__ inline float to_float(__half x) { return __half2float(x); } +__device__ inline float to_float(__nv_bfloat16 x) { return __bfloat162float(x); } + +template +__device__ inline T from_float(float x) { return static_cast(x); } +template <> +__device__ inline __half from_float<__half>(float x) { return __float2half(x); } +template <> +__device__ inline __nv_bfloat16 from_float<__nv_bfloat16>(float x) { return __float2bfloat16(x); } + +// output[N,K] = input[N,M] @ weight[K,M]^T + bias[K] +// blockIdx.y = n (batch row), blockIdx.x * blockDim.x + threadIdx.x = k (output col) +template +__global__ void +linear_kernel(T *output, const T *input, const T *weight, const T *bias, + size_t N, size_t M, size_t K) { + size_t n = blockIdx.y; + size_t k = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (n >= N || k >= K) return; + + float sum = 0.0f; + for (size_t m = 0; m < M; ++m) + sum += to_float(input[n * M + m]) * to_float(weight[k * M + m]); + if (bias != nullptr) + sum += to_float(bias[k]); + output[n * K + k] = from_float(sum); +} + +} // namespace linearops::nvidia + +namespace llaisys::ops::nvidia { + +void linear(std::byte *output, + const std::byte *input, + const std::byte *weight, + const std::byte *bias, + size_t N, + size_t M, + size_t K, + llaisysDataType_t dtype) { + constexpr int block_k = 128; + dim3 block(block_k); + dim3 grid(static_cast((K + block_k - 1) / block_k), + static_cast(N)); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + linearops::nvidia::linear_kernel<<>>( + reinterpret_cast(output), + reinterpret_cast(input), + reinterpret_cast(weight), + reinterpret_cast(bias), + N, M, K); + break; + case LLAISYS_DTYPE_F16: + linearops::nvidia::linear_kernel<__half><<>>( + reinterpret_cast<__half *>(output), + reinterpret_cast(input), + reinterpret_cast(weight), + reinterpret_cast(bias), + N, M, K); + break; + case LLAISYS_DTYPE_BF16: + linearops::nvidia::linear_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16 *>(output), + reinterpret_cast(input), + reinterpret_cast(weight), + reinterpret_cast(bias), + N, M, K); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/linear/nvidia/linear_cu.cuh b/src/ops/linear/nvidia/linear_cu.cuh new file mode 100644 index 00000000..c48ed7f8 --- /dev/null +++ b/src/ops/linear/nvidia/linear_cu.cuh @@ -0,0 +1,17 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { + +void linear(std::byte *output, + const std::byte *input, + const std::byte *weight, + const std::byte *bias, + size_t N, + size_t M, + size_t K, + llaisysDataType_t dtype); + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f865..6f618ee8 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,41 @@ #include "op.hpp" +#include "cpu/linear_cpu.hpp" +#include "llaisys.h" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/linear_cu.cuh" +#endif namespace llaisys::ops { + void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight); + if (bias) + CHECK_SAME_DEVICE(out, bias); + + CHECK_SAME_DTYPE(out->dtype(), in->dtype(), weight->dtype()); + if (bias) + CHECK_SAME_DTYPE(out->dtype(), bias->dtype()); + + auto N = out->shape()[0]; + auto K = out->shape()[1]; + auto M = in->shape()[1]; + + ASSERT(in->shape()[0] == N && in->shape()[1] == M, "linear(): Input shape mismatch for linear op"); + ASSERT(weight->shape()[0] == K && weight->shape()[1] == M, "linear(): Weight shape mismatch for linear op"); + + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous() && (!bias || bias->isContiguous()), + "linear(): All tensors must be contiguous"); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + llaisys::ops::cpu::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, N, M, K, + out->dtype()); +#ifdef ENABLE_NVIDIA_API + } else if (out->deviceType() == LLAISYS_DEVICE_NVIDIA) { + llaisys::ops::nvidia::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, N, M, K, + out->dtype()); +#endif + } else + EXCEPTION_UNSUPPORTED_DEVICE; } + } // namespace llaisys::ops diff --git a/src/ops/ops.hpp b/src/ops/ops.hpp new file mode 100644 index 00000000..54db7ebf --- /dev/null +++ b/src/ops/ops.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "add/op.hpp" +#include "argmax/op.hpp" +#include "embedding/op.hpp" +#include "linear/op.hpp" +// #include "rearrange/op.hpp" +#include "rms_norm/op.hpp" +#include "rope/op.hpp" +#include "sample/op.hpp" +#include "self_attention/op.hpp" +#include "swiglu/op.hpp" \ No newline at end of file diff --git a/src/ops/rms_norm/cpu/rms_cpu.cpp b/src/ops/rms_norm/cpu/rms_cpu.cpp new file mode 100644 index 00000000..02844852 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_cpu.cpp @@ -0,0 +1,56 @@ +#include "rms_cpu.hpp" +#include "../../../utils.hpp" +#include +#include + +template +static void rms_norm_impl( + T *output, const T *input, const T *weight, size_t N, size_t M, float eps) { +#pragma omp parallel for + for (size_t i = 0; i < N; i++) { + long double sum = 0; +#pragma omp simd reduction(+ : sum) + for (size_t j = 0; j < M; j++) { + long double val = casting(long double, input[i * M + j]); + sum += val * val; + } + long double rms = std::sqrt(sum / static_cast(M) + eps); +#pragma omp simd + for (size_t j = 0; j < M; j++) { + long double x = casting(long double, input[i * M + j]); + long double w = casting(long double, weight[j]); + output[i * M + j] = casting(T, static_cast(w * x / rms)); + } + } +} + +namespace llaisys::ops::cpu { + +void rms_norm(std::byte *output, + const std::byte *input, + const std::byte *weight, + size_t N, + size_t M, + float eps, + llaisysDataType_t dtype) { + using namespace llaisys; + + switch (dtype) { + case LLAISYS_DTYPE_F32: + return rms_norm_impl(recast(float *, output), + recast(const float *, input), + recast(const float *, weight), N, M, eps); + case LLAISYS_DTYPE_F16: + return rms_norm_impl(recast(fp16_t *, output), + recast(const fp16_t *, input), + recast(const fp16_t *, weight), N, M, eps); + case LLAISYS_DTYPE_BF16: + return rms_norm_impl(recast(bf16_t *, output), + recast(const bf16_t *, input), + recast(const bf16_t *, weight), N, M, eps); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rms_norm/cpu/rms_cpu.hpp b/src/ops/rms_norm/cpu/rms_cpu.hpp new file mode 100644 index 00000000..9d5b2d74 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_cpu.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + +void rms_norm(std::byte *output, + const std::byte *input, + const std::byte *weight, + size_t N, + size_t M, + float eps, + llaisysDataType_t dtype); + +} \ No newline at end of file diff --git a/src/ops/rms_norm/nvidia/rms_norm_cu.cu b/src/ops/rms_norm/nvidia/rms_norm_cu.cu new file mode 100644 index 00000000..6ce0d098 --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_cu.cu @@ -0,0 +1,106 @@ +#include "rms_norm_cu.cuh" + +#include +#include +#include + +#include "../../../utils/check.hpp" +#include "../../../utils/types.hpp" + +namespace rmsnormops::nvidia { + +__device__ inline float to_float(float x) { return x; } +__device__ inline float to_float(__half x) { return __half2float(x); } +__device__ inline float to_float(__nv_bfloat16 x) { return __bfloat162float(x); } + +template +__device__ inline T from_float(float x) { return static_cast(x); } +template <> +__device__ inline __half from_float<__half>(float x) { return __float2half(x); } +template <> +__device__ inline __nv_bfloat16 from_float<__nv_bfloat16>(float x) { return __float2bfloat16(x); } + +// One block per row. Shared memory holds one float per thread for reduction. +// output[i,j] = input[i,j] * weight[j] * rsqrt(mean(input[i,:]^2) + eps) +template +__global__ void +rms_norm_kernel(T *output, const T *input, const T *weight, + size_t N, size_t M, float eps) { + extern __shared__ float smem[]; // blockDim.x floats + + size_t row = blockIdx.x; + if (row >= N) return; + + unsigned tid = threadIdx.x; + unsigned nthreads = blockDim.x; + + // Phase 1: compute partial sum of squares + float local_sq = 0.0f; + for (size_t j = tid; j < M; j += nthreads) { + float v = to_float(input[row * M + j]); + local_sq += v * v; + } + smem[tid] = local_sq; + __syncthreads(); + + // Binary-tree reduction + for (unsigned s = nthreads / 2; s > 0; s >>= 1) { + if (tid < s) smem[tid] += smem[tid + s]; + __syncthreads(); + } + + float rms_scale = rsqrtf(smem[0] / static_cast(M) + eps); + + // Phase 2: normalize and scale + for (size_t j = tid; j < M; j += nthreads) { + float x = to_float(input[row * M + j]); + float w = to_float(weight[j]); + output[row * M + j] = from_float(x * w * rms_scale); + } +} + +} // namespace rmsnormops::nvidia + +namespace llaisys::ops::nvidia { + +void rms_norm(std::byte *output, + const std::byte *input, + const std::byte *weight, + size_t N, + size_t M, + float eps, + llaisysDataType_t dtype) { + constexpr int threads = 256; // must be power of 2 + size_t smem_size = threads * sizeof(float); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + rmsnormops::nvidia::rms_norm_kernel + <<(N), threads, smem_size>>>( + reinterpret_cast(output), + reinterpret_cast(input), + reinterpret_cast(weight), + N, M, eps); + break; + case LLAISYS_DTYPE_F16: + rmsnormops::nvidia::rms_norm_kernel<__half> + <<(N), threads, smem_size>>>( + reinterpret_cast<__half *>(output), + reinterpret_cast(input), + reinterpret_cast(weight), + N, M, eps); + break; + case LLAISYS_DTYPE_BF16: + rmsnormops::nvidia::rms_norm_kernel<__nv_bfloat16> + <<(N), threads, smem_size>>>( + reinterpret_cast<__nv_bfloat16 *>(output), + reinterpret_cast(input), + reinterpret_cast(weight), + N, M, eps); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rms_norm/nvidia/rms_norm_cu.cuh b/src/ops/rms_norm/nvidia/rms_norm_cu.cuh new file mode 100644 index 00000000..0f173b5e --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_cu.cuh @@ -0,0 +1,16 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { + +void rms_norm(std::byte *output, + const std::byte *input, + const std::byte *weight, + size_t N, + size_t M, + float eps, + llaisysDataType_t dtype); + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9..c00921b5 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,28 @@ #include "op.hpp" +#include "../../utils.hpp" +#include "cpu/rms_cpu.hpp" +#include "llaisys.h" +#include +#ifdef ENABLE_NVIDIA_API +#include "nvidia/rms_norm_cu.cuh" +#endif namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight); + CHECK_SAME_DTYPE(out->dtype(), in->dtype(), weight->dtype()); + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + llaisys::ops::cpu::rms_norm(out->data(), in->data(), weight->data(), out->shape()[0], out->shape()[1], eps, + out->dtype()); +#ifdef ENABLE_NVIDIA_API + } else if (out->deviceType() == LLAISYS_DEVICE_NVIDIA) { + llaisys::ops::nvidia::rms_norm(out->data(), in->data(), weight->data(), out->shape()[0], out->shape()[1], eps, + out->dtype()); +#endif + } else + EXCEPTION_UNSUPPORTED_DEVICE; } } // namespace llaisys::ops diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 00000000..830e791a --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,64 @@ +#include "rope_cpu.hpp" +#include "../../../utils.hpp" +#include "llaisys.h" +#include +#include +#include + +template +void rope_impl( + T *output, const T *input, const int64_t *pos_ids, size_t seqlen, size_t num_head, size_t head_dim, float theta) { + + size_t dim_half = head_dim / 2; + +#pragma omp parallel for collapse(2) + for (size_t seq_idx = 0; seq_idx < seqlen; seq_idx++) { + float pos_id = static_cast(pos_ids[seq_idx]); + + for (size_t head_idx = 0; head_idx < num_head; head_idx++) { + for (size_t i = 0; i < dim_half; i++) { + float angle = pos_id / std::pow(theta, (2.0f * i) / head_dim); + float cos_angle = std::cos(angle); + float sin_angle = std::sin(angle); + + size_t base_idx = seq_idx * num_head * head_dim + head_idx * head_dim; + + float x1 = casting(float, input[base_idx + i]); + float x2 = casting(float, input[base_idx + i + dim_half]); + + output[base_idx + i] = casting(T, x1 * cos_angle - x2 * sin_angle); + output[base_idx + i + dim_half] = casting(T, x1 * sin_angle + x2 * cos_angle); + } + } + } +} + +namespace llaisys::ops::cpu { + +void rope(std::byte *output, + const std::byte *input, + const std::byte *pos_ids, + size_t seqlen, + size_t num_head, + size_t head_dim, + float theta, + llaisysDataType_t dtype) { + + using namespace llaisys; + + switch (dtype) { + case LLAISYS_DTYPE_F32: + return rope_impl(recast(float *, output), recast(const float *, input), recast(const int64_t *, pos_ids), + seqlen, num_head, head_dim, theta); + case LLAISYS_DTYPE_F16: + return rope_impl(recast(fp16_t *, output), recast(const fp16_t *, input), recast(const int64_t *, pos_ids), + seqlen, num_head, head_dim, theta); + case LLAISYS_DTYPE_BF16: + return rope_impl(recast(bf16_t *, output), recast(const bf16_t *, input), recast(const int64_t *, pos_ids), + seqlen, num_head, head_dim, theta); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 00000000..18d166ce --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + +void rope(std::byte *output, + const std::byte *input, + const std::byte *pos_ids, + size_t seqlen, + size_t num_head, + size_t head_dim, + float theta, + llaisysDataType_t dtype); + +} \ No newline at end of file diff --git a/src/ops/rope/nvidia/rope_cu.cu b/src/ops/rope/nvidia/rope_cu.cu new file mode 100644 index 00000000..03d582ec --- /dev/null +++ b/src/ops/rope/nvidia/rope_cu.cu @@ -0,0 +1,96 @@ +#include "rope_cu.cuh" + +#include +#include +#include +#include + +#include "../../../utils/check.hpp" +#include "../../../utils/types.hpp" + +namespace ropeops::nvidia { + +__device__ inline float to_float(float x) { return x; } +__device__ inline float to_float(__half x) { return __half2float(x); } +__device__ inline float to_float(__nv_bfloat16 x) { return __bfloat162float(x); } + +template +__device__ inline T from_float(float x) { return static_cast(x); } +template <> +__device__ inline __half from_float<__half>(float x) { return __float2half(x); } +template <> +__device__ inline __nv_bfloat16 from_float<__nv_bfloat16>(float x) { return __float2bfloat16(x); } + +// blockIdx.y = flattened (seq * num_head + head) +// blockIdx.x * blockDim.x + threadIdx.x = i (index in [0, head_dim/2)) +template +__global__ void +rope_kernel(T *output, const T *input, const int64_t *pos_ids, + size_t seqlen, size_t num_head, size_t head_dim, float theta) { + size_t seq_head = blockIdx.y; + size_t seq = seq_head / num_head; + size_t head = seq_head % num_head; + size_t i = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + size_t dim_half = head_dim / 2; + + if (seq >= seqlen || i >= dim_half) return; + + float pos = static_cast(pos_ids[seq]); + float angle = pos / powf(theta, (2.0f * static_cast(i)) / static_cast(head_dim)); + float cos_a = cosf(angle); + float sin_a = sinf(angle); + + size_t base = seq * num_head * head_dim + head * head_dim; + float x1 = to_float(input[base + i]); + float x2 = to_float(input[base + i + dim_half]); + + output[base + i] = from_float(x1 * cos_a - x2 * sin_a); + output[base + i + dim_half] = from_float(x1 * sin_a + x2 * cos_a); +} + +} // namespace ropeops::nvidia + +namespace llaisys::ops::nvidia { + +void rope(std::byte *output, + const std::byte *input, + const std::byte *pos_ids, + size_t seqlen, + size_t num_head, + size_t head_dim, + float theta, + llaisysDataType_t dtype) { + size_t dim_half = head_dim / 2; + constexpr int block_x = 64; + dim3 block(block_x); + dim3 grid(static_cast((dim_half + block_x - 1) / block_x), + static_cast(seqlen * num_head)); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + ropeops::nvidia::rope_kernel<<>>( + reinterpret_cast(output), + reinterpret_cast(input), + reinterpret_cast(pos_ids), + seqlen, num_head, head_dim, theta); + break; + case LLAISYS_DTYPE_F16: + ropeops::nvidia::rope_kernel<__half><<>>( + reinterpret_cast<__half *>(output), + reinterpret_cast(input), + reinterpret_cast(pos_ids), + seqlen, num_head, head_dim, theta); + break; + case LLAISYS_DTYPE_BF16: + ropeops::nvidia::rope_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16 *>(output), + reinterpret_cast(input), + reinterpret_cast(pos_ids), + seqlen, num_head, head_dim, theta); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rope/nvidia/rope_cu.cuh b/src/ops/rope/nvidia/rope_cu.cuh new file mode 100644 index 00000000..66595383 --- /dev/null +++ b/src/ops/rope/nvidia/rope_cu.cuh @@ -0,0 +1,18 @@ +#pragma once + +#include "llaisys.h" +#include +#include + +namespace llaisys::ops::nvidia { + +void rope(std::byte *output, + const std::byte *input, + const std::byte *pos_ids, + size_t seqlen, + size_t num_head, + size_t head_dim, + float theta, + llaisysDataType_t dtype); + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64..a0df2e3e 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,32 @@ #include "op.hpp" +#include "../../utils.hpp" +#include "cpu/rope_cpu.hpp" +#include "llaisys.h" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/rope_cu.cuh" +#endif namespace llaisys::ops { + void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, pos_ids); + CHECK_SAME_DTYPE(out->dtype(), in->dtype()); + ASSERT(pos_ids->dtype() == LLAISYS_DTYPE_I64, "RoPE(): pos_ids must be int64."); + + size_t seqlen = in->shape()[0]; + size_t num_head = in->shape()[1]; + size_t head_dim = in->shape()[2]; + + ASSERT(pos_ids->numel() == seqlen, "RoPE(): pos_ids length must match input sequence length."); + + if (in->deviceType() == LLAISYS_DEVICE_CPU) { + ops::cpu::rope(out->data(), in->data(), pos_ids->data(), seqlen, num_head, head_dim, theta, in->dtype()); +#ifdef ENABLE_NVIDIA_API + } else if (in->deviceType() == LLAISYS_DEVICE_NVIDIA) { + ops::nvidia::rope(out->data(), in->data(), pos_ids->data(), seqlen, num_head, head_dim, theta, in->dtype()); +#endif + } else + EXCEPTION_UNSUPPORTED_DEVICE; } + } // namespace llaisys::ops diff --git a/src/ops/sample/cpu/sample_cpu.cpp b/src/ops/sample/cpu/sample_cpu.cpp new file mode 100644 index 00000000..aac38862 --- /dev/null +++ b/src/ops/sample/cpu/sample_cpu.cpp @@ -0,0 +1,113 @@ +#include "sample_cpu.hpp" +#include "../../../utils.hpp" + +#include +#include +#include +#include +#include +#include + +// Thread-local RNG so each thread has independent, seeding-capable state. +static thread_local std::mt19937_64 rng{std::random_device{}()}; + +namespace llaisys::ops::cpu { + +void sample_set_seed(uint64_t seed) { + rng.seed(seed); +} + +template +static int64_t sample_impl(const T *logits, size_t numel, + int top_k, float top_p, float temperature) { + // ── 1. Temperature scaling (in float) ────────────────────────────────── + std::vector scores(numel); + for (size_t i = 0; i < numel; ++i) + scores[i] = casting(float, logits[i]) / temperature; + + // ── 2. Sort indices by score descending ──────────────────────────────── + std::vector indices(numel); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), + [&](int a, int b) { + if (scores[a] == scores[b]) { + // Deterministic tie-break: match argmax semantics (first index wins). + return a < b; + } + return scores[a] > scores[b]; + }); + + // Near-zero temperature approximates greedy decoding; return argmax deterministically. + if (temperature <= 1e-5f) { + return static_cast(indices[0]); + } + + // ── 3. Top-k truncation ──────────────────────────────────────────────── + int k = static_cast(numel); + if (top_k > 0 && top_k < k) + k = top_k; + + // ── 4. Softmax over the top-k candidates (numerically stable) ───────── + float max_score = scores[indices[0]]; + std::vector probs(k); + for (int i = 0; i < k; ++i) + probs[i] = std::exp(scores[indices[i]] - max_score); + + float sum = 0.0f; + for (int i = 0; i < k; ++i) + sum += probs[i]; + for (int i = 0; i < k; ++i) + probs[i] /= sum; + + // ── 5. Top-p (nucleus) truncation ───────────────────────────────────── + // Keep the minimal prefix whose cumulative probability >= top_p, then + // renormalise. + if (top_p < 1.0f) { + float cumsum = 0.0f; + int cutoff = k; + for (int i = 0; i < k; ++i) { + cumsum += probs[i]; + if (cumsum >= top_p) { + cutoff = i + 1; + break; + } + } + k = cutoff; + probs.resize(k); + // Renormalise after truncation. + sum = 0.0f; + for (int i = 0; i < k; ++i) + sum += probs[i]; + for (int i = 0; i < k; ++i) + probs[i] /= sum; + } + + // ── 6. Sample from the distribution ─────────────────────────────────── + std::discrete_distribution dist(probs.begin(), probs.end()); + return static_cast(indices[dist(rng)]); +} + +void sample(std::byte *out, const std::byte *logits, size_t numel, + int top_k, float top_p, float temperature, + llaisysDataType_t dtype) { + int64_t result; + switch (dtype) { + case LLAISYS_DTYPE_F32: + result = sample_impl(reinterpret_cast(logits), + numel, top_k, top_p, temperature); + break; + case LLAISYS_DTYPE_F16: + result = sample_impl(reinterpret_cast(logits), + numel, top_k, top_p, temperature); + break; + case LLAISYS_DTYPE_BF16: + result = sample_impl(reinterpret_cast(logits), + numel, top_k, top_p, temperature); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } + *reinterpret_cast(out) = result; +} + +} // namespace llaisys::ops::cpu diff --git a/src/ops/sample/cpu/sample_cpu.hpp b/src/ops/sample/cpu/sample_cpu.hpp new file mode 100644 index 00000000..6fdb8c5d --- /dev/null +++ b/src/ops/sample/cpu/sample_cpu.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "llaisys.h" +#include +#include + +namespace llaisys::ops::cpu { + +void sample(std::byte *out, const std::byte *logits, size_t numel, + int top_k, float top_p, float temperature, + llaisysDataType_t dtype); + +void sample_set_seed(uint64_t seed); + +} // namespace llaisys::ops::cpu diff --git a/src/ops/sample/op.cpp b/src/ops/sample/op.cpp new file mode 100644 index 00000000..d551d920 --- /dev/null +++ b/src/ops/sample/op.cpp @@ -0,0 +1,37 @@ +#include "op.hpp" +#include "../../utils.hpp" +#include "cpu/sample_cpu.hpp" +#include "llaisys.h" +#ifdef ENABLE_NVIDIA_API +// #include "nvidia/sample_cu.cuh" // TODO: add CUDA implementation +#endif + +namespace llaisys::ops { + +void sample(tensor_t out, tensor_t logits, int top_k, float top_p, float temperature) { + ASSERT(out->numel() == 1, "sample(): out must have exactly one element"); + ASSERT(out->dtype() == LLAISYS_DTYPE_I64, "sample(): out must be int64"); + ASSERT(logits->isContiguous(), "sample(): logits must be contiguous"); + ASSERT(temperature > 0.0f, "sample(): temperature must be > 0"); + ASSERT(top_p > 0.0f && top_p <= 1.0f, "sample(): top_p must be in (0, 1]"); + + switch (logits->deviceType()) { + case LLAISYS_DEVICE_CPU: + cpu::sample(out->data(), logits->data(), logits->numel(), + top_k, top_p, temperature, logits->dtype()); + break; +#ifdef ENABLE_NVIDIA_API + // case LLAISYS_DEVICE_NVIDIA: + // nvidia::sample(...); + // break; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } +} + +void sample_set_seed(uint64_t seed) { + cpu::sample_set_seed(seed); +} + +} // namespace llaisys::ops diff --git a/src/ops/sample/op.hpp b/src/ops/sample/op.hpp new file mode 100644 index 00000000..77b9793d --- /dev/null +++ b/src/ops/sample/op.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include "../../tensor/tensor.hpp" + +namespace llaisys::ops { + +/** + * @brief Sample a token from logits with optional temperature, top-k, and top-p filtering. + * + * @param out Output tensor, shape {1}, dtype int64. Receives the sampled token index. + * @param logits Input logits tensor, shape {vocab_size}, any float dtype. + * @param top_k Keep only the top-k highest-logit tokens before sampling (0 = disabled). + * @param top_p Nucleus sampling: keep the smallest set of tokens whose cumulative probability + * exceeds top_p (1.0 = disabled). + * @param temperature Divide logits by this value before softmax (1.0 = no change). + */ +void sample(tensor_t out, tensor_t logits, int top_k, float top_p, float temperature); + +/** + * @brief Set the random seed for the sampling RNG (per-thread). + * Call before sampling for reproducible results. + */ +void sample_set_seed(uint64_t seed); + +} // namespace llaisys::ops diff --git a/src/ops/self_attention/cpu/selfattn_cpu.cpp b/src/ops/self_attention/cpu/selfattn_cpu.cpp new file mode 100644 index 00000000..076ac9d5 --- /dev/null +++ b/src/ops/self_attention/cpu/selfattn_cpu.cpp @@ -0,0 +1,191 @@ +#include "selfattn_cpu.hpp" +#include "../../../utils.hpp" +#include "llaisys.h" +#include +#include +#include +#include +#include +#include + +template +static void self_attn_impl(T *attn_val, + const T *q, + const T *k, + const T *v, + size_t seqlen, + size_t num_head, + size_t head_dim, + size_t kvlen, + size_t num_kv_head, + size_t vdim, + float scale) { + +#pragma omp parallel + { + std::vector scores(kvlen); + +#pragma omp for collapse(2) schedule(static) + for (size_t h = 0; h < num_head; ++h) { + size_t num_groups = num_head / num_kv_head; + size_t head_k = h / num_groups; + + for (size_t s = 0; s < seqlen; ++s) { + float *sc = scores.data(); + + size_t bound = kvlen - seqlen + s; + size_t L = std::min(kvlen, bound + 1); + + // QK^T + max + float mx = -INFINITY; + size_t qbase = s * num_head * head_dim + h * head_dim; + + for (size_t t = 0; t < L; ++t) { + size_t kbase + = t * num_kv_head * head_dim + head_k * head_dim; + long double dot = 0.0L; +#pragma omp simd reduction(+ : dot) + for (size_t d = 0; d < head_dim; ++d) + dot += casting(long double, q[qbase + d]) + * casting(long double, k[kbase + d]); + float val = static_cast(dot * scale); + sc[t] = val; + mx = (val > mx) ? val : mx; + } + for (size_t t = L; t < kvlen; ++t) sc[t] = 0.f; + + // softmax + long double sum = 0.0L; + for (size_t t = 0; t < L; ++t) { + long double e = std::exp(static_cast(sc[t]) + - static_cast(mx)); + sc[t] = static_cast(e); + sum += e; + } + float inv = 1.0f / static_cast(sum); +#pragma omp simd + for (size_t t = 0; t < L; ++t) sc[t] *= inv; + + // PV (cache-friendly order) + size_t obase = s * num_head * vdim + h * vdim; + std::vector out(vdim, 0.0L); // 可换成分块栈数组 + + for (size_t i = 0; i < L; ++i) { + long double p = static_cast(sc[i]); + const T *vptr = v + i * num_kv_head * vdim + head_k * vdim; +#pragma omp simd + for (size_t t = 0; t < vdim; ++t) + out[t] += p * casting(long double, vptr[t]); + } + for (size_t t = 0; t < vdim; ++t) + attn_val[obase + t] + = casting(T, static_cast(out[t])); + } + } + } +} + +template +static void self_attn_impl_without_openmp(T *attn_val, + const T *q, + const T *k, + const T *v, + size_t seqlen, + size_t num_head, + size_t head_dim, + size_t kvlen, + size_t num_kv_head, + size_t vdim, + float scale) { + + using usize = std::size_t; + + usize ngroup = num_head / num_kv_head; + for (usize s = 0; s < seqlen; s++) { + for (usize h = 0; h < num_head; h++) { + usize head_k = h / ngroup; + + std::vector scores(kvlen, 0); + for (usize kl = 0; kl < kvlen; kl++) { + float sum = 0; + for (usize d = 0; d < head_dim; d++) { + usize qbase = s * num_head * head_dim + h * head_dim + d; + usize kbase + = kl * num_kv_head * head_dim + head_k * head_dim + d; + sum += casting(float, q[qbase]) * casting(float, k[kbase]); + } + scores[kl] = sum * scale; + } + + usize L = std::min(kvlen, kvlen - seqlen + s + 1); + // softmax + float max_score = -std::numeric_limits::infinity(); + for (usize t = 0; t < L; t++) + if (scores[t] > max_score) + max_score = scores[t]; + + float sum_exp = 0.0f; + for (usize t = 0; t < L; t++) { + scores[t] = std::exp(scores[t] - max_score); + sum_exp += scores[t]; + } + for (usize t = 0; t < L; t++) scores[t] /= sum_exp; + for (usize t = L; t < kvlen; t++) scores[t] = 0.0f; + + // PV + for (usize t = 0; t < vdim; t++) { + float acc = 0.0f; + for (usize i = 0; i < L; i++) { + usize vbase = i * num_kv_head * vdim + head_k * vdim + t; + acc += scores[i] * casting(float, v[vbase]); + } + usize obase = s * num_head * vdim + h * vdim + t; + attn_val[obase] = casting(T, acc); + } + } + } +} + +namespace llaisys::ops::cpu { + +void self_attn(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + size_t seqlen, + size_t num_head, + size_t head_dim, + size_t kvlen, + size_t num_kv_head, + size_t vdim, + float scale, + llaisysDataType_t dtype) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return self_attn_impl( + recast(float *, attn_val), recast(const float *, q), + recast(const float *, k), recast(const float *, v), seqlen, + num_head, head_dim, kvlen, num_kv_head, vdim, scale); + + case LLAISYS_DTYPE_F16: + return self_attn_impl(recast(llaisys::fp16_t *, attn_val), + recast(const llaisys::fp16_t *, q), + recast(const llaisys::fp16_t *, k), + recast(const llaisys::fp16_t *, v), seqlen, + num_head, head_dim, kvlen, num_kv_head, vdim, + scale); + + case LLAISYS_DTYPE_BF16: + return self_attn_impl(recast(llaisys::bf16_t *, attn_val), + recast(const llaisys::bf16_t *, q), + recast(const llaisys::bf16_t *, k), + recast(const llaisys::bf16_t *, v), seqlen, + num_head, head_dim, kvlen, num_kv_head, vdim, + scale); + + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/self_attention/cpu/selfattn_cpu.hpp b/src/ops/self_attention/cpu/selfattn_cpu.hpp new file mode 100644 index 00000000..405a8db2 --- /dev/null +++ b/src/ops/self_attention/cpu/selfattn_cpu.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + +void self_attn(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + size_t seqlen, + size_t num_head, + size_t head_dim, + size_t kvlen, + size_t num_kv_head, + size_t vdim, + float scale, + llaisysDataType_t dtype); + +} \ No newline at end of file diff --git a/src/ops/self_attention/nvidia/selfattn_cu.cu b/src/ops/self_attention/nvidia/selfattn_cu.cu new file mode 100644 index 00000000..afbaee90 --- /dev/null +++ b/src/ops/self_attention/nvidia/selfattn_cu.cu @@ -0,0 +1,165 @@ +#include "selfattn_cu.cuh" + +#include +#include +#include +#include + +#include "../../../utils/check.hpp" +#include "../../../utils/types.hpp" + +namespace selfattnops::nvidia { + +__device__ inline float to_float(float x) { return x; } +__device__ inline float to_float(__half x) { return __half2float(x); } +__device__ inline float to_float(__nv_bfloat16 x) { return __bfloat162float(x); } + +template +__device__ inline T from_float(float x) { return static_cast(x); } +template <> +__device__ inline __half from_float<__half>(float x) { return __float2half(x); } +template <> +__device__ inline __nv_bfloat16 from_float<__nv_bfloat16>(float x) { return __float2bfloat16(x); } + +// One block per (seq, head) pair. +// Dynamic shared memory layout: +// float scores[kvlen] -- QK dot products +// float reduce[blockDim.x] -- block-reduction scratch +// +// NOTE: requires (kvlen + blockDim.x) * sizeof(float) <= shared memory limit +// (~49 KB on Ampere), i.e. kvlen up to ~12 000 tokens at blockDim.x=128. +template +__global__ void +self_attn_kernel(T *attn_val, + const T *q, const T *k, const T *v, + size_t seqlen, size_t num_head, size_t head_dim, + size_t kvlen, size_t num_kv_head, size_t vdim, + float scale) { + extern __shared__ float smem[]; + float *scores = smem; + float *reduce_buf = smem + kvlen; // [blockDim.x] for reductions + + size_t block_id = blockIdx.x; + size_t s = block_id / num_head; + size_t h = block_id % num_head; + unsigned tid = threadIdx.x; + unsigned nthreads = blockDim.x; + + size_t num_groups = num_head / num_kv_head; + size_t head_k = h / num_groups; + + // causal: attend to positions [0, L) + size_t L = kvlen - seqlen + s + 1; + if (L > kvlen) L = kvlen; + + size_t qbase = s * num_head * head_dim + h * head_dim; + + // ---- Phase 1: compute QK^T scores ---- + for (size_t t = tid; t < L; t += nthreads) { + size_t kbase = t * num_kv_head * head_dim + head_k * head_dim; + float dot = 0.0f; + for (size_t d = 0; d < head_dim; ++d) + dot += to_float(q[qbase + d]) * to_float(k[kbase + d]); + scores[t] = dot * scale; + } + __syncthreads(); + + // ---- Phase 2: find global max (numerically stable softmax) ---- + float local_max = -FLT_MAX; + for (size_t t = tid; t < L; t += nthreads) + local_max = fmaxf(local_max, scores[t]); + reduce_buf[tid] = local_max; + __syncthreads(); + for (unsigned s2 = nthreads / 2; s2 > 0; s2 >>= 1) { + if (tid < s2) reduce_buf[tid] = fmaxf(reduce_buf[tid], reduce_buf[tid + s2]); + __syncthreads(); + } + float gmax = reduce_buf[0]; + __syncthreads(); + + // ---- Phase 3: exp(score - max), compute sum ---- + float local_sum = 0.0f; + for (size_t t = tid; t < L; t += nthreads) { + scores[t] = expf(scores[t] - gmax); + local_sum += scores[t]; + } + reduce_buf[tid] = local_sum; + __syncthreads(); + for (unsigned s2 = nthreads / 2; s2 > 0; s2 >>= 1) { + if (tid < s2) reduce_buf[tid] += reduce_buf[tid + s2]; + __syncthreads(); + } + float inv_sum = 1.0f / reduce_buf[0]; + __syncthreads(); + + // ---- Phase 4: normalize scores in-place ---- + for (size_t t = tid; t < L; t += nthreads) + scores[t] *= inv_sum; + __syncthreads(); + + // ---- Phase 5: weighted sum of V ---- + size_t obase = s * num_head * vdim + h * vdim; + for (size_t d = tid; d < vdim; d += nthreads) { + float acc = 0.0f; + for (size_t t = 0; t < L; ++t) { + size_t vbase = t * num_kv_head * vdim + head_k * vdim; + acc += scores[t] * to_float(v[vbase + d]); + } + attn_val[obase + d] = from_float(acc); + } +} + +} // namespace selfattnops::nvidia + +namespace llaisys::ops::nvidia { + +void self_attn(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + size_t seqlen, + size_t num_head, + size_t head_dim, + size_t kvlen, + size_t num_kv_head, + size_t vdim, + float scale, + llaisysDataType_t dtype) { + constexpr int threads = 128; // must be power of 2 + size_t smem_size = (kvlen + threads) * sizeof(float); + unsigned int blocks = static_cast(seqlen * num_head); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + selfattnops::nvidia::self_attn_kernel + <<>>( + reinterpret_cast(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seqlen, num_head, head_dim, kvlen, num_kv_head, vdim, scale); + break; + case LLAISYS_DTYPE_F16: + selfattnops::nvidia::self_attn_kernel<__half> + <<>>( + reinterpret_cast<__half *>(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seqlen, num_head, head_dim, kvlen, num_kv_head, vdim, scale); + break; + case LLAISYS_DTYPE_BF16: + selfattnops::nvidia::self_attn_kernel<__nv_bfloat16> + <<>>( + reinterpret_cast<__nv_bfloat16 *>(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seqlen, num_head, head_dim, kvlen, num_kv_head, vdim, scale); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/self_attention/nvidia/selfattn_cu.cuh b/src/ops/self_attention/nvidia/selfattn_cu.cuh new file mode 100644 index 00000000..802d45fa --- /dev/null +++ b/src/ops/self_attention/nvidia/selfattn_cu.cuh @@ -0,0 +1,21 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { + +void self_attn(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + size_t seqlen, + size_t num_head, + size_t head_dim, + size_t kvlen, + size_t num_kv_head, + size_t vdim, + float scale, + llaisysDataType_t dtype); + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d62014..1a05fd02 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,51 @@ #include "op.hpp" +#include "cpu/selfattn_cpu.hpp" +#include "llaisys.h" +#include +#ifdef ENABLE_NVIDIA_API +#include "nvidia/selfattn_cu.cuh" +#endif namespace llaisys::ops { -void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); +void self_attention( + tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { + CHECK_SAME_DEVICE(attn_val, q, k, v); + + auto seqlen = q->shape()[0]; + auto num_head = q->shape()[1]; + auto head_dim = q->shape()[2]; + auto kvlen = k->shape()[0]; + auto num_kv_head = k->shape()[1]; + auto vdim = v->shape()[2]; + + // Check dimensions + ASSERT(attn_val->shape()[0] == seqlen && attn_val->shape()[1] == num_head + && attn_val->shape()[2] == vdim, + "[self-attn] attn_val shape mismatch"); + ASSERT(k->shape()[0] == kvlen && k->shape()[1] == num_kv_head + && k->shape()[2] == head_dim, + "[self-attn] k shape mismatch"); + ASSERT(v->shape()[0] == kvlen && v->shape()[1] == num_kv_head + && v->shape()[2] == vdim, + "[self-attn] v shape mismatch"); + ASSERT(q->isContiguous() && k->isContiguous() && v->isContiguous() + && attn_val->isContiguous(), + "[self-attn] all tensors must be contiguous"); + + llaisys::core::context().setDevice(attn_val->deviceType(), + attn_val->deviceId()); + + if (attn_val->deviceType() == LLAISYS_DEVICE_CPU) { + cpu::self_attn(attn_val->data(), q->data(), k->data(), v->data(), + seqlen, num_head, head_dim, kvlen, num_kv_head, vdim, + scale, attn_val->dtype()); +#ifdef ENABLE_NVIDIA_API + } else if (attn_val->deviceType() == LLAISYS_DEVICE_NVIDIA) { + nvidia::self_attn(attn_val->data(), q->data(), k->data(), v->data(), + seqlen, num_head, head_dim, kvlen, num_kv_head, vdim, + scale, attn_val->dtype()); +#endif + } else + EXCEPTION_UNSUPPORTED_DEVICE; } } // namespace llaisys::ops diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 00000000..39dafb00 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,36 @@ +#include "swiglu_cpu.hpp" +#include "../../../utils.hpp" +#include "llaisys.h" +#include +#include + +template static void swiglu_impl(T *out, const T *gate, const T *up, size_t numel) { + +#pragma omp parallel for + for (size_t i = 0; i < numel; ++i) { + float gate_val = casting(float, gate[i]); + float up_val = casting(float, up[i]); + // Swiglu activation: out = up * sigmoid(gate) + float sigmoid_gate = gate_val / (1.0f + std::exp(-gate_val)); + out[i] = casting(T, up_val * sigmoid_gate); + } +} + +namespace llaisys::ops::cpu { + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, size_t numel, llaisysDataType_t dtype) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return swiglu_impl(recast(float *, out), recast(const float *, gate), recast(const float *, up), numel); + case LLAISYS_DTYPE_F16: + return swiglu_impl(recast(llaisys::fp16_t *, out), recast(const llaisys::fp16_t *, gate), + recast(const llaisys::fp16_t *, up), numel); + case LLAISYS_DTYPE_BF16: + return swiglu_impl(recast(llaisys::bf16_t *, out), recast(const llaisys::bf16_t *, gate), + recast(const llaisys::bf16_t *, up), numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 00000000..fe456bd1 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, size_t numel, llaisysDataType_t dtype); + +} \ No newline at end of file diff --git a/src/ops/swiglu/nvidia/swiglu_cu.cu b/src/ops/swiglu/nvidia/swiglu_cu.cu new file mode 100644 index 00000000..c5bc7276 --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_cu.cu @@ -0,0 +1,75 @@ +#include "swiglu_cu.cuh" + +#include +#include +#include +#include + +#include "../../../utils/check.hpp" +#include "../../../utils/types.hpp" + +namespace swigluops::nvidia { + +__device__ inline float to_float(float x) { return x; } +__device__ inline float to_float(__half x) { return __half2float(x); } +__device__ inline float to_float(__nv_bfloat16 x) { return __bfloat162float(x); } + +template +__device__ inline T from_float(float x) { return static_cast(x); } +template <> +__device__ inline __half from_float<__half>(float x) { return __float2half(x); } +template <> +__device__ inline __nv_bfloat16 from_float<__nv_bfloat16>(float x) { return __float2bfloat16(x); } + +// out[i] = up[i] * SiLU(gate[i]) where SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x)) +template +__global__ void +swiglu_kernel(T *out, const T *gate, const T *up, size_t numel) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= numel) return; + float g = to_float(gate[idx]); + float u = to_float(up[idx]); + float silu = g / (1.0f + expf(-g)); + out[idx] = from_float(u * silu); +} + +} // namespace swigluops::nvidia + +namespace llaisys::ops::nvidia { + +void swiglu(std::byte *out, + const std::byte *gate, + const std::byte *up, + size_t numel, + llaisysDataType_t dtype) { + constexpr int threads = 256; + unsigned int blocks = static_cast((numel + threads - 1) / threads); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + swigluops::nvidia::swiglu_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + break; + case LLAISYS_DTYPE_F16: + swigluops::nvidia::swiglu_kernel<__half><<>>( + reinterpret_cast<__half *>(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + break; + case LLAISYS_DTYPE_BF16: + swigluops::nvidia::swiglu_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16 *>(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/swiglu/nvidia/swiglu_cu.cuh b/src/ops/swiglu/nvidia/swiglu_cu.cuh new file mode 100644 index 00000000..2534e5a5 --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_cu.cuh @@ -0,0 +1,14 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { + +void swiglu(std::byte *out, + const std::byte *gate, + const std::byte *up, + size_t numel, + llaisysDataType_t dtype); + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc9..4e42a401 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,28 @@ #include "op.hpp" +#include "cpu/swiglu_cpu.hpp" +#include "llaisys.h" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/swiglu_cu.cuh" +#endif namespace llaisys::ops { + void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, gate, up); + CHECK_SAME_DTYPE(out->dtype(), gate->dtype(), up->dtype()); + ASSERT(out->shape() == gate->shape() && out->shape() == up->shape(), + "swiglu(): all tensors must have the same shape"); + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + cpu::swiglu(out->data(), gate->data(), up->data(), out->numel(), out->dtype()); +#ifdef ENABLE_NVIDIA_API + } else if (out->deviceType() == LLAISYS_DEVICE_NVIDIA) { + nvidia::swiglu(out->data(), gate->data(), up->data(), out->numel(), out->dtype()); +#endif + } else + EXCEPTION_UNSUPPORTED_DEVICE; } + } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb6..2d2c7ec4 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace llaisys { @@ -164,27 +165,69 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + size_t stride = 1; + for (ptrdiff_t i = ndim() - 1; i >= 0; i--) { + if (_meta.shape[i] == 1) { + continue; + } + if (_meta.strides[i] != static_cast(stride)) { + return false; + } + stride *= _meta.shape[i]; + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + std::vector new_shape(ndim()); + std::vector new_strides(ndim(), 1); + for (size_t i = 0; i < ndim(); i++) { + new_shape[i] = _meta.shape[order[i]]; + new_strides[i] = _meta.strides[order[i]]; + } + TensorMeta new_meta{_meta.dtype, new_shape, new_strides}; + return std::shared_ptr(new Tensor(std::move(new_meta), _storage, _offset)); } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + // 1. check if #num if ok + size_t numelems = 1; + for (auto s : shape) + numelems *= s; + if (numelems != this->numel()) + throw std::runtime_error("Reshape with different number of elements"); + + // 2. check if contiguous + if (!this->isContiguous()) + throw std::runtime_error("Reshape on non-contiguous tensor is not supported"); + + // 3. compute strides + std::vector strides(shape.size(), 1); + size_t stride = 1; + for (int i = int(shape.size()) - 1; i >= 0; i--) { + strides[i] = stride; + stride *= shape[i]; + } + + TensorMeta new_meta{this->dtype(), shape, strides}; + return std::shared_ptr(new Tensor(new_meta, _storage, _offset)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t new_offset = _offset + _meta.strides[dim] * start * this->elementSize(); + std::vector new_shape = _meta.shape; + new_shape[dim] = end - start; + TensorMeta new_meta{_meta.dtype, new_shape, _meta.strides}; + return std::shared_ptr(new Tensor(std::move(new_meta), _storage, new_offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + core::context().setDevice(this->deviceType(), this->deviceId()); + core::context().runtime().api()->memcpy_sync( + this->data(), + src_, + this->numel() * this->elementSize(), + LLAISYS_MEMCPY_H2D); } tensor_t Tensor::contiguous() const { diff --git a/src/utils.hpp b/src/utils.hpp index f038edfb..0adf8229 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -1,3 +1,6 @@ #pragma once #include "utils/check.hpp" #include "utils/types.hpp" + +#define casting(T, v) llaisys::utils::cast(v) +#define recast(T, v) reinterpret_cast(v) \ No newline at end of file diff --git a/test/ops/sample.py b/test/ops/sample.py new file mode 100644 index 00000000..610d8991 --- /dev/null +++ b/test/ops/sample.py @@ -0,0 +1,227 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) + +import llaisys +import torch +from test_utils import random_tensor, zero_tensor, check_equal, benchmark + + +# ─── helpers ──────────────────────────────────────────────────────────────── + +def _read_i64(tensor: llaisys.Tensor) -> int: + """Copy a single-element int64 tensor to host and return its value.""" + tmp = torch.zeros((1,), dtype=torch.int64) + api = llaisys.RuntimeAPI(tensor.device_type()) + api.memcpy_sync( + tmp.data_ptr(), tensor.data_ptr(), + tmp.numel() * tmp.element_size(), + llaisys.MemcpyKind.D2D, + ) + return int(tmp.item()) + + +def _make_logits(values: list[float], dtype_name: str, device_name: str): + """Build a 1-D logits tensor from explicit values.""" + t = torch.tensor(values, dtype=_torch_float_dtype(dtype_name)) + ls = llaisys.Tensor((len(values),), dtype=llaisys_dtype(dtype_name), + device=llaisys_device(device_name)) + api = llaisys.RuntimeAPI(llaisys_device(device_name)) + api.memcpy_sync(ls.data_ptr(), t.data_ptr(), + t.numel() * t.element_size(), llaisys.MemcpyKind.D2D) + return t, ls + + +# ─── dtype helpers ────────────────────────────────────────────────────────── + +def _torch_float_dtype(dtype_name: str): + return {"f32": torch.float32, "f16": torch.float16, "bf16": torch.bfloat16}[dtype_name] + + +def llaisys_dtype(dtype_name: str): + return {"f32": llaisys.DataType.F32, "f16": llaisys.DataType.F16, + "bf16": llaisys.DataType.BF16}[dtype_name] + + +def llaisys_device(device_name: str): + return {"cpu": llaisys.DeviceType.CPU, "nvidia": llaisys.DeviceType.NVIDIA}[device_name] + + +# ─── test cases ───────────────────────────────────────────────────────────── + +def test_argmax_mode(vocab_size: int, dtype_name: str, device_name: str): + """top_k=1 must return the argmax token deterministically.""" + print(f" [argmax mode] vocab={vocab_size} dtype=<{dtype_name}>") + logits_t, logits_ls = random_tensor((vocab_size,), dtype_name, device_name) + + out_ls = llaisys.Tensor((1,), dtype=llaisys.DataType.I64, + device=llaisys_device(device_name)) + + # Run 10 times – result must always be the argmax. + expected = int(logits_t.argmax().item()) + for _ in range(10): + llaisys.Ops.sample(out_ls, logits_ls, top_k=1, top_p=1.0, temperature=1.0) + assert _read_i64(out_ls) == expected, \ + f"top_k=1 should always return argmax {expected}, got {_read_i64(out_ls)}" + + +def test_topk_constraint(vocab_size: int, top_k: int, dtype_name: str, device_name: str): + """Sampled token must be within the top-k logit indices.""" + print(f" [top-k] vocab={vocab_size} top_k={top_k} dtype=<{dtype_name}>") + logits_t, logits_ls = random_tensor((vocab_size,), dtype_name, device_name) + + # For low-precision dtypes (especially bf16), ties at the k-th boundary are common. + # Treat all tokens whose logit is >= the k-th largest value as valid top-k candidates. + logits_f = logits_t.float() + kth_value = torch.topk(logits_f, top_k).values[-1] + allowed = set((logits_f >= kth_value).nonzero(as_tuple=False).view(-1).tolist()) + + out_ls = llaisys.Tensor((1,), dtype=llaisys.DataType.I64, + device=llaisys_device(device_name)) + + llaisys.Ops.sample_set_seed(42) + for _ in range(50): + llaisys.Ops.sample(out_ls, logits_ls, top_k=top_k, top_p=1.0, temperature=1.0) + token = _read_i64(out_ls) + assert token in allowed, \ + f"top_k={top_k}: sampled token {token} not in top-k set {allowed}" + + +def test_topp_constraint(vocab_size: int, top_p: float, dtype_name: str, device_name: str): + """Sampled token must lie within the nucleus (top-p) set.""" + print(f" [top-p] vocab={vocab_size} top_p={top_p} dtype=<{dtype_name}>") + logits_t, logits_ls = random_tensor((vocab_size,), dtype_name, device_name) + + # Compute nucleus on the torch side. + probs = torch.softmax(logits_t.float(), dim=-1) + sorted_probs, sorted_idx = torch.sort(probs, descending=True) + cumsum = torch.cumsum(sorted_probs, dim=0) + # All tokens in the nucleus: those at or before the first cumsum >= top_p. + cutoff = int((cumsum < top_p).sum().item()) + 1 + # Include ties on the cutoff probability to avoid false negatives in low precision. + cutoff_prob = float(sorted_probs[cutoff - 1].item()) + nucleus = set((probs >= cutoff_prob).nonzero(as_tuple=False).view(-1).tolist()) + + out_ls = llaisys.Tensor((1,), dtype=llaisys.DataType.I64, + device=llaisys_device(device_name)) + + llaisys.Ops.sample_set_seed(0) + for _ in range(50): + llaisys.Ops.sample(out_ls, logits_ls, top_k=0, top_p=top_p, temperature=1.0) + token = _read_i64(out_ls) + assert token in nucleus, \ + f"top_p={top_p}: sampled token {token} not in nucleus {nucleus}" + + +def test_temperature_distribution(vocab_size: int, dtype_name: str, device_name: str): + """ + With high temperature and many draws, the empirical distribution should + be roughly uniform – no single token dominates. + With temperature=0.01 it should converge to near-argmax behaviour. + We do a light sanity check rather than a full statistical test. + """ + print(f" [temperature] vocab={vocab_size} dtype=<{dtype_name}>") + logits_t, logits_ls = random_tensor((vocab_size,), dtype_name, device_name) + + out_ls = llaisys.Tensor((1,), dtype=llaisys.DataType.I64, + device=llaisys_device(device_name)) + + # Very low temperature → all draws should be same token (argmax). + argmax_token = int(logits_t.float().argmax().item()) + llaisys.Ops.sample_set_seed(7) + for _ in range(20): + llaisys.Ops.sample(out_ls, logits_ls, top_k=0, top_p=1.0, temperature=1e-6) + assert _read_i64(out_ls) == argmax_token, \ + "temperature≈0 should always sample the argmax token" + + # Very high temperature → diversity check: over 200 draws at least 2 distinct tokens + if vocab_size >= 4: + llaisys.Ops.sample_set_seed(13) + seen = set() + for _ in range(200): + llaisys.Ops.sample(out_ls, logits_ls, top_k=0, top_p=1.0, temperature=1e6) + seen.add(_read_i64(out_ls)) + assert len(seen) >= 2, \ + "temperature=1e6 (near-uniform) should sample diverse tokens" + + +def test_seed_reproducibility(vocab_size: int, dtype_name: str, device_name: str): + """Same seed must produce the same sequence of sampled tokens.""" + print(f" [seed reproducibility] vocab={vocab_size} dtype=<{dtype_name}>") + logits_t, logits_ls = random_tensor((vocab_size,), dtype_name, device_name) + out_ls = llaisys.Tensor((1,), dtype=llaisys.DataType.I64, + device=llaisys_device(device_name)) + + llaisys.Ops.sample_set_seed(99) + run1 = [] + for _ in range(20): + llaisys.Ops.sample(out_ls, logits_ls, top_k=0, top_p=1.0, temperature=1.0) + run1.append(_read_i64(out_ls)) + + llaisys.Ops.sample_set_seed(99) + run2 = [] + for _ in range(20): + llaisys.Ops.sample(out_ls, logits_ls, top_k=0, top_p=1.0, temperature=1.0) + run2.append(_read_i64(out_ls)) + + assert run1 == run2, f"Same seed must give same sequence.\nrun1={run1}\nrun2={run2}" + + +def test_profile(vocab_size: int, dtype_name: str, device_name: str): + logits_t, logits_ls = random_tensor((vocab_size,), dtype_name, device_name) + out_ls = llaisys.Tensor((1,), dtype=llaisys.DataType.I64, + device=llaisys_device(device_name)) + + def torch_sample(): + probs = torch.softmax(logits_t.float(), dim=-1) + torch.multinomial(probs, 1) + + benchmark( + torch_sample, + lambda: llaisys.Ops.sample(out_ls, logits_ls, top_k=50, top_p=0.9, temperature=0.8), + device_name, + ) + + +# ─── entry point ──────────────────────────────────────────────────────────── + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"]) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + + DTYPES = ["f32", "f16", "bf16"] + VOCAB_SIZES = [16, 256, 4096] + + print(f"Testing Ops.sample on {args.device}") + + for vocab in VOCAB_SIZES: + for dtype in DTYPES: + test_argmax_mode(vocab, dtype, args.device) + + for vocab in VOCAB_SIZES: + for dtype in DTYPES: + test_topk_constraint(vocab, top_k=min(5, vocab // 2), dtype_name=dtype, device_name=args.device) + + for vocab in VOCAB_SIZES: + for dtype in DTYPES: + test_topp_constraint(vocab, top_p=0.9, dtype_name=dtype, device_name=args.device) + + for vocab in VOCAB_SIZES: + for dtype in DTYPES: + test_temperature_distribution(vocab, dtype, args.device) + + for vocab in VOCAB_SIZES: + for dtype in DTYPES: + test_seed_reproducibility(vocab, dtype, args.device) + + if args.profile: + print("\nBenchmark (vocab=32000, f32):") + test_profile(32000, "f32", args.device) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops/self_attention.py b/test/ops/self_attention.py index a042b51b..1634561b 100644 --- a/test/ops/self_attention.py +++ b/test/ops/self_attention.py @@ -15,7 +15,7 @@ def torch_self_attention(attn_val, query, key, value, scale): L, S = query.size(-2), key.size(-2) attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=S-L) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=S-L) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) @@ -72,6 +72,7 @@ def test_op_self_attention( # qlen, kvlen, nh, nkvh, hd (2, 2, 1, 1, 4), (5, 11, 4, 2, 8), + (9,9,12,2,128), ] testDtypePrec = [ # type, atol, rtol diff --git a/test/test_chat_server.py b/test/test_chat_server.py new file mode 100644 index 00000000..1fd2e6b2 --- /dev/null +++ b/test/test_chat_server.py @@ -0,0 +1,73 @@ +import json + +from fastapi.testclient import TestClient + +from llaisys.chat.server import ChatRuntime, create_app + + +class _DummyTokenizer: + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): + return "".join([f"{m['role']}:{m['content']}\n" for m in messages]) + "assistant:" + + def encode(self, text): + return [ord(c) % 256 for c in text] + + def decode(self, tokens, skip_special_tokens=True): + return "".join(chr(t) for t in tokens) + + +class _DummyModel: + def generate(self, inputs, max_new_tokens, top_k, top_p, temperature): + # Return prompt + "ok". + return list(inputs) + [ord("o"), ord("k")] + + def generate_stream(self, inputs, max_new_tokens, top_k, top_p, temperature): + yield ord("o") + yield ord("k") + + +def test_chat_completion_response_shape(): + runtime = ChatRuntime(tokenizer=_DummyTokenizer(), model=_DummyModel()) + client = TestClient(create_app(runtime)) + + payload = { + "model": "qwen2", + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 8, + "temperature": 0.8, + "top_p": 0.9, + "top_k": 40, + "stream": False, + } + resp = client.post("/v1/chat/completions", json=payload) + + assert resp.status_code == 200 + body = resp.json() + assert body["object"] == "chat.completion" + assert body["choices"][0]["message"]["role"] == "assistant" + assert body["choices"][0]["message"]["content"] == "ok" + assert body["usage"]["completion_tokens"] == 2 + + +def test_chat_completion_stream_sse(): + runtime = ChatRuntime(tokenizer=_DummyTokenizer(), model=_DummyModel()) + client = TestClient(create_app(runtime)) + + payload = { + "model": "qwen2", + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + } + with client.stream("POST", "/v1/chat/completions", json=payload) as resp: + assert resp.status_code == 200 + lines = [line for line in resp.iter_lines() if line] + + data_lines = [line for line in lines if line.startswith("data: ")] + assert data_lines[-1] == "data: [DONE]" + + first_event = json.loads(data_lines[0][6:]) + assert first_event["choices"][0]["delta"]["role"] == "assistant" + + chunk_events = [json.loads(line[6:]) for line in data_lines[1:-1]] + merged = "".join(e["choices"][0]["delta"].get("content", "") for e in chunk_events) + assert merged == "ok" diff --git a/test/test_hf.py b/test/test_hf.py new file mode 100644 index 00000000..fa3c5d31 --- /dev/null +++ b/test/test_hf.py @@ -0,0 +1,33 @@ +import os +from huggingface_hub import snapshot_download +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +from test_utils import torch_device + + +SENTENCE = "Who are you?" + + +def load_hf_model(model_path=None, device_name="cpu"): + model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + if model_path and os.path.isdir(model_path): + print(f"Loading model from local path: {model_path}") + else: + print(f"Loading model from Hugging Face: {model_id}") + model_path = snapshot_download(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map=torch_device(device_name), + trust_remote_code=True, + ) + + return tokenizer, model, model_path + + +tokenizer, hf_model, hf_model_path = load_hf_model( + model_path="./data", + device_name="cpu", +) diff --git a/test/test_infer.py b/test/test_infer.py index 59d06b87..b25fad69 100644 --- a/test/test_infer.py +++ b/test/test_infer.py @@ -34,7 +34,13 @@ def load_hf_model(model_path=None, device_name="cpu"): def hf_infer( - prompt, tokenizer, model, max_new_tokens=128, top_p=0.8, top_k=50, temperature=0.8 + prompt, + tokenizer, + model, + max_new_tokens=128, + top_p=0.8, + top_k=50, + temperature=0.8, ): input_content = tokenizer.apply_chat_template( conversation=[{"role": "user", "content": prompt}], @@ -68,6 +74,7 @@ def llaisys_infer( tokenize=False, ) inputs = tokenizer.encode(input_content) + print(f"Input tokens: {inputs}") outputs = model.generate( inputs, max_new_tokens=max_new_tokens, diff --git a/test/test_minimal.py b/test/test_minimal.py new file mode 100644 index 00000000..6f37ae37 --- /dev/null +++ b/test/test_minimal.py @@ -0,0 +1,72 @@ +import ctypes +import os +import sys +from tqdm import tqdm +import llaisys.models +from llaisys import LIB_LLAISYS +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +from huggingface_hub import snapshot_download +from test_utils import torch_device + + +def load_hf_model(model_path=None, device_name="cpu"): + model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + if model_path and os.path.isdir(model_path): + print(f"Loading model from local path: {model_path}") + else: + print(f"Loading model from Hugging Face: {model_id}") + model_path = snapshot_download(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map=torch_device(device_name), + trust_remote_code=True, + ) + + return tokenizer, model, model_path + + +tokenizer, hf_model, hf_model_path = load_hf_model( + model_path="./data", + device_name="cpu", +) +sentence = "Who are you?" +MAX_TOKENS = 128 + +input_content = tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": sentence}], + add_generation_prompt=True, + tokenize=False, +) +inputs = tokenizer.encode(input_content, return_tensors="pt").to("cpu") + + + +hf_output = hf_model.generate( + inputs, + max_new_tokens=MAX_TOKENS, + top_k=1, + top_p=1, + temperature=1, +) + + +model = llaisys.models.Qwen2(model_path="./data") + +listinput = inputs[0].tolist() +output = model.generate(listinput, max_new_tokens=MAX_TOKENS) +# listinput.append(output) +print(output) +print(tokenizer.decode(output, skip_special_tokens=True)) +print(hf_output) +print(tokenizer.decode(hf_output, skip_special_tokens=True)) + + +""" +Answer: +91786 +[151646, 151644, 15191, 525, 498, 30, 151645, 151648, 198, 91786] +""" diff --git a/xmake.lua b/xmake.lua index 1f65f7a9..b4bba821 100644 --- a/xmake.lua +++ b/xmake.lua @@ -1,7 +1,9 @@ +add_rules("plugin.compile_commands.autoupdate", {outputdir = "./", lsp = "clangd"}) add_rules("mode.debug", "mode.release") set_encodings("utf-8") add_includedirs("include") +add_runenvs("OMP_NUM_THREADS", 4) -- CPU -- includes("xmake/cpu.lua") @@ -37,6 +39,10 @@ target("llaisys-device") set_kind("static") add_deps("llaisys-utils") add_deps("llaisys-device-cpu") + -- link in the Nvidia device implementation when requested + if has_config("nv-gpu") then + add_deps("llaisys-device-nvidia") + end set_languages("cxx17") set_warnings("all", "error") @@ -83,11 +89,16 @@ target_end() target("llaisys-ops") set_kind("static") add_deps("llaisys-ops-cpu") + if has_config("nv-gpu") then + add_deps("llaisys-ops-nvidia") + end + add_packages("openmp") set_languages("cxx17") set_warnings("all", "error") if not is_plat("windows") then add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cxflags("-fopenmp") end add_files("src/ops/*/*.cpp") @@ -95,6 +106,42 @@ target("llaisys-ops") on_install(function (target) end) target_end() +target("llaisys-kvcache") + set_kind("static") + add_deps("llaisys-utils") + add_deps("llaisys-core") + add_deps("llaisys-tensor") + + set_languages("cxx17") + set_warnings("all", "error") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + end + + add_files("src/kvcache/*.cc") + + on_install(function (target) end) +target_end() + +-- target("llaisys-models") +-- set_kind("static") +-- add_deps("llaisys-utils") +-- add_deps("llaisys-core") +-- add_deps("llaisys-tensor") +-- add_deps("llaisys-ops") +-- add_deps("llaisys-kvcache") + +-- set_languages("cxx17") +-- set_warnings("all", "error") +-- if not is_plat("windows") then +-- add_cxflags("-fPIC", "-Wno-unknown-pragmas") +-- end + +-- add_files("src/llaisys/models/*.cc") + +-- on_install(function (target) end) +-- target_end() + target("llaisys") set_kind("shared") add_deps("llaisys-utils") @@ -102,6 +149,13 @@ target("llaisys") add_deps("llaisys-core") add_deps("llaisys-tensor") add_deps("llaisys-ops") + add_deps("llaisys-kvcache") + -- include nvidia device library when enabled (llaisys-device already pulls it in) + -- add_deps("llaisys-models") + if has_config("nv-gpu") then + -- link against CUDA runtime so __cudaRegisterLinkedBinary symbols resolve + add_links("cudart") + end set_languages("cxx17") set_warnings("all", "error") diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua new file mode 100644 index 00000000..aeb4d318 --- /dev/null +++ b/xmake/nvidia.lua @@ -0,0 +1,35 @@ +target("llaisys-device-nvidia") + set_kind("static") + set_policy("build.cuda.devlink", true) + set_languages("cxx17") + set_warnings("all", "error") + if not is_plat("windows") then + -- C++ flags for host compiler + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + -- CUDA compiler flags: pass PIC to host compiler and silence pragmas + add_cuflags("-Xcompiler=-fPIC", "-Wno-unknown-pragmas") + -- CUDA device-link also needs PIC because this archive is linked into a shared library. + add_culdflags("-Xcompiler=-fPIC") + end + + add_files("../src/device/nvidia/*.cu") + + on_install(function (target) end) +target_end() + +target("llaisys-ops-nvidia") + set_kind("static") + set_policy("build.cuda.devlink", true) + set_languages("cxx17") + set_warnings("all", "error") + add_includedirs("../include") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cuflags("-Xcompiler=-fPIC", "-Wno-unknown-pragmas") + add_culdflags("-Xcompiler=-fPIC") + end + + add_files("../src/ops/*/nvidia/*.cu") + + on_install(function (target) end) +target_end() \ No newline at end of file