diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py index f536fb52..c40ffc8e 100644 --- a/python/llaisys/libllaisys/__init__.py +++ b/python/llaisys/libllaisys/__init__.py @@ -12,6 +12,7 @@ from .tensor import llaisysTensor_t from .tensor import load_tensor from .ops import load_ops +from .models import load_models def load_shared_library(): @@ -38,6 +39,7 @@ def load_shared_library(): load_runtime(LIB_LLAISYS) load_tensor(LIB_LLAISYS) load_ops(LIB_LLAISYS) +load_models(LIB_LLAISYS) __all__ = [ diff --git a/python/llaisys/libllaisys/models/__init__.py b/python/llaisys/libllaisys/models/__init__.py new file mode 100644 index 00000000..6361cf81 --- /dev/null +++ b/python/llaisys/libllaisys/models/__init__.py @@ -0,0 +1,12 @@ +from .qwen2 import load_models +from .qwen2 import LlaisysQwen2Meta +from .qwen2 import LlaisysQwen2Weights +from .qwen2 import llaisysQwen2Model_t + +__all__ = [ + "load_models", + "LlaisysQwen2Meta", + "LlaisysQwen2Weights", + "llaisysQwen2Model_t", +] + diff --git a/python/llaisys/libllaisys/models/qwen2.py b/python/llaisys/libllaisys/models/qwen2.py new file mode 100644 index 00000000..6be71e4b --- /dev/null +++ b/python/llaisys/libllaisys/models/qwen2.py @@ -0,0 +1,70 @@ +import ctypes +from ctypes import POINTER, c_float, c_int, c_int64, c_size_t, c_void_p + +from ..llaisys_types import llaisysDataType_t, llaisysDeviceType_t +from ..tensor import llaisysTensor_t + + +class LlaisysQwen2Meta(ctypes.Structure): + _fields_ = [ + ("dtype", llaisysDataType_t), + ("nlayer", c_size_t), + ("hs", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("maxseq", c_size_t), + ("voc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_int64), + ] + + +class LlaisysQwen2Weights(ctypes.Structure): + _fields_ = [ + ("in_embed", llaisysTensor_t), + ("out_embed", llaisysTensor_t), + ("out_norm_w", llaisysTensor_t), + ("attn_norm_w", POINTER(llaisysTensor_t)), + ("attn_q_w", POINTER(llaisysTensor_t)), + ("attn_q_b", POINTER(llaisysTensor_t)), + ("attn_k_w", POINTER(llaisysTensor_t)), + ("attn_k_b", POINTER(llaisysTensor_t)), + ("attn_v_w", POINTER(llaisysTensor_t)), + ("attn_v_b", POINTER(llaisysTensor_t)), + ("attn_o_w", POINTER(llaisysTensor_t)), + ("mlp_norm_w", POINTER(llaisysTensor_t)), + ("mlp_gate_w", POINTER(llaisysTensor_t)), + ("mlp_up_w", POINTER(llaisysTensor_t)), + ("mlp_down_w", POINTER(llaisysTensor_t)), + ] + + +# Opaque handle type. +llaisysQwen2Model_t = c_void_p + + +def load_models(lib): + # struct LlaisysQwen2Model *llaisysQwen2ModelCreate(...) + lib.llaisysQwen2ModelCreate.argtypes = [ + POINTER(LlaisysQwen2Meta), + llaisysDeviceType_t, + POINTER(c_int), # device_ids + c_int, # ndevice + ] + lib.llaisysQwen2ModelCreate.restype = llaisysQwen2Model_t + + # void llaisysQwen2ModelDestroy(...) + lib.llaisysQwen2ModelDestroy.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelDestroy.restype = None + + # struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(...) + lib.llaisysQwen2ModelWeights.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelWeights.restype = POINTER(LlaisysQwen2Weights) + + # int64_t llaisysQwen2ModelInfer(...) + lib.llaisysQwen2ModelInfer.argtypes = [llaisysQwen2Model_t, POINTER(c_int64), c_size_t] + lib.llaisysQwen2ModelInfer.restype = c_int64 + diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b2..97f1a16d 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,33 +1,210 @@ -from typing import Sequence +from typing import Sequence, Optional from ..libllaisys import LIB_LLAISYS from ..libllaisys import DeviceType from pathlib import Path +import json import safetensors +from ..libllaisys.llaisys_types import DataType +from ..libllaisys.tensor import llaisysTensor_t +from ..libllaisys.models.qwen2 import LlaisysQwen2Meta, LlaisysQwen2Weights +import ctypes + +from ..tensor import Tensor class Qwen2: def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor + # - parse model config -> meta + # - create backend model + # - load safetensors -> create llaisys tensors -> fill weights struct + self._device = device + self._model = None + self._weights = None + self._meta = None + self._model_path = None + # Keep Python Tensor objects alive; the backend only stores raw handles. + self._owned_tensors = [] model_path = Path(model_path) + self._model_path = model_path + + meta = self._load_meta(model_path) + self._meta = meta + self._model, self._weights_ptr, self._weights = self._create_backend(meta, device) for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") + # NOTE: weights are often bf16, NumPy may not support it well. + data_ = safetensors.safe_open(file, framework="pt", device="cpu") for name_ in data_.keys(): - ## TODO: load the model weights - pass + t = data_.get_tensor(name_) + self._load_weight(name_, t) def generate( self, inputs: Sequence[int], - max_new_tokens: int = None, + max_new_tokens: Optional[int] = None, top_k: int = 1, top_p: float = 0.8, temperature: float = 0.8, ): - # TODO: Implement generate function + _ = (top_k, top_p, temperature) + if max_new_tokens is None: + max_new_tokens = 128 + + # loop: + # next_id = LIB_LLAISYS.llaisysQwen2ModelInfer(...) + # append -> stop on eos + token_ids = list(inputs) + eos_id = int(self._meta.end_token) if self._meta is not None else -1 + for _ in range(max_new_tokens): + arr = (ctypes.c_int64 * len(token_ids))(*token_ids) + next_id = int( + LIB_LLAISYS.llaisysQwen2ModelInfer(self._model, arr, len(token_ids)) + ) + token_ids.append(next_id) + if eos_id != -1 and next_id == eos_id: + break + print(next_id) + return token_ids + + def _create_backend(self, meta: LlaisysQwen2Meta, device: DeviceType): + device_ids = (ctypes.c_int * 1)(0) + m = LIB_LLAISYS.llaisysQwen2ModelCreate(ctypes.byref(meta), int(device), device_ids, 1) + if not m: + raise RuntimeError("llaisysQwen2ModelCreate failed (returned null)") + w_ptr = LIB_LLAISYS.llaisysQwen2ModelWeights(m) + return m, w_ptr, w_ptr.contents + + def _load_meta(self, model_path: Path) -> LlaisysQwen2Meta: + cfg_path = model_path / "config.json" + if not cfg_path.exists(): + raise FileNotFoundError(f"Missing config.json under: {model_path}") + + cfg = json.loads(cfg_path.read_text()) + + hs = int(cfg["hidden_size"]) + nh = int(cfg["num_attention_heads"]) + nkvh = int(cfg.get("num_key_value_heads", nh)) + nlayer = int(cfg["num_hidden_layers"]) + di = int(cfg["intermediate_size"]) + voc = int(cfg["vocab_size"]) + maxseq = int(cfg.get("max_position_embeddings", cfg.get("seq_length", 0))) + if maxseq <= 0: + maxseq = 4096 + + # Model-specific constants. + eps = float(cfg.get("rms_norm_eps", 1e-6)) + theta = float(cfg.get("rope_theta", 10000.0)) + end_token = int(cfg.get("eos_token_id", -1)) + # DType: keep bf16 by default (matching test/test_infer.py). + dtype = DataType.BF16 + + dh = hs // nh + return LlaisysQwen2Meta( + dtype=int(dtype), + nlayer=nlayer, + hs=hs, + nh=nh, + nkvh=nkvh, + dh=dh, + di=di, + maxseq=maxseq, + voc=voc, + epsilon=eps, + theta=theta, + end_token=end_token, + ) + + def _load_weight(self, name: str, t): + # t is a torch.Tensor when loaded via safetensors (framework="pt"). + # We only use torch here as a container for the raw bytes. + if not hasattr(t, "dtype"): + raise TypeError(f"Unexpected tensor type for {name}: {type(t)}") + + # Map torch dtype -> llaisys dtype + if str(t.dtype) == "torch.bfloat16": + dtype = DataType.BF16 + elif str(t.dtype) == "torch.float16": + dtype = DataType.F16 + elif str(t.dtype) == "torch.float32": + dtype = DataType.F32 + else: + raise ValueError(f"Unsupported dtype for {name}: {t.dtype}") + + # Always load from contiguous CPU tensor for now. + t_contig = t.contiguous().cpu() + shape = tuple(int(s) for s in t_contig.shape) + w = Tensor(shape=shape, dtype=dtype, device=self._device) + w.load(ctypes.c_void_p(int(t_contig.data_ptr()))) + + # IMPORTANT: keep it alive to avoid freeing the underlying llaisysTensor_t. + self._owned_tensors.append(w) + self._route_weight(name, w.lib_tensor()) + + def _route_weight(self, name: str, handle: llaisysTensor_t): + # Minimal mapping skeleton (Assignment #3). + if name == "model.embed_tokens.weight": + self._weights.in_embed = handle + return + if name == "lm_head.weight": + self._weights.out_embed = handle + return + if name == "model.norm.weight": + self._weights.out_norm_w = handle + return + + # Per-layer mappings + if name.startswith("model.layers."): + parts = name.split(".") + # Expect: model.layers.{i}.<...> + try: + layer = int(parts[2]) + except Exception: + print(f"TODO unmapped (bad layer parse): {name}") + return + + if name.endswith("input_layernorm.weight"): + self._weights.attn_norm_w[layer] = handle + return + if name.endswith("post_attention_layernorm.weight"): + self._weights.mlp_norm_w[layer] = handle + return + + # Attention projections + if "self_attn.q_proj.weight" in name: + self._weights.attn_q_w[layer] = handle + return + if "self_attn.q_proj.bias" in name: + self._weights.attn_q_b[layer] = handle + return + if "self_attn.k_proj.weight" in name: + self._weights.attn_k_w[layer] = handle + return + if "self_attn.k_proj.bias" in name: + self._weights.attn_k_b[layer] = handle + return + if "self_attn.v_proj.weight" in name: + self._weights.attn_v_w[layer] = handle + return + if "self_attn.v_proj.bias" in name: + self._weights.attn_v_b[layer] = handle + return + if "self_attn.o_proj.weight" in name: + self._weights.attn_o_w[layer] = handle + return + + # MLP projections + if "mlp.gate_proj.weight" in name: + self._weights.mlp_gate_w[layer] = handle + return + if "mlp.up_proj.weight" in name: + self._weights.mlp_up_w[layer] = handle + return + if "mlp.down_proj.weight" in name: + self._weights.mlp_down_w[layer] = handle + return - return [] + print(f"TODO unmapped: {name}") diff --git a/src/llaisys/models/qwen2.cc b/src/llaisys/models/qwen2.cc new file mode 100644 index 00000000..a9ac6908 --- /dev/null +++ b/src/llaisys/models/qwen2.cc @@ -0,0 +1,52 @@ +#include "llaisys/models/qwen2.h" + +#include "../llaisys_tensor.hpp" + +#include "../../models/qwen2/model.hpp" + +__C { + +struct LlaisysQwen2Model { + std::unique_ptr impl; +}; + +struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, + llaisysDeviceType_t device, + int *device_ids, + int ndevice) { + try { + CHECK_ARGUMENT(meta != nullptr, "llaisysQwen2ModelCreate: meta is null"); + auto *m = new LlaisysQwen2Model; + m->impl = llaisys::models::Qwen2Model::create(*meta, device, device_ids, ndevice); + return m; + } catch (const std::exception &e) { + std::cerr << "[ERROR] llaisysQwen2ModelCreate failed: " << e.what() << std::endl; + return nullptr; + } +} + +void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model *model) { + delete model; +} + +struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model *model) { + try { + CHECK_ARGUMENT(model != nullptr, "llaisysQwen2ModelWeights: model is null"); + return model->impl->weights(); + } catch (const std::exception &e) { + std::cerr << "[ERROR] llaisysQwen2ModelWeights failed: " << e.what() << std::endl; + return nullptr; + } +} + +int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken) { + try { + CHECK_ARGUMENT(model != nullptr, "llaisysQwen2ModelInfer: model is null"); + return model->impl->infer(token_ids, ntoken); + } catch (const std::exception &e) { + std::cerr << "[ERROR] llaisysQwen2ModelInfer failed: " << e.what() << std::endl; + // Be conservative: return EOS/end token so the caller can stop generation. + return model->impl ? model->impl->endToken() : -1; + } +} +} diff --git a/src/llaisys/ops.cc b/src/llaisys/ops.cc index c99fbc32..f5722eb3 100644 --- a/src/llaisys/ops.cc +++ b/src/llaisys/ops.cc @@ -23,7 +23,8 @@ __C { llaisys::ops::embedding(out->tensor, index->tensor, weight->tensor); } void llaisysLinear(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, llaisysTensor_t bias) { - llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, bias->tensor); + auto bias_ptr = bias ? bias->tensor : nullptr; + llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, bias_ptr); } void llaisysRearrange(llaisysTensor_t out, llaisysTensor_t in) { llaisys::ops::rearrange(out->tensor, in->tensor); diff --git a/src/models/qwen2/model.cpp b/src/models/qwen2/model.cpp new file mode 100644 index 00000000..0b1dc1da --- /dev/null +++ b/src/models/qwen2/model.cpp @@ -0,0 +1,516 @@ +#include "model.hpp" + +#include "../../llaisys/llaisys_tensor.hpp" +#include "../../ops/add/op.hpp" +#include "../../ops/argmax/op.hpp" +#include "../../ops/embedding/op.hpp" +#include "../../ops/linear/op.hpp" +#include "../../ops/rms_norm/op.hpp" +#include "../../ops/rope/op.hpp" +#include "../../ops/self_attention/op.hpp" +#include "../../ops/swiglu/op.hpp" +#include "../../tensor/tensor.hpp" +#include "../../utils.hpp" + +#include +#include +#include +#include +#include + +namespace llaisys::models { + +static inline llaisys::tensor_t unwrap_(llaisysTensor_t h, const char *what) { + CHECK_ARGUMENT(h != nullptr, what); + return h->tensor; +} + +static inline void ensure_kv_cache_allocated_(std::vector &k_cache, + std::vector &v_cache, + const LlaisysQwen2Meta &meta, + llaisysDeviceType_t device_type, + int device_id) { + const bool already_allocated = (k_cache.size() == meta.nlayer && v_cache.size() == meta.nlayer && meta.nlayer > 0 + && k_cache[0] != nullptr && v_cache[0] != nullptr); + if (already_allocated) { + return; + } + + k_cache.assign(meta.nlayer, nullptr); + v_cache.assign(meta.nlayer, nullptr); + for (size_t layer = 0; layer < meta.nlayer; layer++) { + k_cache[layer] = llaisys::Tensor::create({meta.maxseq, meta.nkvh, meta.dh}, meta.dtype, device_type, device_id); + v_cache[layer] = llaisys::Tensor::create({meta.maxseq, meta.nkvh, meta.dh}, meta.dtype, device_type, device_id); + } +} + +static inline void kv_cache_write_row_(llaisys::tensor_t cache, + llaisys::tensor_t kv_new, + size_t pos, + size_t nkvh, + size_t dh) { + CHECK_ARGUMENT(cache != nullptr && kv_new != nullptr, "kv_cache_write_row: null tensor"); + CHECK_ARGUMENT(cache->isContiguous() && kv_new->isContiguous(), "kv_cache_write_row: tensors must be contiguous"); + CHECK_ARGUMENT(cache->ndim() == 3 && kv_new->ndim() == 3, "kv_cache_write_row: expected 3D tensors"); + CHECK_ARGUMENT(kv_new->shape()[0] == 1 && kv_new->shape()[1] == nkvh && kv_new->shape()[2] == dh, + "kv_cache_write_row: kv_new must be [1, nkvh, dh]"); + CHECK_ARGUMENT(cache->shape()[0] > pos && cache->shape()[1] == nkvh && cache->shape()[2] == dh, + "kv_cache_write_row: cache must be [maxseq, nkvh, dh] with pos in range"); + + const size_t row_elems = nkvh * dh; + const size_t row_bytes = row_elems * cache->elementSize(); + std::memcpy(cache->data() + pos * row_bytes, kv_new->data(), row_bytes); +} + +std::unique_ptr Qwen2Model::create(const LlaisysQwen2Meta &meta, + llaisysDeviceType_t device, + const int *device_ids, + int ndevice) { + CHECK_ARGUMENT(meta.nlayer > 0, "Qwen2Model: meta.nlayer must be > 0"); + CHECK_ARGUMENT(meta.hs > 0 && meta.nh > 0 && meta.dh > 0, "Qwen2Model: invalid hidden/head sizes"); + CHECK_ARGUMENT(meta.nh % meta.nkvh == 0, "Qwen2Model: require nh % nkvh == 0"); + CHECK_ARGUMENT(meta.maxseq > 0 && meta.voc > 0, "Qwen2Model: invalid maxseq/vocab"); + + std::vector ids; + if (device_ids && ndevice > 0) { + ids.assign(device_ids, device_ids + ndevice); + } else { + ids.push_back(0); + } + + return std::unique_ptr(new Qwen2Model(meta, device, std::move(ids))); +} + +Qwen2Model::Qwen2Model(const LlaisysQwen2Meta &meta, + llaisysDeviceType_t device, + std::vector device_ids) + : meta_(meta), device_(device), device_ids_(std::move(device_ids)) { + // Allocate per-layer handle arrays (to be populated from Python). + attn_norm_w_.assign(meta_.nlayer, nullptr); + attn_q_w_.assign(meta_.nlayer, nullptr); + attn_q_b_.assign(meta_.nlayer, nullptr); + attn_k_w_.assign(meta_.nlayer, nullptr); + attn_k_b_.assign(meta_.nlayer, nullptr); + attn_v_w_.assign(meta_.nlayer, nullptr); + attn_v_b_.assign(meta_.nlayer, nullptr); + attn_o_w_.assign(meta_.nlayer, nullptr); + + mlp_norm_w_.assign(meta_.nlayer, nullptr); + mlp_gate_w_.assign(meta_.nlayer, nullptr); + mlp_up_w_.assign(meta_.nlayer, nullptr); + mlp_down_w_.assign(meta_.nlayer, nullptr); + + k_cache_.assign(meta_.nlayer, nullptr); + v_cache_.assign(meta_.nlayer, nullptr); + cur_len_ = 0; + initWeightsView_(); +} + +Qwen2Model::~Qwen2Model() { +} + +void Qwen2Model::initWeightsView_() { + std::memset(&weights_, 0, sizeof(weights_)); + // Global weights (to be filled by Python) + weights_.in_embed = nullptr; + weights_.out_embed = nullptr; + weights_.out_norm_w = nullptr; + + // Per-layer arrays (stable pointers to vector storage). + weights_.attn_norm_w = attn_norm_w_.data(); + weights_.attn_q_w = attn_q_w_.data(); + weights_.attn_q_b = attn_q_b_.data(); + weights_.attn_k_w = attn_k_w_.data(); + weights_.attn_k_b = attn_k_b_.data(); + weights_.attn_v_w = attn_v_w_.data(); + weights_.attn_v_b = attn_v_b_.data(); + weights_.attn_o_w = attn_o_w_.data(); + + weights_.mlp_norm_w = mlp_norm_w_.data(); + weights_.mlp_gate_w = mlp_gate_w_.data(); + weights_.mlp_up_w = mlp_up_w_.data(); + weights_.mlp_down_w = mlp_down_w_.data(); +} + +LlaisysQwen2Weights *Qwen2Model::weights() { + return &weights_; +} + +int64_t Qwen2Model::infer(const int64_t *token_ids, size_t ntoken) { + CHECK_ARGUMENT(token_ids != nullptr || ntoken == 0, "Qwen2ModelInfer: token_ids is null"); + + // Lesson 5 skeleton (scaffold, not full math): + // + // 0) Validate required weights are loaded (fail fast). + // 1) Token -> embedding + // - gather token embedding(s) from weights_.in_embed + // 2) For each layer i in [0, nlayer) + // 2.1) Attn norm (RMSNorm): x = rms_norm(x, attn_norm_w[i]) + // 2.2) QKV projections: q = xWq + bq, k = xWk + bk, v = xWv + bv + // 2.3) RoPE on q/k with position ids + // 2.4) Self-attention with causal mask + KV-cache + // 2.5) Output projection: attn_out = attn_val @ Wo + // 2.6) Residual add: x = x + attn_out + // + // 2.7) MLP norm (RMSNorm): x = rms_norm(x, mlp_norm_w[i]) + // 2.8) MLP projections: gate = xW_gate, up = xW_up + // 2.9) SwiGLU: act = swiglu(gate, up) + // 2.10) Down proj: mlp_out = act @ W_down + // 2.11) Residual add: x = x + mlp_out + // + // 3) Final norm: x = rms_norm(x, out_norm_w) + // 4) LM head: logits = x @ out_embed^T + // 5) Argmax (for --test): next_token = argmax(logits) + // + // KV-cache: allocate per-layer K/V buffers sized [maxseq, nkvh, dh] and update at each step. + + const char *hint = missingWeightsHint_(); + CHECK_ARGUMENT(hint == nullptr, hint); + + // Minimal "first real computation" skeleton: + // - last_token_id -> embedding -> final_norm -> lm_head -> argmax + // This is NOT a correct model forward yet (layers + KV-cache still TODO), but it validates + // that weights are usable and ops glue works. + CHECK_ARGUMENT(device_ == LLAISYS_DEVICE_CPU, "Qwen2Model: only CPU infer is wired in the skeleton"); + + if (ntoken == 0) { + return meta_.end_token; + } + CHECK_ARGUMENT(ntoken <= meta_.maxseq, "Qwen2ModelInfer: ntoken exceeds maxseq"); + if (0 == cur_len_) { + // prefill + ensure_kv_cache_allocated_(k_cache_, v_cache_, meta_, device_, 0); + + std::vector vec_id(ntoken); + for (size_t i = 0; i < ntoken; ++i) { + CHECK_ARGUMENT(token_ids[i] >= 0 && static_cast(token_ids[i]) < meta_.voc, "Qwen2ModelInfer: token id out of range"); + vec_id[i] = token_ids[i]; + } + + // index: [T] i64 + auto index = llaisys::Tensor::create({ntoken}, LLAISYS_DTYPE_I64, device_, 0); + index->load(vec_id.data()); + + // x: [T, H] + auto x = llaisys::Tensor::create({ntoken, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::embedding(x, index, unwrap_(weights_.in_embed, "Qwen2Model: weights.in_embed is null")); + + // pos_ids: [T] i64 + std::vector vec_pos(ntoken); + for (size_t i = 0; i < ntoken; ++i) { + vec_pos[i] = i; + } + auto pos_ids = llaisys::Tensor::create({ntoken}, LLAISYS_DTYPE_I64, device_, 0); + pos_ids->load(vec_pos.data()); + for (size_t layer = 0; layer < meta_.nlayer; ++layer) { + CHECK_ARGUMENT(k_cache_[layer] != nullptr && v_cache_[layer] != nullptr, "Qwen2ModelInfer(prefill): kv cache is null"); + CHECK_ARGUMENT(k_cache_[layer]->ndim() == 3 && v_cache_[layer]->ndim() == 3, "Qwen2ModelInfer(prefill): kv cache must be 3D"); + CHECK_ARGUMENT(k_cache_[layer]->shape()[0] == meta_.maxseq && v_cache_[layer]->shape()[0] == meta_.maxseq, + "Qwen2ModelInfer(prefill): kv cache maxseq mismatch"); + CHECK_ARGUMENT(k_cache_[layer]->shape()[1] == meta_.nkvh && v_cache_[layer]->shape()[1] == meta_.nkvh, + "Qwen2ModelInfer(prefill): kv cache nkvh mismatch"); + CHECK_ARGUMENT(k_cache_[layer]->shape()[2] == meta_.dh && v_cache_[layer]->shape()[2] == meta_.dh, + "Qwen2ModelInfer(prefill): kv cache dh mismatch"); + + // attn_norm:[T,H] + auto x_attn_in = llaisys::Tensor::create({ntoken, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::rms_norm(x_attn_in, x, unwrap_(weights_.attn_norm_w[layer], "Qwen2Model: weights.attn_norm_w is null"), meta_.epsilon); + + // q2d:[T,H] + auto q2d = llaisys::Tensor::create({ntoken, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::linear( + q2d, + x_attn_in, + unwrap_(weights_.attn_q_w[layer], "Qwen2Model: weights.attn_q_w is null"), + unwrap_(weights_.attn_q_b[layer], "Qwen2Model: weights.attn_q_b is null")); + + // k2d:[T,nkvh*hd] + auto k2d = llaisys::Tensor::create({ntoken, meta_.nkvh * meta_.dh}, meta_.dtype, device_, 0); + llaisys::ops::linear( + k2d, + x_attn_in, + unwrap_(weights_.attn_k_w[layer], "Qwen2Model: weights.attn_k_w is null"), + unwrap_(weights_.attn_k_b[layer], "Qwen2Model: weights.attn_k_b is null")); + + // v2d:[T,nkvh*hd] + auto v2d = llaisys::Tensor::create({ntoken, meta_.nkvh * meta_.dh}, meta_.dtype, device_, 0); + llaisys::ops::linear( + v2d, + x_attn_in, + unwrap_(weights_.attn_v_w[layer], "Qwen2Model: weights.attn_v_w is null"), + unwrap_(weights_.attn_v_b[layer], "Qwen2Model: weights.attn_v_b is null")); + + // q:[T,nh,nd] + auto q = q2d->view({ntoken, meta_.nh, meta_.dh}); + + // k:[T,nkvh,nd] + auto k = k2d->view({ntoken, meta_.nkvh, meta_.dh}); + + // v:[T,nkvh,nd] + auto v = v2d->view({ntoken, meta_.nkvh, meta_.dh}); + for (size_t i = 0; i < ntoken; i++) { + auto slice = v->slice(0, i, i + 1); + kv_cache_write_row_(v_cache_[layer], slice, i, meta_.nkvh, meta_.dh); + } + // q_rope [T, nh, hd] + auto q_rope = llaisys::Tensor::create({ntoken, meta_.nh, meta_.dh}, meta_.dtype, device_, 0); + llaisys::ops::rope(q_rope, q, pos_ids, meta_.theta); + + // k_rope [T, nkvh, hd] + auto k_rope = llaisys::Tensor::create({ntoken, meta_.nkvh, meta_.dh}, meta_.dtype, device_, 0); + llaisys::ops::rope(k_rope, k, pos_ids, meta_.theta); + for (size_t i = 0; i < ntoken; i++) { + auto slice = k_rope->slice(0, i, i + 1); + kv_cache_write_row_(k_cache_[layer], slice, i, meta_.nkvh, meta_.dh); + } + + // attn_val [T,nh,hd] + auto attn_val = llaisys::Tensor::create({ntoken, meta_.nh, meta_.dh}, meta_.dtype, device_, 0); + llaisys::ops::self_attention(attn_val, q_rope, k_rope, v, float(1.0f / std::sqrt(meta_.dh))); + + // attn2d [T,H] + auto attn2d = attn_val->view({ntoken, meta_.hs}); + + // attn_out [T,H] + auto attn_out = llaisys::Tensor::create({ntoken, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::linear(attn_out, attn2d, unwrap_(weights_.attn_o_w[layer], "Qwen2Model: weights.attn_o_ws is null"), nullptr); + + // x_after_attn [T,H] + auto x_after_attn = llaisys::Tensor::create({ntoken, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::add(x_after_attn, x, attn_out); + + // x_mlp_in [T,H] + auto x_mlp_in = llaisys::Tensor::create({ntoken, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::rms_norm(x_mlp_in, x_after_attn, unwrap_(weights_.mlp_norm_w[layer], "Qwen2Model: weights.mlp_norm_w is null"), meta_.epsilon); + + // gate [T,di] + auto gate = llaisys::Tensor::create({ntoken, meta_.di}, meta_.dtype, device_, 0); + llaisys::ops::linear(gate, x_mlp_in, unwrap_(weights_.mlp_gate_w[layer], "Qwen2Model: weights_.mlp_gate_w is null"), nullptr); + + // up [T,di] + auto up = llaisys::Tensor::create({ntoken, meta_.di}, meta_.dtype, device_, 0); + llaisys::ops::linear(up, x_mlp_in, unwrap_(weights_.mlp_up_w[layer], "Qwen2Model: weights_.mlp_up_w is null"), nullptr); + + // act [T,di] + auto act = llaisys::Tensor::create({ntoken, meta_.di}, meta_.dtype, device_, 0); + llaisys::ops::swiglu(act, gate, up); + + // mlp_out [T,H] + auto mlp_out = llaisys::Tensor::create({ntoken, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::linear(mlp_out, act, unwrap_(weights_.mlp_down_w[layer], "Qwen2Model: weights_.mlp_down_w is null"), nullptr); + + // x_after_mlp [T,H] + llaisys::ops::add(x, x_after_attn, mlp_out); + } + + // final norm: [T, H] + auto last_x = x->slice(0, ntoken - 1, ntoken); + auto x_norm = llaisys::Tensor::create({1, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::rms_norm(x_norm, last_x, unwrap_(weights_.out_norm_w, "Qwen2Model: weights.out_norm_w is null"), meta_.epsilon); + + // logits2d: [1, vocab] + auto logits2d = llaisys::Tensor::create({1, meta_.voc}, meta_.dtype, device_, 0); + llaisys::ops::linear( + logits2d, + x_norm, + unwrap_(weights_.out_embed, "Qwen2Model: weights.out_embed is null"), + nullptr); + + // argmax over vocab (treat logits as 1D [vocab]) + auto logits = logits2d->view({meta_.voc}); + auto max_idx = llaisys::Tensor::create({1}, LLAISYS_DTYPE_I64, device_, 0); + auto max_val = llaisys::Tensor::create({1}, meta_.dtype, device_, 0); + llaisys::ops::argmax(max_idx, max_val, logits); + + // Argmax kernel writes size_t; on 64-bit this fits in I64 storage. + const size_t idx = *reinterpret_cast(max_idx->data()); + cur_len_ = ntoken; + return static_cast(idx); + } else { + // decode + CHECK_ARGUMENT(ntoken == cur_len_ + 1, "Qwen2Model: ntoken error"); + ensure_kv_cache_allocated_(k_cache_, v_cache_, meta_, device_, 0); + + // index: [1] i64 + int64_t last_token_id = token_ids[ntoken - 1]; + CHECK_ARGUMENT(last_token_id >= 0 && static_cast(last_token_id) < meta_.voc, "Qwen2ModelInfer(decode): token id out of range"); + auto index = llaisys::Tensor::create({1}, LLAISYS_DTYPE_I64, device_, 0); + index->load(&last_token_id); + + // x: [1, H] + auto x = llaisys::Tensor::create({1, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::embedding(x, index, unwrap_(weights_.in_embed, "Qwen2Model: weights.in_embed is null")); + + // pos_ids: [1] i64 + auto pos_ids = llaisys::Tensor::create({1}, LLAISYS_DTYPE_I64, device_, 0); + int64_t pos = ntoken - 1; + CHECK_ARGUMENT(pos >= 0 && static_cast(pos) < meta_.maxseq, "Qwen2ModelInfer(decode): pos exceeds maxseq"); + pos_ids->load(&pos); + for (size_t layer = 0; layer < meta_.nlayer; ++layer) { + CHECK_ARGUMENT(k_cache_[layer] != nullptr && v_cache_[layer] != nullptr, "Qwen2ModelInfer(decode): kv cache is null"); + + // attn_norm:[1,H] + auto x_attn_in = llaisys::Tensor::create({1, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::rms_norm(x_attn_in, x, unwrap_(weights_.attn_norm_w[layer], "Qwen2Model: weights.attn_norm_w is null"), meta_.epsilon); + + // q2d:[1,H] + auto q2d = llaisys::Tensor::create({1, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::linear( + q2d, + x_attn_in, + unwrap_(weights_.attn_q_w[layer], "Qwen2Model: weights.attn_q_w is null"), + unwrap_(weights_.attn_q_b[layer], "Qwen2Model: weights.attn_q_b is null")); + + // k2d:[1,nkvh*hd] + auto k2d = llaisys::Tensor::create({1, meta_.nkvh * meta_.dh}, meta_.dtype, device_, 0); + llaisys::ops::linear( + k2d, + x_attn_in, + unwrap_(weights_.attn_k_w[layer], "Qwen2Model: weights.attn_k_w is null"), + unwrap_(weights_.attn_k_b[layer], "Qwen2Model: weights.attn_k_b is null")); + + // v2d:[1,nkvh*hd] + auto v2d = llaisys::Tensor::create({1, meta_.nkvh * meta_.dh}, meta_.dtype, device_, 0); + llaisys::ops::linear( + v2d, + x_attn_in, + unwrap_(weights_.attn_v_w[layer], "Qwen2Model: weights.attn_v_w is null"), + unwrap_(weights_.attn_v_b[layer], "Qwen2Model: weights.attn_v_b is null")); + + // q:[1,nh,nd] + auto q = q2d->view({1, meta_.nh, meta_.dh}); + + // k:[1,nkvh,nd] + auto k = k2d->view({1, meta_.nkvh, meta_.dh}); + + // v:[1,nkvh,nd] + auto v = v2d->view({1, meta_.nkvh, meta_.dh}); + kv_cache_write_row_(v_cache_[layer], v, pos, meta_.nkvh, meta_.dh); + + // q_rope [1, nh, hd] + auto q_rope = llaisys::Tensor::create({1, meta_.nh, meta_.dh}, meta_.dtype, device_, 0); + llaisys::ops::rope(q_rope, q, pos_ids, meta_.theta); + + // k_rope [1, nkvh, hd] + auto k_rope = llaisys::Tensor::create({1, meta_.nkvh, meta_.dh}, meta_.dtype, device_, 0); + llaisys::ops::rope(k_rope, k, pos_ids, meta_.theta); + kv_cache_write_row_(k_cache_[layer], k_rope, pos, meta_.nkvh, meta_.dh); + + // attn_val [1,nh,hd] + auto attn_val = llaisys::Tensor::create({1, meta_.nh, meta_.dh}, meta_.dtype, device_, 0); + + size_t kvlen = pos + 1; + CHECK_ARGUMENT(kvlen <= meta_.maxseq, "Qwen2ModelInfer(decode): kvlen exceeds maxseq"); + // k_prefix [kvlen,nkvh,hd] + auto k_prefix = k_cache_[layer]->slice(0, 0, kvlen); + + // v_prefix [kvlen,nkvh,hd] + auto v_prefix = v_cache_[layer]->slice(0, 0, kvlen); + llaisys::ops::self_attention(attn_val, q_rope, k_prefix, v_prefix, float(1.0f / std::sqrt(meta_.dh))); + + // attn2d [1,H] + auto attn2d = attn_val->view({1, meta_.hs}); + + // attn_out [1,H] + auto attn_out = llaisys::Tensor::create({1, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::linear(attn_out, attn2d, unwrap_(weights_.attn_o_w[layer], "Qwen2Model: weights.attn_o_ws is null"), nullptr); + + // x_after_attn [1,H] + auto x_after_attn = llaisys::Tensor::create({1, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::add(x_after_attn, x, attn_out); + + // x_mlp_in [1,H] + auto x_mlp_in = llaisys::Tensor::create({1, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::rms_norm(x_mlp_in, x_after_attn, unwrap_(weights_.mlp_norm_w[layer], "Qwen2Model: weights.mlp_norm_w is null"), meta_.epsilon); + + // gate [1,di] + auto gate = llaisys::Tensor::create({1, meta_.di}, meta_.dtype, device_, 0); + llaisys::ops::linear(gate, x_mlp_in, unwrap_(weights_.mlp_gate_w[layer], "Qwen2Model: weights_.mlp_gate_w is null"), nullptr); + + // up [1,di] + auto up = llaisys::Tensor::create({1, meta_.di}, meta_.dtype, device_, 0); + llaisys::ops::linear(up, x_mlp_in, unwrap_(weights_.mlp_up_w[layer], "Qwen2Model: weights_.mlp_up_w is null"), nullptr); + + // act [1,di] + auto act = llaisys::Tensor::create({1, meta_.di}, meta_.dtype, device_, 0); + llaisys::ops::swiglu(act, gate, up); + + // mlp_out [1,H] + auto mlp_out = llaisys::Tensor::create({1, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::linear(mlp_out, act, unwrap_(weights_.mlp_down_w[layer], "Qwen2Model: weights_.mlp_down_w is null"), nullptr); + + // x_after_mlp [1,H] + llaisys::ops::add(x, x_after_attn, mlp_out); + } + + // final norm: [1, H] + auto last_x = x->slice(0, 0, 1); + auto x_norm = llaisys::Tensor::create({1, meta_.hs}, meta_.dtype, device_, 0); + llaisys::ops::rms_norm(x_norm, last_x, unwrap_(weights_.out_norm_w, "Qwen2Model: weights.out_norm_w is null"), meta_.epsilon); + + // logits2d: [1, vocab] + auto logits2d = llaisys::Tensor::create({1, meta_.voc}, meta_.dtype, device_, 0); + llaisys::ops::linear( + logits2d, + x_norm, + unwrap_(weights_.out_embed, "Qwen2Model: weights.out_embed is null"), + nullptr); + + // argmax over vocab (treat logits as 1D [vocab]) + auto logits = logits2d->view({meta_.voc}); + auto max_idx = llaisys::Tensor::create({1}, LLAISYS_DTYPE_I64, device_, 0); + auto max_val = llaisys::Tensor::create({1}, meta_.dtype, device_, 0); + llaisys::ops::argmax(max_idx, max_val, logits); + + // Argmax kernel writes size_t; on 64-bit this fits in I64 storage. + const size_t idx = *reinterpret_cast(max_idx->data()); + cur_len_ = ntoken; + return static_cast(idx); + } +} + +const char *Qwen2Model::missingWeightsHint_() const { + if (weights_.in_embed == nullptr) { + return "Qwen2Model: missing weights.in_embed (model.embed_tokens.weight)"; + } + if (weights_.out_embed == nullptr) { + return "Qwen2Model: missing weights.out_embed (lm_head.weight)"; + } + if (weights_.out_norm_w == nullptr) { + return "Qwen2Model: missing weights.out_norm_w (model.norm.weight)"; + } + + // Spot-check layer 0 to catch mapping bugs early. + if (!weights_.attn_norm_w || weights_.attn_norm_w[0] == nullptr) { + return "Qwen2Model: missing weights.attn_norm_w[0]"; + } + if (!weights_.attn_q_w || weights_.attn_q_w[0] == nullptr) { + return "Qwen2Model: missing weights.attn_q_w[0]"; + } + if (!weights_.attn_k_w || weights_.attn_k_w[0] == nullptr) { + return "Qwen2Model: missing weights.attn_k_w[0]"; + } + if (!weights_.attn_v_w || weights_.attn_v_w[0] == nullptr) { + return "Qwen2Model: missing weights.attn_v_w[0]"; + } + if (!weights_.attn_o_w || weights_.attn_o_w[0] == nullptr) { + return "Qwen2Model: missing weights.attn_o_w[0]"; + } + + if (!weights_.mlp_norm_w || weights_.mlp_norm_w[0] == nullptr) { + return "Qwen2Model: missing weights.mlp_norm_w[0]"; + } + if (!weights_.mlp_gate_w || weights_.mlp_gate_w[0] == nullptr) { + return "Qwen2Model: missing weights.mlp_gate_w[0]"; + } + if (!weights_.mlp_up_w || weights_.mlp_up_w[0] == nullptr) { + return "Qwen2Model: missing weights.mlp_up_w[0]"; + } + if (!weights_.mlp_down_w || weights_.mlp_down_w[0] == nullptr) { + return "Qwen2Model: missing weights.mlp_down_w[0]"; + } + + return nullptr; +} + +} // namespace llaisys::models diff --git a/src/models/qwen2/model.hpp b/src/models/qwen2/model.hpp new file mode 100644 index 00000000..d512fda2 --- /dev/null +++ b/src/models/qwen2/model.hpp @@ -0,0 +1,80 @@ +#pragma once + +#include "llaisys/models/qwen2.h" +#include "llaisys/tensor.h" + +#include +#include +#include + +namespace llaisys::models { + +// Forward-declare internal tensor handle type (defined in src/tensor/tensor.hpp). +// We keep it opaque here to avoid pulling in the full tensor header. +} // namespace llaisys::models + +namespace llaisys { +class Tensor; +using tensor_t = std::shared_ptr; +} // namespace llaisys + +namespace llaisys::models { + +// Skeleton for Assignment #3. +// Keep the model logic here (outside src/llaisys/), and expose it via C API in src/llaisys/models/. +class Qwen2Model { +public: + static std::unique_ptr create(const LlaisysQwen2Meta &meta, + llaisysDeviceType_t device, + const int *device_ids, + int ndevice); + + ~Qwen2Model(); + + // The returned pointer is owned by the model and valid until destroy(). + LlaisysQwen2Weights *weights(); + + // Infer one next token given existing token_ids[0..ntoken). + // NOTE: This is intentionally a stub for the assignment skeleton. + int64_t infer(const int64_t *token_ids, size_t ntoken); + + int64_t endToken() const { return meta_.end_token; } + +private: + explicit Qwen2Model(const LlaisysQwen2Meta &meta, + llaisysDeviceType_t device, + std::vector device_ids); + + void initWeightsView_(); + const char *missingWeightsHint_() const; + +private: + LlaisysQwen2Meta meta_{}; + llaisysDeviceType_t device_{LLAISYS_DEVICE_CPU}; + std::vector device_ids_; + + // C-facing view of weights. Pointers inside refer to the vectors below. + LlaisysQwen2Weights weights_{}; + + // Per-layer weight handles (filled by Python via weights()). + std::vector attn_norm_w_; + std::vector attn_q_w_; + std::vector attn_q_b_; + std::vector attn_k_w_; + std::vector attn_k_b_; + std::vector attn_v_w_; + std::vector attn_v_b_; + std::vector attn_o_w_; + + std::vector mlp_norm_w_; + std::vector mlp_gate_w_; + std::vector mlp_up_w_; + std::vector mlp_down_w_; + + // TODO(assignment-3): add KV-cache storage here. + std::vector<::llaisys::tensor_t> k_cache_; + std::vector<::llaisys::tensor_t> v_cache_; + size_t cur_len_{0}; +}; + +} // namespace llaisys::models diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 00000000..ddeb8970 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,43 @@ +#include "argmax_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +void argmax_(size_t *max_idx, T *max_vals, const T *vals, size_t numel) { + size_t max_idx_ = 0; + T max_val_ = vals[0]; + for (size_t i = 1; i < numel; i++) { + if constexpr (std::is_same_v || std::is_same_v) { + if (llaisys::utils::cast(vals[i]) > llaisys::utils::cast(max_val_)) { + max_idx_ = i; + max_val_ = llaisys::utils::cast(vals[i]); + } + } else { + if (vals[i] > max_val_) { + max_idx_ = i; + max_val_ = vals[i]; + } + } + } + *max_idx = max_idx_; + *max_vals = max_val_; +} + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_vals, const std::byte *vals, llaisysDataType_t type, size_t numel) { + switch (type) { + case LLAISYS_DTYPE_F32: + return argmax_(reinterpret_cast(max_idx), reinterpret_cast(max_vals), reinterpret_cast(vals), numel); + case LLAISYS_DTYPE_BF16: + return argmax_(reinterpret_cast(max_idx), reinterpret_cast(max_vals), + reinterpret_cast(vals), numel); + case LLAISYS_DTYPE_F16: + return argmax_(reinterpret_cast(max_idx), reinterpret_cast(max_vals), + reinterpret_cast(vals), numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 00000000..0c362ee4 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t size); +} \ No newline at end of file diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d42..091e8229 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,31 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/argmax_cpu.hpp" namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(max_idx, max_val, vals); + // Only support contiguous inputs with same shape for now. + ASSERT(vals->isContiguous(), "Argmax: tensor vals must be contiguous."); + + // always support cpu calculation + if (max_idx->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); + } + + llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId()); + + switch (vals->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 00000000..ee018ce6 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,46 @@ +#include "embedding_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +template +void embedding_(T *out, const size_t *index, const T *weight, + const size_t idx_numel, const size_t dim0, const size_t dim1) { + // 每行拷贝的字节数 + const size_t row_bytes = dim1 * sizeof(T); + + for (size_t i = 0; i < idx_numel; ++i) { + size_t idx = index[i]; + + // 边界检查:确保索引不越界 + ASSERT(idx < dim0, "Embedding: index out of vocabulary range"); + + // 地址计算 + T *dst_ptr = out + (i * dim1); + const T *src_ptr = weight + (idx * dim1); + + // 执行拷贝 + memcpy(dst_ptr, src_ptr, row_bytes); + } +} + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, const size_t idx_numel, const size_t dim0, const size_t dim1) { + switch (type) { + case LLAISYS_DTYPE_F32: + return embedding_(reinterpret_cast(out), reinterpret_cast(index), + reinterpret_cast(weight), idx_numel, dim0, dim1); + case LLAISYS_DTYPE_BF16: + return embedding_(reinterpret_cast(out), reinterpret_cast(index), + reinterpret_cast(weight), idx_numel, dim0, dim1); + case LLAISYS_DTYPE_F16: + return embedding_(reinterpret_cast(out), reinterpret_cast(index), + reinterpret_cast(weight), idx_numel, dim0, dim1); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 00000000..f0ffa1b4 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, const size_t idx_numel, const size_t dim0, const size_t dim1); +} \ No newline at end of file diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d0..5ea793bc 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,34 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/embedding_cpu.hpp" + namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, index, weight); + // Only support contiguous inputs with same shape for now. + CHECK_SAME_DTYPE(out->dtype(), weight->dtype()); + ASSERT(out->isContiguous() && index->isContiguous() && weight->isContiguous(), "Embedding: all tensors must be contiguous."); + + // always support cpu calculation + if (weight->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::embedding(out->data(), index->data(), weight->data(),weight->dtype(), index->numel(), weight->shape()[0],weight->shape()[1]); + } + + llaisys::core::context().setDevice(weight->deviceType(), weight->deviceId()); + + switch (weight->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::embedding(out->data(), index->data(), weight->data(),weight->dtype(), index->numel(), weight->shape()[0],weight->shape()[1]); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 00000000..bc16a3a0 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,53 @@ +#include "linear_cpu.hpp" + +#include "../../../utils.hpp" +#include + +template +void linear_(T *out, const T *in, const T *weight, const size_t B, const size_t K, const size_t M, const T *bias = nullptr) { + for (size_t i = 0; i < B; ++i) { + for (size_t j = 0; j < M; ++j) { + if constexpr (std::is_same_v || std::is_same_v) { + float acc = bias ? llaisys::utils::cast(bias[j]) : llaisys::utils::cast(0); + for (size_t k = 0; k < K; ++k) { + acc += llaisys::utils::cast(in[k + i * K]) * llaisys::utils::cast(weight[k + j * K]); + } + out[i * M + j] = llaisys::utils::cast(acc); + } else { + T acc = bias ? bias[j] : 0; + for (size_t k = 0; k < K; ++k) { + acc += in[k + i * K] * weight[k + j * K]; + } + out[i * M + j] = acc; + } + } + } +} + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, const size_t B, const size_t K, const size_t M) { + switch (type) { + case LLAISYS_DTYPE_F32: + return linear_(reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + B, K, M, + reinterpret_cast(bias)); + case LLAISYS_DTYPE_BF16: + return linear_(reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + B, K, M, + reinterpret_cast(bias)); + case LLAISYS_DTYPE_F16: + return linear_(reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + B, K, M, + reinterpret_cast(bias)); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 00000000..06457dfa --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, llaisysDataType_t type, + const size_t B, const size_t K, const size_t M); +} \ No newline at end of file diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f865..40b87f93 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,47 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/linear_cpu.hpp" + namespace llaisys::ops { void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight); + CHECK_SAME_DTYPE(out->dtype(), in->dtype(), weight->dtype()); + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), "linear: all tensors must be contiguous."); + + ASSERT(out->ndim() == 2 && in->ndim() == 2 && weight->ndim() == 2, "linear: out/in/weight must be 2D"); + const size_t B = in->shape()[0]; + const size_t K = in->shape()[1]; + const size_t M = weight->shape()[0]; + ASSERT(weight->shape()[1] == K, "linear: weight shape mismatch (expect [M, K])"); + ASSERT(out->shape()[0] == B && out->shape()[1] == M, "linear: out shape mismatch (expect [B, M])"); + + auto bias_ptr = bias ? bias->data() : nullptr; + if (bias_ptr) { + CHECK_SAME_DEVICE(out, bias); + CHECK_SAME_DTYPE(out->dtype(), bias->dtype()); + ASSERT(bias->isContiguous(), "linear: all tensors must be contiguous."); + ASSERT(bias->ndim() == 1 && bias->shape()[0] == M, "linear: bias shape mismatch (expect [M])"); + } + // always support cpu calculation + if (weight->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::linear(out->data(), in->data(), weight->data(), bias_ptr, weight->dtype(), B, K, M); + } + + llaisys::core::context().setDevice(weight->deviceType(), weight->deviceId()); + + switch (weight->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::linear(out->data(), in->data(), weight->data(), bias_ptr, weight->dtype(), B, K, M); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rearrange/cpu/rearrange_cpu.cpp b/src/ops/rearrange/cpu/rearrange_cpu.cpp new file mode 100644 index 00000000..bea2132d --- /dev/null +++ b/src/ops/rearrange/cpu/rearrange_cpu.cpp @@ -0,0 +1,61 @@ +#include "rearrange_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +namespace { + +static inline size_t numel_(const size_t *shape, size_t ndim) { + size_t n = 1; + for (size_t i = 0; i < ndim; i++) n *= shape[i]; + return n; +} + +// Compute the offset in elements for a given linear index under `shape` and `strides`. +static inline size_t offset_from_linear_(size_t linear, + const size_t *shape, + const ptrdiff_t *strides, + size_t ndim) { + // Unravel `linear` into indices, then dot with strides. + // + // strides are in elements (not bytes) and may be non-contiguous. We assume + // non-negative strides for now (the framework doesn't support negative strides yet). + ptrdiff_t off = 0; + for (size_t d = 0; d < ndim; d++) { + const size_t dim = ndim - 1 - d; // last dim first + const size_t size_d = shape[dim]; + const size_t idx_d = linear % size_d; + linear /= size_d; + off += static_cast(idx_d) * strides[dim]; + } + CHECK_ARGUMENT(off >= 0, "rearrange_cpu: negative offset (negative strides not supported)"); + return static_cast(off); +} + +} // namespace + +namespace llaisys::ops::cpu { + +void rearrange(std::byte *out, + const std::byte *in, + llaisysDataType_t dtype, + const size_t *shape, + const ptrdiff_t *out_strides, + const ptrdiff_t *in_strides, + size_t ndim) { + CHECK_ARGUMENT(out != nullptr && in != nullptr, "rearrange_cpu: null data ptr"); + CHECK_ARGUMENT(shape != nullptr && out_strides != nullptr && in_strides != nullptr, "rearrange_cpu: null meta ptr"); + CHECK_ARGUMENT(ndim > 0, "rearrange_cpu: ndim must be > 0"); + + const size_t esize = llaisys::utils::dsize(dtype); + const size_t n = numel_(shape, ndim); + + for (size_t linear = 0; linear < n; linear++) { + const size_t in_off = offset_from_linear_(linear, shape, in_strides, ndim); + const size_t out_off = offset_from_linear_(linear, shape, out_strides, ndim); + std::memcpy(out + out_off * esize, in + in_off * esize, esize); + } +} + +} // namespace llaisys::ops::cpu diff --git a/src/ops/rearrange/cpu/rearrange_cpu.hpp b/src/ops/rearrange/cpu/rearrange_cpu.hpp new file mode 100644 index 00000000..ec981b19 --- /dev/null +++ b/src/ops/rearrange/cpu/rearrange_cpu.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { + +// Copy elements from `in` to `out` with arbitrary strides. +// +// - `shape` is in elements, length = `ndim` +// - `*_strides` are in elements (not bytes), length = `ndim` +// - `out` and `in` must have the same dtype and shape +void rearrange(std::byte *out, + const std::byte *in, + llaisysDataType_t dtype, + const size_t *shape, + const ptrdiff_t *out_strides, + const ptrdiff_t *in_strides, + size_t ndim); + +} // namespace llaisys::ops::cpu + diff --git a/src/ops/rearrange/op.cpp b/src/ops/rearrange/op.cpp index 017a6ae5..e05d479b 100644 --- a/src/ops/rearrange/op.cpp +++ b/src/ops/rearrange/op.cpp @@ -1,7 +1,43 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rearrange_cpu.hpp" + namespace llaisys::ops { void rearrange(tensor_t out, tensor_t in) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in); + CHECK_SAME_DTYPE(out->dtype(), in->dtype()); + CHECK_SAME_SHAPE(out->shape(), in->shape()); + + // NOTE: unlike other ops, rearrange is intended to work with non-contiguous tensors. + // It copies elements from `in` to `out` respecting each tensor's strides. + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rearrange(out->data(), in->data(), out->dtype(), + out->shape().data(), + out->strides().data(), + in->strides().data(), + out->ndim()); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rearrange(out->data(), in->data(), out->dtype(), + out->shape().data(), + out->strides().data(), + in->strides().data(), + out->ndim()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.cpp b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp new file mode 100644 index 00000000..3a42931c --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp @@ -0,0 +1,57 @@ +#include "rms_norm_cpu.hpp" + +#include "../../../utils.hpp" +#include +#include +#include +#include + +template +void rms_norm_(T *out, const T *in, const T *weight, const float eps, + const size_t B, const size_t K) { + for (size_t i = 0; i < B; ++i) { + float sum_sq = 0.0f; + + // x 按行求平方和 + for (size_t j = 0; j < K; ++j) { + if constexpr (std::is_same_v || std::is_same_v) { + float val = llaisys::utils::cast(in[i * K + j]); + sum_sq += val * val; + } else { + float val = in[i * K + j]; + sum_sq += val * val; + } + } + + // 按行求平方根 + float scale = 1.0f / std::sqrt(sum_sq / K + eps); + + // x 和 w 相同位置元素相乘 除以 平方根 + for (size_t j = 0; j < K; ++j) { + if constexpr (std::is_same_v || std::is_same_v) { + out[i * K + j] = llaisys::utils::cast(llaisys::utils::cast(in[i * K + j]) * llaisys::utils::cast(weight[j]) * scale); + } else { + out[i * K + j] = in[i * K + j] * weight[j] * scale; + } + } + } +} + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, const float eps, + llaisysDataType_t type, const size_t B, const size_t K) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), + eps, B, K); + case LLAISYS_DTYPE_BF16: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), + eps, B, K); + case LLAISYS_DTYPE_F16: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), + eps, B, K); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.hpp b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp new file mode 100644 index 00000000..93da9f70 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, const float eps, + llaisysDataType_t type, const size_t B, const size_t K); +} \ No newline at end of file diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9..07e39f50 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,33 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rms_norm_cpu.hpp" namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight); + // Only support contiguous inputs with same shape for now. + CHECK_SAME_DTYPE(in->dtype(), out->dtype()); + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), "res_norm: all tensors must be contiguous."); + + // always support cpu calculation + if (weight->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rms_norm(out->data(), in->data(), weight->data(), eps, weight->dtype(), in->shape()[0], in->shape()[1]); + } + + llaisys::core::context().setDevice(weight->deviceType(), weight->deviceId()); + + switch (weight->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rms_norm(out->data(), in->data(), weight->data(), eps, weight->dtype(), in->shape()[0], in->shape()[1]); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 00000000..ef07b517 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,73 @@ +#include "rope_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include +#include +#include + +template +void rope_(T *out, const T *in, const int64_t *pos_ids, + const size_t seqlen, const size_t nhead, const size_t d, float theta) { + size_t half = d / 2; + std::vector denom(half); + for (size_t j = 0; j < half; ++j) { + float exp = (2.0f * j) / d; // 2j/d + denom[j] = std::pow(theta, exp); // θ^(2j/d) + } + + std::vector sin_cache(seqlen * half); + std::vector cos_cache(seqlen * half); + for (size_t i = 0; i < seqlen; ++i) { + float pos = static_cast(pos_ids[i]); + for (size_t j = 0; j < half; ++j) { + float phi = pos / denom[j]; + sin_cache[i * half + j] = sinf(phi); + cos_cache[i * half + j] = cosf(phi); + } + } + + for (size_t i = 0; i < seqlen; ++i) { + for (size_t k = 0; k < nhead; ++k) { + size_t base = i * (nhead * d) + k * d; + for (size_t j = 0; j < half; ++j) { + float cos_v = cos_cache[i * half + j]; + float sin_v = sin_cache[i * half + j]; + if constexpr ((std::is_same_v || std::is_same_v)) { + out[base + j] + = llaisys::utils::cast(llaisys::utils::cast(in[base + j]) * cos_v + - llaisys::utils::cast(in[base + j + half]) * sin_v); + out[base + j + half] + = llaisys::utils::cast(llaisys::utils::cast(in[base + j + half]) * cos_v + + llaisys::utils::cast(in[base + j]) * sin_v); + } else { + out[base + j] = in[base + j] * cos_v + - in[base + j + half] * sin_v; + out[base + j + half] = in[base + j + half] * cos_v + + in[base + j] * sin_v; + } + } + } + } +} + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, + llaisysDataType_t type, const size_t seqlen, const size_t nhead, const size_t d, float theta) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rope_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(pos_ids), seqlen, nhead, d, theta); + case LLAISYS_DTYPE_BF16: + return rope_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(pos_ids), seqlen, nhead, d, theta); + case LLAISYS_DTYPE_F16: + return rope_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(pos_ids), seqlen, nhead, d, theta); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 00000000..7484931f --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, + llaisysDataType_t type, const size_t seq_len, const size_t nhead, const size_t d, float theta); +} \ No newline at end of file diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64..f3e67629 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,42 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rope_cpu.hpp" + namespace llaisys::ops { void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, pos_ids); + CHECK_SAME_DTYPE(out->dtype(), in->dtype()); + // Only support contiguous inputs for now. + ASSERT(out->isContiguous() && in->isContiguous() && pos_ids->isContiguous(), "RoPE: all tensors must be contiguous."); + ASSERT(out->ndim() == 3, "RoPE: out dimension must be 3"); + ASSERT(pos_ids->ndim() == 1, "RoPE: pos_ids dimension must be 1"); + ASSERT(pos_ids->dtype() == LLAISYS_DTYPE_I64, "RoPE: pos_ids must be int64"); + size_t head_dim = in->shape()[2]; + ASSERT(head_dim % 2 == 0, "RoPE: head dimension must be even"); + ASSERT(pos_ids->numel() == in->shape()[0], "RoPE: pos_ids length must match seq_len"); + + // always support cpu calculation + if (in->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rope(out->data(), in->data(), pos_ids->data(), in->dtype(), + in->shape()[0], in->shape()[1], in->shape()[2], theta); + } + + llaisys::core::context().setDevice(in->deviceType(), in->deviceId()); + + switch (in->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rope(out->data(), in->data(), pos_ids->data(), in->dtype(), + in->shape()[0], in->shape()[1], in->shape()[2], theta); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/self_attention/cpu/self_attention_cpu.cpp b/src/ops/self_attention/cpu/self_attention_cpu.cpp new file mode 100644 index 00000000..0f7ab40e --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.cpp @@ -0,0 +1,118 @@ +#include "self_attention_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include +#include + +template +static void self_attention_(T *out, + const T *q, + const T *k, + const T *v, + size_t qlen, + size_t nh, + size_t kvlen, + size_t nkvh, + size_t hd, + float scale) { + // GQA mapping: repeat each kv head nh/nkvh times (same as repeat_interleave in the test). + const size_t repeat = nh / nkvh; + const ptrdiff_t offset = static_cast(kvlen) - static_cast(qlen); + + std::vector logits(kvlen); + std::vector probs(kvlen); + + for (size_t h = 0; h < nh; h++) { + const size_t hk = h / repeat; + + for (size_t qi = 0; qi < qlen; qi++) { + const ptrdiff_t max_j = std::min( + static_cast(kvlen) - 1, + static_cast(qi) + offset); + + float max_logit = -std::numeric_limits::infinity(); + + // 1) logits = (q_i · k_j) * scale, with causal mask. + for (size_t kj = 0; kj < kvlen; kj++) { + if (static_cast(kj) > max_j) { + logits[kj] = -std::numeric_limits::infinity(); + continue; + } + + const T *q_row = q + (qi * nh + h) * hd; + const T *k_row = k + (kj * nkvh + hk) * hd; + float dot = 0.0f; + for (size_t d = 0; d < hd; d++) { + dot += llaisys::utils::cast(q_row[d]) * llaisys::utils::cast(k_row[d]); + } + const float logit = dot * scale; + logits[kj] = logit; + max_logit = std::max(max_logit, logit); + } + + // 2) probs = softmax(logits) + float denom = 0.0f; + for (size_t kj = 0; kj < kvlen; kj++) { + float p = 0.0f; + if (!std::isinf(logits[kj])) { + p = std::exp(logits[kj] - max_logit); + } + probs[kj] = p; + denom += p; + } + const float inv_denom = 1.0f / denom; + + // 3) out = probs @ v + T *out_row = out + (qi * nh + h) * hd; + for (size_t d = 0; d < hd; d++) { + float acc = 0.0f; + for (size_t kj = 0; kj < kvlen; kj++) { + if (probs[kj] == 0.0f) continue; + const T *v_row = v + (kj * nkvh + hk) * hd; + acc += (probs[kj] * inv_denom) * llaisys::utils::cast(v_row[d]); + } + out_row[d] = llaisys::utils::cast(acc); + } + } + } +} + +namespace llaisys::ops::cpu { +void self_attention(std::byte *out, + const std::byte *q, + const std::byte *k, + const std::byte *v, + llaisysDataType_t dtype, + size_t qlen, + size_t nh, + size_t kvlen, + size_t nkvh, + size_t hd, + float scale) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return self_attention_(reinterpret_cast(out), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + qlen, nh, kvlen, nkvh, hd, scale); + case LLAISYS_DTYPE_BF16: + return self_attention_(reinterpret_cast(out), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + qlen, nh, kvlen, nkvh, hd, scale); + case LLAISYS_DTYPE_F16: + return self_attention_(reinterpret_cast(out), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + qlen, nh, kvlen, nkvh, hd, scale); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/cpu/self_attention_cpu.hpp b/src/ops/self_attention/cpu/self_attention_cpu.hpp new file mode 100644 index 00000000..ad26cd0b --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +// Implements causal self-attention: +// out = softmax((q @ k^T) * scale + causal_mask) @ v +// +// Layout expectations (contiguous): +// q : [qlen, nh, hd] +// k/v : [kvlen, nkvh, hd] (GQA/MQA supported when nh % nkvh == 0) +// out : [qlen, nh, hd] +// +// causal_mask matches test/ops/self_attention.py: +// allow key index j <= i + (kvlen - qlen) +void self_attention(std::byte *out, + const std::byte *q, + const std::byte *k, + const std::byte *v, + llaisysDataType_t dtype, + size_t qlen, + size_t nh, + size_t kvlen, + size_t nkvh, + size_t hd, + float scale); +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d62014..05f66cca 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,53 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/self_attention_cpu.hpp" + namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(attn_val, q, k, v); + CHECK_SAME_DTYPE(attn_val->dtype(), q->dtype(), k->dtype(), v->dtype()); + ASSERT(attn_val->isContiguous() && q->isContiguous() && k->isContiguous() && v->isContiguous(), + "SelfAttention: all tensors must be contiguous."); + + ASSERT(q->ndim() == 3 && k->ndim() == 3 && v->ndim() == 3 && attn_val->ndim() == 3, + "SelfAttention: q/k/v/attn_val must be 3D [len, heads, head_dim]"); + + const size_t qlen = q->shape()[0]; + const size_t nh = q->shape()[1]; + const size_t hd = q->shape()[2]; + + const size_t kvlen = k->shape()[0]; + const size_t nkvh = k->shape()[1]; + + CHECK_SAME_SHAPE(v->shape(), k->shape()); + ASSERT(attn_val->shape()[0] == qlen && attn_val->shape()[1] == nh && attn_val->shape()[2] == hd, + "SelfAttention: attn_val shape must match q shape [qlen, nh, hd]"); + ASSERT(k->shape()[2] == hd, "SelfAttention: head_dim mismatch between q and k"); + ASSERT(nh % nkvh == 0, "SelfAttention: require nh % nkvh == 0 (GQA/MQA head mapping)"); + ASSERT(kvlen >= qlen, "SelfAttention: currently require kvlen >= qlen for causal masking"); + + // always support cpu calculation + if (attn_val->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), + attn_val->dtype(), qlen, nh, kvlen, nkvh, hd, scale); + } + + llaisys::core::context().setDevice(attn_val->deviceType(), attn_val->deviceId()); + + switch (attn_val->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), + attn_val->dtype(), qlen, nh, kvlen, nkvh, hd, scale); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 00000000..9e00a5c4 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,49 @@ +#include "swiglu_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +static void swiglu_(T *out, const T *gate, const T *up, size_t numel) { + for (size_t i = 0; i < numel; i++) { + if constexpr (std::is_same_v || std::is_same_v) { + float u = llaisys::utils::cast(up[i]); + float g = llaisys::utils::cast(gate[i]); + float s = g / (1.0f + std::exp(-g)); + out[i] = llaisys::utils::cast(u * s); + } else { + float g = static_cast(gate[i]); + float s = g / (1.0f + std::exp(-g)); + out[i] = static_cast(static_cast(up[i]) * s); + } + } +} + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, + const std::byte *gate, + const std::byte *up, + llaisysDataType_t dtype, + size_t numel) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return swiglu_(reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + case LLAISYS_DTYPE_BF16: + return swiglu_(reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + case LLAISYS_DTYPE_F16: + return swiglu_(reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 00000000..9a89273d --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +// out[i] = up[i] * (gate[i] / (1 + exp(-gate[i]))) +void swiglu(std::byte *out, + const std::byte *gate, + const std::byte *up, + llaisysDataType_t dtype, + size_t numel); +} // namespace llaisys::ops::cpu diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc9..78160ae9 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,35 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/swiglu_cpu.hpp" + namespace llaisys::ops { void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, gate, up); + CHECK_SAME_SHAPE(out->shape(), gate->shape(), up->shape()); + CHECK_SAME_DTYPE(out->dtype(), gate->dtype(), up->dtype()); + ASSERT(out->isContiguous() && gate->isContiguous() && up->isContiguous(), + "SwiGLU: all tensors must be contiguous."); + + // always support cpu calculation + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), out->numel()); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), out->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb6..f4b4139a 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -2,6 +2,7 @@ #include "../utils.hpp" +#include #include #include #include @@ -17,10 +18,10 @@ tensor_t Tensor::create(const std::vector &shape, int device) { size_t ndim_ = shape.size(); std::vector strides(ndim_); - size_t stride = 1; + ptrdiff_t stride = 1; for (size_t i = 1; i <= ndim_; i++) { strides[ndim_ - i] = stride; - stride *= shape[ndim_ - i]; + stride *= static_cast(shape[ndim_ - i]); } TensorMeta meta{dtype, shape, strides}; size_t total_elems = stride; @@ -164,27 +165,100 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + ptrdiff_t stride = 1; + auto shape = this->shape(); + auto strides = this->strides(); + size_t ndim = this->ndim(); + for (size_t i = 1; i <= ndim; i++) { + if (strides[ndim - i] != stride) { + return false; + } + stride *= static_cast(shape[ndim - i]); + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t n_dim = order.size(); + std::vector shape(n_dim); + std::vector strides(n_dim); + + auto old_shape = this->shape(); + auto old_strides = this->strides(); + for (size_t i = 0; i < n_dim; ++i) { + int64_t dim = order[i]; + shape[i] = old_shape[dim]; + strides[i] = old_strides[dim]; + } + TensorMeta meta{this->dtype(), shape, strides}; + return std::shared_ptr(new Tensor(meta, _storage)); } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + // 检查连续性 + if (!this->isContiguous()) { + throw std::runtime_error("Tensor must be contiguous to view()"); + } + // 验证元素总数匹配 + size_t old_elements = this->numel(); + size_t new_elements = 1; + for (auto s : shape) { + new_elements *= s; + } + + if (old_elements != new_elements) { + throw std::runtime_error("Shape mismatch in view()"); + } + // 计算新的 strides + size_t ndim = shape.size(); + std::vector strides(ndim); + ptrdiff_t stride = 1; + for (size_t i = 1; i <= ndim; ++i) { + strides[ndim - i] = stride; + stride *= static_cast(shape[ndim - i]); + } + + TensorMeta meta{this->dtype(), shape, strides}; + return std::shared_ptr(new Tensor(meta, _storage, _offset)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + if (dim >= this->ndim()) { + throw std::runtime_error("Dimension out of range"); + } + if (start >= end || end > this->shape()[dim]) { + throw std::runtime_error("Invalid slice indices"); + } + + std::vector shape = this->shape(); + shape[dim] = end - start; + TensorMeta meta{this->dtype(), shape, this->strides()}; + + // `_offset` is in bytes, strides are in elements. Convert to a byte offset. + const size_t byte_offset = this->_offset + static_cast(start) * static_cast(this->strides()[dim]) * this->elementSize(); + + return std::shared_ptr(new Tensor(meta, _storage, byte_offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + // 判断指针不为空 + if (!src_) { + throw std::runtime_error("Tensor::load: source pointer is null"); + } + // 获取设备类型 + core::context().setDevice(this->deviceType(), this->deviceId()); + core::context().runtime().api()->device_synchronize(); + size_t byte_size = this->numel() * this->elementSize(); + // 根据设备类型决定如何复制 + if (this->deviceType() == LLAISYS_DEVICE_CPU) { + std::memcpy(this->data(), src_, byte_size); + } else { + core::context().runtime().api()->memcpy_sync( + this->data(), + src_, + byte_size, + LLAISYS_MEMCPY_H2D); + } } tensor_t Tensor::contiguous() const { diff --git a/test/ops/linear_nobias.py b/test/ops/linear_nobias.py new file mode 100644 index 00000000..3d56e06d --- /dev/null +++ b/test/ops/linear_nobias.py @@ -0,0 +1,77 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) + +import torch +import llaisys + +from llaisys.libllaisys import LIB_LLAISYS +from test_utils import random_tensor, check_equal, benchmark + + +def torch_linear_nobias(out, x, w): + # out = x @ w.T + torch.nn.functional.linear(x, w, bias=None, out=out) + + +def llaisys_linear_nobias(out_, x_, w_): + # Call the C API directly so we can pass NULL bias. + LIB_LLAISYS.llaisysLinear(out_.lib_tensor(), x_.lib_tensor(), w_.lib_tensor(), None) + + +def test_op_linear_nobias( + out_shape, + x_shape, + w_shape, + dtype_name="f32", + atol=1e-5, + rtol=1e-5, + device_name="cpu", + profile=False, +): + print(f" out {out_shape}, x {x_shape}, w {w_shape}, bias False, dtype <{dtype_name}>") + x, x_ = random_tensor(x_shape, dtype_name, device_name, scale=0.1) + w, w_ = random_tensor(w_shape, dtype_name, device_name, scale=0.01) + + out, out_ = random_tensor(out_shape, dtype_name, device_name) + torch_linear_nobias(out, x, w) + llaisys_linear_nobias(out_, x_, w_) + + assert check_equal(out_, out, atol=atol, rtol=rtol) + + if profile: + benchmark( + lambda: torch_linear_nobias(out, x, w), + lambda: llaisys_linear_nobias(out_, x_, w_), + device_name, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + + test_shapes = [ + ((2, 3), (2, 4), (3, 4)), + ((8, 16), (8, 32), (16, 32)), + ((512, 4096), (512, 4096), (4096, 4096)), + ] + test_dtype_prec = [ + ("f32", 1e-5, 1e-5), + ("f16", 1e-3, 1e-3), + ("bf16", 1e-2, 1e-2), + ] + + print(f"Testing Ops.linear (no bias) on {args.device}") + for shapes in test_shapes: + for dtype_name, atol, rtol in test_dtype_prec: + test_op_linear_nobias(*shapes, dtype_name, atol, rtol, args.device, args.profile) + + print("\033[92mTest passed!\033[0m\n") + diff --git a/test/ops/rearrange.py b/test/ops/rearrange.py new file mode 100644 index 00000000..327eaf69 --- /dev/null +++ b/test/ops/rearrange.py @@ -0,0 +1,94 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) + +import llaisys +import torch + +from test_utils import random_tensor, check_equal, benchmark + + +def torch_rearrange(out, x, perm): + out.copy_(x.permute(*perm).contiguous()) + + +def llaisys_rearrange(out_, x_, perm): + # Create a non-contiguous view via permute, then rearrange into a contiguous output. + x_view_ = x_.permute(*perm) + llaisys.Ops.rearrange(out_, x_view_) + + +def test_op_rearrange( + shape, + perm, + dtype_name="f32", + atol=1e-5, + rtol=1e-5, + device_name="cpu", + profile=False, +): + print(f" shape {shape} perm {perm} dtype <{dtype_name}>") + x, x_ = random_tensor(shape, dtype_name, device_name, scale=0.1) + + out_shape = tuple(shape[p] for p in perm) + out = torch.empty(out_shape, dtype=x.dtype, device=x.device) + out_ = llaisys.Tensor(out_shape, dtype=llaisys_dtype(dtype_name), device=llaisys_device(device_name)) + + torch_rearrange(out, x, perm) + llaisys_rearrange(out_, x_, perm) + + assert check_equal(out_, out, atol=atol, rtol=rtol) + + if profile: + benchmark( + lambda: torch_rearrange(out, x, perm), + lambda: llaisys_rearrange(out_, x_, perm), + device_name, + ) + + +def llaisys_device(device_name: str): + if device_name == "cpu": + return llaisys.DeviceType.CPU + elif device_name == "nvidia": + return llaisys.DeviceType.NVIDIA + raise ValueError(device_name) + + +def llaisys_dtype(dtype_name: str): + if dtype_name == "f32": + return llaisys.DataType.F32 + if dtype_name == "f16": + return llaisys.DataType.F16 + if dtype_name == "bf16": + return llaisys.DataType.BF16 + raise ValueError(dtype_name) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + + test_shapes = [ + ((2, 3, 4), (2, 0, 1)), + ((4, 5, 6), (1, 2, 0)), + ] + test_dtype_prec = [ + ("f32", 1e-5, 1e-5), + ("f16", 1e-3, 1e-3), + ("bf16", 1e-2, 1e-2), + ] + + print(f"Testing Ops.rearrange on {args.device}") + for shape, perm in test_shapes: + for dtype_name, atol, rtol in test_dtype_prec: + test_op_rearrange(shape, perm, dtype_name, atol, rtol, args.device, args.profile) + + print("\033[92mTest passed!\033[0m\n") + diff --git a/test/test_infer.py b/test/test_infer.py index 59d06b87..29a4b19d 100644 --- a/test/test_infer.py +++ b/test/test_infer.py @@ -82,7 +82,7 @@ def llaisys_infer( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) - parser.add_argument("--model", default=None, type=str) + parser.add_argument("--model", default="/root/lxl/models/llaisys/DeepSeek-R1-Distill-Qwen-1.5B/", type=str) parser.add_argument("--prompt", default="Who are you?", type=str) parser.add_argument("--max_steps", default=128, type=int) parser.add_argument("--top_p", default=0.8, type=float) diff --git a/xmake.lua b/xmake.lua index 1f65f7a9..ec4c2600 100644 --- a/xmake.lua +++ b/xmake.lua @@ -95,6 +95,22 @@ target("llaisys-ops") on_install(function (target) end) target_end() +target("llaisys-models") + set_kind("static") + add_deps("llaisys-ops") + add_deps("llaisys-tensor") + + set_languages("cxx17") + set_warnings("all", "error") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + end + + add_files("src/models/*/*.cpp") + + on_install(function (target) end) +target_end() + target("llaisys") set_kind("shared") add_deps("llaisys-utils") @@ -102,10 +118,12 @@ target("llaisys") add_deps("llaisys-core") add_deps("llaisys-tensor") add_deps("llaisys-ops") + add_deps("llaisys-models") set_languages("cxx17") set_warnings("all", "error") add_files("src/llaisys/*.cc") + add_files("src/llaisys/models/*.cc") set_installdir(".") @@ -119,4 +137,4 @@ target("llaisys") os.cp("lib/*.so", "python/llaisys/libllaisys/") end end) -target_end() \ No newline at end of file +target_end()