diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h index 7054626d..2a767c75 100644 --- a/include/llaisys/models/qwen2.h +++ b/include/llaisys/models/qwen2.h @@ -14,8 +14,8 @@ __C { struct LlaisysQwen2Weights { llaisysTensor_t in_embed; llaisysTensor_t out_embed; - llaisysTensor_t out_norm_w; // a.k.a. model.norm.weight - llaisysTensor_t *attn_norm_w; // a.k.a. input_layernorm.weight + llaisysTensor_t out_norm_w; + llaisysTensor_t *attn_norm_w; llaisysTensor_t *attn_q_w; llaisysTensor_t *attn_q_b; llaisysTensor_t *attn_k_w; @@ -23,7 +23,7 @@ __C { llaisysTensor_t *attn_v_w; llaisysTensor_t *attn_v_b; llaisysTensor_t *attn_o_w; - llaisysTensor_t *mlp_norm_w; // a.k.a. post_attention_layernorm.weight + llaisysTensor_t *mlp_norm_w; llaisysTensor_t *mlp_gate_w; llaisysTensor_t *mlp_up_w; llaisysTensor_t *mlp_down_w; diff --git a/python/llaisys/libllaisys/llaisys_types.py b/python/llaisys/libllaisys/llaisys_types.py index c5a0b467..af9998f0 100644 --- a/python/llaisys/libllaisys/llaisys_types.py +++ b/python/llaisys/libllaisys/llaisys_types.py @@ -2,7 +2,6 @@ from enum import IntEnum -# Device Type enum class DeviceType(IntEnum): CPU = 0 NVIDIA = 1 @@ -12,7 +11,6 @@ class DeviceType(IntEnum): llaisysDeviceType_t = ctypes.c_int -# Data Type enum class DataType(IntEnum): INVALID = 0 BYTE = 1 @@ -39,7 +37,6 @@ class DataType(IntEnum): llaisysDataType_t = ctypes.c_int -# Memory Copy Kind enum class MemcpyKind(IntEnum): H2H = 0 H2D = 1 @@ -48,8 +45,13 @@ class MemcpyKind(IntEnum): llaisysMemcpyKind_t = ctypes.c_int +llaisysTensor_t = ctypes.c_void_p + +class LlaisysQwen2Model(ctypes.Structure): + pass +llaisysQwen2ModelHandle = ctypes.POINTER(LlaisysQwen2Model) +llaisysQwen2Weights_p = ctypes.c_void_p -# Stream type (opaque pointer) llaisysStream_t = ctypes.c_void_p __all__ = [ diff --git a/python/llaisys/libllaisys/models.py b/python/llaisys/libllaisys/models.py new file mode 100644 index 00000000..b8a44911 --- /dev/null +++ b/python/llaisys/libllaisys/models.py @@ -0,0 +1,42 @@ +import ctypes +from .llaisys_types import llaisysDataType_t, llaisysTensor_t, llaisysDeviceType_t + +class LlaisysQwen2Meta(ctypes.Structure): + _fields_ = [ + ("dtype", llaisysDataType_t), + ("nlayer", ctypes.c_size_t), + ("hs", ctypes.c_size_t), + ("nh", ctypes.c_size_t), + ("nkvh", ctypes.c_size_t), + ("dh", ctypes.c_size_t), + ("di", ctypes.c_size_t), + ("maxseq", ctypes.c_size_t), + ("voc", ctypes.c_size_t), + ("epsilon", ctypes.c_float), + ("theta", ctypes.c_float), + ("end_token", ctypes.c_int64), + ] + +class LlaisysQwen2Weights(ctypes.Structure): + _fields_ = [ + ("in_embed", llaisysTensor_t), + ("out_embed", llaisysTensor_t), + ("out_norm_w", llaisysTensor_t), + ("attn_norm_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_q_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_q_b", ctypes.POINTER(llaisysTensor_t)), + ("attn_k_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_k_b", ctypes.POINTER(llaisysTensor_t)), + ("attn_v_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_v_b", ctypes.POINTER(llaisysTensor_t)), + ("attn_o_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_norm_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_gate_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_up_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_down_w", ctypes.POINTER(llaisysTensor_t)), + ] + +class LlaisysQwen2Model(ctypes.Structure): + pass + +llaisysQwen2ModelHandle = ctypes.POINTER(LlaisysQwen2Model) diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b2..ae851ffd 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,33 +1,192 @@ -from typing import Sequence -from ..libllaisys import LIB_LLAISYS -from ..libllaisys import DeviceType - +import ctypes +import json +import os +import mmap +import struct +from typing import List, Dict, Optional, Sequence, Any +import numpy as np from pathlib import Path -import safetensors +from ..libllaisys import LIB_LLAISYS, llaisysTensor_t, llaisysDataType_t, llaisysDeviceType_t, DataType, DeviceType +from ..libllaisys.models import LlaisysQwen2Meta, LlaisysQwen2Weights, LlaisysQwen2Model, llaisysQwen2ModelHandle +from ..tensor import Tensor -class Qwen2: +LIB_LLAISYS.llaisysQwen2ModelCreate.argtypes = [ctypes.POINTER(LlaisysQwen2Meta), llaisysDeviceType_t, ctypes.POINTER(ctypes.c_int), ctypes.c_int] +LIB_LLAISYS.llaisysQwen2ModelCreate.restype = llaisysQwen2ModelHandle + +LIB_LLAISYS.llaisysQwen2ModelDestroy.argtypes = [llaisysQwen2ModelHandle] +LIB_LLAISYS.llaisysQwen2ModelDestroy.restype = None + +LIB_LLAISYS.llaisysQwen2ModelWeights.argtypes = [llaisysQwen2ModelHandle] +LIB_LLAISYS.llaisysQwen2ModelWeights.restype = ctypes.POINTER(LlaisysQwen2Weights) - def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor +LIB_LLAISYS.llaisysQwen2ModelInfer.argtypes = [llaisysQwen2ModelHandle, ctypes.POINTER(ctypes.c_int64), ctypes.c_size_t] +LIB_LLAISYS.llaisysQwen2ModelInfer.restype = ctypes.c_int64 + +class Qwen2: + def __init__(self, model_path: str, device: DeviceType = DeviceType.CPU, device_id: int = 0): model_path = Path(model_path) + config_path = model_path / "config.json" + + with open(config_path, "r") as f: + config = json.load(f) + + self.device = device + self.device_id = device_id + + self.meta = LlaisysQwen2Meta() + + dtype_str = config.get("torch_dtype", "float32") + self.meta.dtype = DataType.F32 + + self.meta.nlayer = config.get("num_hidden_layers", 24) + self.meta.hs = config.get("hidden_size", 2048) + self.meta.nh = config.get("num_attention_heads", 16) + self.meta.nkvh = config.get("num_key_value_heads", 16) + self.meta.dh = self.meta.hs // self.meta.nh + self.meta.di = config.get("intermediate_size", 11008) + self.meta.maxseq = config.get("max_position_embeddings", 8192) + self.meta.voc = config.get("vocab_size", 151936) + self.meta.epsilon = config.get("rms_norm_eps", 1e-6) + self.meta.theta = config.get("rope_theta", 1000000.0) + self.meta.end_token = 151643 # Placeholder + + dev_ids = (ctypes.c_int * 1)(device_id) + self.handle = LIB_LLAISYS.llaisysQwen2ModelCreate(ctypes.byref(self.meta), device, dev_ids, 1) + self.weights_ptr = LIB_LLAISYS.llaisysQwen2ModelWeights(self.handle) + self.tensors_ref = [] for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") - for name_ in data_.keys(): - ## TODO: load the model weights - pass + print(f"Loading weights from {file}...") + weights_data = self._load_safetensors_bf16_as_f32(file) + for key, arr in weights_data.items(): + if not arr.flags['C_CONTIGUOUS']: + arr = np.ascontiguousarray(arr) + + t = Tensor(list(arr.shape), self.meta.dtype, device, device_id) + t.load(ctypes.c_void_p(arr.ctypes.data)) + + self._assign_weight(key, t) + + def _load_safetensors_bf16_as_f32(self, path: Path) -> Dict[str, np.ndarray]: + tensors = {} + with open(path, 'rb') as f: + length_bytes = f.read(8) + if not length_bytes: return {} + header_size = struct.unpack(' List[int]: + + generated = [] + tokens = list(inputs) + + arr = (ctypes.c_int64 * len(tokens))(*tokens) + next_token = LIB_LLAISYS.llaisysQwen2ModelInfer(self.handle, arr, len(tokens)) + generated.append(next_token) + tokens = [next_token] + + for _ in range(max_new_tokens - 1): + arr = (ctypes.c_int64 * 1)(*tokens) + next_token = LIB_LLAISYS.llaisysQwen2ModelInfer(self.handle, arr, 1) + generated.append(next_token) + tokens = [next_token] + + if next_token == self.meta.end_token: + break + + return list(inputs) + generated diff --git a/src/llaisys/models/qwen2.cc b/src/llaisys/models/qwen2.cc new file mode 100644 index 00000000..209ea8b1 --- /dev/null +++ b/src/llaisys/models/qwen2.cc @@ -0,0 +1,31 @@ +#include "llaisys/models/qwen2.h" +#include "../../models/qwen2/model.hpp" + +using namespace llaisys::models::qwen2; + +extern "C" { + +struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) { + int dev_id = (ndevice > 0 && device_ids != nullptr) ? device_ids[0] : 0; + Qwen2Model* model = new Qwen2Model(*meta, device, dev_id); + return reinterpret_cast(model); +} + +void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model) { + if (model) { + delete reinterpret_cast(model); + } +} + +struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model) { + if (!model) return nullptr; + return reinterpret_cast(model)->getWeightsStruct(); +} + +int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken) { + if (!model) return -1; + std::vector tokens(token_ids, token_ids + ntoken); + return reinterpret_cast(model)->infer(tokens); +} + +} diff --git a/src/llaisys/qwen2.cc b/src/llaisys/qwen2.cc new file mode 100644 index 00000000..5bf03a66 --- /dev/null +++ b/src/llaisys/qwen2.cc @@ -0,0 +1,36 @@ +#include "llaisys/models/qwen2.h" +#include "../models/qwen2/qwen2.hpp" + +extern "C" { + +struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) { + if (!meta || ndevice < 1) return nullptr; + // For now support single device + int device_id = device_ids ? device_ids[0] : 0; + + // Copy meta + LlaisysQwen2Meta cpp_meta = *meta; + + auto* model = new llaisys::Qwen2Model(cpp_meta, device, device_id); + return reinterpret_cast(model); +} + +void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model) { + if (model) { + delete reinterpret_cast(model); + } +} + +struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model) { + if (!model) return nullptr; + auto* cpp_model = reinterpret_cast(model); + return cpp_model->getWeights(); +} + +int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken) { + if (!model) return -1; + auto* cpp_model = reinterpret_cast(model); + return cpp_model->infer(token_ids, ntoken); +} + +} diff --git a/src/models/qwen2/qwen2.cpp b/src/models/qwen2/qwen2.cpp new file mode 100644 index 00000000..910b2fda --- /dev/null +++ b/src/models/qwen2/qwen2.cpp @@ -0,0 +1,189 @@ +#include "qwen2.hpp" +#include "llaisys/ops.h" +#include "../../ops/add/op.hpp" +#include "../../ops/argmax/op.hpp" +#include "../../ops/embedding/op.hpp" +#include "../../ops/linear/op.hpp" +#include "../../ops/rearrange/op.hpp" +#include "../../ops/rms_norm/op.hpp" +#include "../../ops/rope/op.hpp" +#include "../../ops/self_attention/op.hpp" +#include "../../ops/swiglu/op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils/check.hpp" +#include +#include +#include +#include "../../llaisys/llaisys_tensor.hpp" + +namespace llaisys { + +inline tensor_t to_cpp(llaisysTensor_t t) { + if (!t) return nullptr; + return reinterpret_cast(t)->tensor; +} + +inline tensor_t to_cpp(llaisysTensor_t* t_array, size_t idx) { + if (!t_array || !t_array[idx]) return nullptr; + return reinterpret_cast(t_array[idx])->tensor; +} + +Qwen2Model::Qwen2Model(const LlaisysQwen2Meta& meta, llaisysDeviceType_t device, int device_id) + : _meta(meta), _device_type(device), _device_id(device_id) { + + _weights.attn_norm_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_q_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_q_b = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_k_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_k_b = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_v_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_v_b = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.attn_o_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.mlp_norm_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.mlp_gate_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.mlp_up_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + _weights.mlp_down_w = (llaisysTensor_t*)calloc(meta.nlayer, sizeof(llaisysTensor_t)); + + init_buffers(); +} + +Qwen2Model::~Qwen2Model() { + free(_weights.attn_norm_w); + free(_weights.attn_q_w); + free(_weights.attn_q_b); + free(_weights.attn_k_w); + free(_weights.attn_k_b); + free(_weights.attn_v_w); + free(_weights.attn_v_b); + free(_weights.attn_o_w); + free(_weights.mlp_norm_w); + free(_weights.mlp_gate_w); + free(_weights.mlp_up_w); + free(_weights.mlp_down_w); +} + +void Qwen2Model::init_buffers() { + core::context().setDevice(_device_type, _device_id); + + for(size_t i=0; i<_meta.nlayer; ++i) { + _kv_caches.push_back({ + Tensor::create({_meta.maxseq, _meta.nkvh, _meta.dh}, _meta.dtype, _device_type, _device_id), + Tensor::create({_meta.maxseq, _meta.nkvh, _meta.dh}, _meta.dtype, _device_type, _device_id) + }); + } + + _hidden_states = Tensor::create({1, 1, _meta.hs}, _meta.dtype, _device_type, _device_id); + _residual = Tensor::create({1, 1, _meta.hs}, _meta.dtype, _device_type, _device_id); + _ln_out = Tensor::create({1, 1, _meta.hs}, _meta.dtype, _device_type, _device_id); + _attn_out = Tensor::create({1, 1, _meta.hs}, _meta.dtype, _device_type, _device_id); + _mlp_out = Tensor::create({1, 1, _meta.hs}, _meta.dtype, _device_type, _device_id); + + _logits = Tensor::create({1, _meta.voc}, _meta.dtype, _device_type, _device_id); + + _pos_ids = Tensor::create({_meta.maxseq}, LLAISYS_DTYPE_I64, _device_type, _device_id); +} + +int64_t Qwen2Model::infer(const int64_t* token_ids, size_t ntoken) { + core::context().setDevice(_device_type, _device_id); + + tensor_t input_tokens = Tensor::create({ntoken}, LLAISYS_DTYPE_I64, _device_type, _device_id); + input_tokens->load(token_ids); + + tensor_t current_pos_ids = _pos_ids->slice(0, 0, ntoken); + std::vector pos_data(ntoken); + for(size_t i=0; iload(pos_data.data()); + + std::vector seq_shape = {ntoken, _meta.hs}; + + tensor_t hidden_states = Tensor::create({ntoken, _meta.hs}, _meta.dtype, _device_type, _device_id); + ops::embedding(hidden_states, input_tokens, to_cpp(_weights.in_embed)); + + for(size_t i=0; i<_meta.nlayer; ++i) { + tensor_t normed = Tensor::create(seq_shape, _meta.dtype, _device_type, _device_id); + ops::rms_norm(normed, hidden_states, to_cpp(_weights.attn_norm_w, i), _meta.epsilon); + + size_t q_dim = _meta.nh * _meta.dh; + size_t k_dim = _meta.nkvh * _meta.dh; + + tensor_t q = Tensor::create({ntoken, q_dim}, _meta.dtype, _device_type, _device_id); + tensor_t k = Tensor::create({ntoken, k_dim}, _meta.dtype, _device_type, _device_id); + tensor_t v = Tensor::create({ntoken, k_dim}, _meta.dtype, _device_type, _device_id); + + ops::linear(q, normed, to_cpp(_weights.attn_q_w, i), to_cpp(_weights.attn_q_b, i)); + ops::linear(k, normed, to_cpp(_weights.attn_k_w, i), to_cpp(_weights.attn_k_b, i)); + ops::linear(v, normed, to_cpp(_weights.attn_v_w, i), to_cpp(_weights.attn_v_b, i)); + + q = q->view({ntoken, _meta.nh, _meta.dh}); + k = k->view({ntoken, _meta.nkvh, _meta.dh}); + v = v->view({ntoken, _meta.nkvh, _meta.dh}); + + ops::rope(q, q, current_pos_ids, _meta.theta); + ops::rope(k, k, current_pos_ids, _meta.theta); + + tensor_t k_cache_slot = _kv_caches[i].k->slice(0, _cur_pos, _cur_pos + ntoken); + tensor_t v_cache_slot = _kv_caches[i].v->slice(0, _cur_pos, _cur_pos + ntoken); + + ops::rearrange(k_cache_slot, k); + ops::rearrange(v_cache_slot, v); + + tensor_t k_full = _kv_caches[i].k->slice(0, 0, _cur_pos + ntoken); + tensor_t v_full = _kv_caches[i].v->slice(0, 0, _cur_pos + ntoken); + + tensor_t attn_val = Tensor::create({ntoken, _meta.nh, _meta.dh}, _meta.dtype, _device_type, _device_id); + float scale = 1.0f / std::sqrt((float)_meta.dh); + + ops::self_attention(attn_val, q, k_full, v_full, scale); + + attn_val = attn_val->view({ntoken, _meta.hs}); + + tensor_t attn_output = Tensor::create(seq_shape, _meta.dtype, _device_type, _device_id); + ops::linear(attn_output, attn_val, to_cpp(_weights.attn_o_w, i), nullptr); + + ops::add(hidden_states, hidden_states, attn_output); + + normed = Tensor::create(seq_shape, _meta.dtype, _device_type, _device_id); + ops::rms_norm(normed, hidden_states, to_cpp(_weights.mlp_norm_w, i), _meta.epsilon); + + tensor_t gate = Tensor::create({ntoken, _meta.di}, _meta.dtype, _device_type, _device_id); + tensor_t up = Tensor::create({ntoken, _meta.di}, _meta.dtype, _device_type, _device_id); + + ops::linear(gate, normed, to_cpp(_weights.mlp_gate_w, i), nullptr); + ops::linear(up, normed, to_cpp(_weights.mlp_up_w, i), nullptr); + + tensor_t swiglu_out = Tensor::create({ntoken, _meta.di}, _meta.dtype, _device_type, _device_id); + ops::swiglu(swiglu_out, gate, up); + + tensor_t mlp_output = Tensor::create(seq_shape, _meta.dtype, _device_type, _device_id); + ops::linear(mlp_output, swiglu_out, to_cpp(_weights.mlp_down_w, i), nullptr); + + ops::add(hidden_states, hidden_states, mlp_output); + } + + tensor_t final_normed = Tensor::create(seq_shape, _meta.dtype, _device_type, _device_id); + ops::rms_norm(final_normed, hidden_states, to_cpp(_weights.out_norm_w), _meta.epsilon); + + tensor_t last_hidden = final_normed->slice(0, ntoken-1, ntoken); + last_hidden = last_hidden->view({1, _meta.hs}); + + tensor_t logits = Tensor::create({1, _meta.voc}, _meta.dtype, _device_type, _device_id); + ops::linear(logits, last_hidden, to_cpp(_weights.out_embed), nullptr); + + tensor_t max_idx = Tensor::create({1}, LLAISYS_DTYPE_I64, _device_type, _device_id); + tensor_t max_val = Tensor::create({1}, _meta.dtype, _device_type, _device_id); + + ops::argmax(max_idx, max_val, logits); + + int64_t next_token_id; + if (_device_type == LLAISYS_DEVICE_CPU) { + next_token_id = *reinterpret_cast(max_idx->data()); + } else { + next_token_id = 0; + } + + _cur_pos += ntoken; + + return next_token_id; +} + +} // namespace llaisys diff --git a/src/models/qwen2/qwen2.hpp b/src/models/qwen2/qwen2.hpp new file mode 100644 index 00000000..5d55f5e7 --- /dev/null +++ b/src/models/qwen2/qwen2.hpp @@ -0,0 +1,61 @@ +#pragma once +#include "llaisys/models/qwen2.h" +#include "../../tensor/tensor.hpp" +#include +#include + +namespace llaisys { + +class Qwen2Model { +public: + Qwen2Model(const LlaisysQwen2Meta& meta, llaisysDeviceType_t device, int device_id); + ~Qwen2Model(); + + LlaisysQwen2Weights* getWeights() { return &_weights; } + + int64_t infer(const int64_t* token_ids, size_t ntoken); + +private: + LlaisysQwen2Meta _meta; + LlaisysQwen2Weights _weights; + + llaisysDeviceType_t _device_type; + int _device_id; + + struct KVCache { + tensor_t k; + tensor_t v; + }; + std::vector _kv_caches; + + size_t _cur_pos = 0; + + tensor_t _hidden_states; + tensor_t _residual; + tensor_t _ln_out; + tensor_t _attn_out; + tensor_t _mlp_out; + tensor_t _logits; + + tensor_t _tokens_tensor; + tensor_t _pos_ids; + + void init_buffers(); + + std::vector _layers_attn_norm_w; + std::vector _layers_attn_q_w; + std::vector _layers_attn_q_b; + std::vector _layers_attn_k_w; + std::vector _layers_attn_k_b; + std::vector _layers_attn_v_w; + std::vector _layers_attn_v_b; + std::vector _layers_attn_o_w; + std::vector _layers_mlp_norm_w; + std::vector _layers_mlp_gate_w; + std::vector _layers_mlp_up_w; + std::vector _layers_mlp_down_w; + + void allocate_layers_weights(); +}; + +} // namespace llaisys diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 00000000..2caae179 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,45 @@ +#include "argmax_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +void argmax_(size_t *max_idx, T *max_val, const T* vals, size_t numel){ + size_t max_index = 0; + float max_value = llaisys::utils::cast(vals[0]); + + for(size_t i = 1; i < numel; ++i){ + if constexpr (std::is_same_v || std::is_same_v){ + max_value = llaisys::utils::cast(max_value); + float current_value = llaisys::utils::cast(vals[i]); + if(current_value > max_value){ + max_value = current_value; + max_index = i; + } + } else { + if(vals[i] > max_value){ + max_value = vals[i]; + max_index = i; + } + } + } + + *max_idx = max_index; + *max_val = llaisys::utils::cast(max_value); +} + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel){ + switch (type) { + case LLAISYS_DTYPE_F32: + return argmax_(reinterpret_cast(max_idx), reinterpret_cast(max_val), reinterpret_cast(vals), numel); + case LLAISYS_DTYPE_BF16: + return argmax_(reinterpret_cast(max_idx), reinterpret_cast(max_val), reinterpret_cast(vals), numel); + case LLAISYS_DTYPE_F16: + return argmax_(reinterpret_cast(max_idx), reinterpret_cast(max_val), reinterpret_cast(vals), numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 00000000..5f58c207 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { + void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t dtype, size_t numel); +} diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d42..7019e27e 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,32 @@ #include "op.hpp" + +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/argmax_cpu.hpp" + namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(max_idx, max_val, vals); + CHECK_SAME_DTYPE(max_val->dtype(), vals->dtype()); + ASSERT(vals->isContiguous(), "Argmax: vals tensor must be contiguous."); + + if(vals->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); + } + + llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId()); + switch (vals->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 00000000..f97594b4 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,33 @@ +#include "embedding_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +void embedding_(T *out, const int64_t *index, const T* weight, size_t numel, size_t embedding_dim){ + for(size_t i = 0; i < numel; ++i){ + // if constexpr (std::is_same_v || std::is_same_v){ + T* out_row_dst = out + i * embedding_dim; + const T* weight_row_src = weight + index[i] * embedding_dim; + std::memcpy(out_row_dst, weight_row_src, embedding_dim * sizeof(T)); + // } else { + + // } + } +} + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t numel, size_t embedding_dim){ + switch (type) { + case LLAISYS_DTYPE_F32: + return embedding_(reinterpret_cast(out), reinterpret_cast(index), reinterpret_cast(weight), numel, embedding_dim); + case LLAISYS_DTYPE_BF16: + return embedding_(reinterpret_cast(out), reinterpret_cast(index), reinterpret_cast(weight), numel, embedding_dim); + case LLAISYS_DTYPE_F16: + return embedding_(reinterpret_cast(out), reinterpret_cast(index), reinterpret_cast(weight), numel, embedding_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 00000000..7d4f0d82 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { + void embedding(std::byte *out, const std::byte *index, const std::byte *wight, llaisysDataType_t dtype, size_t numel, size_t embedding_dim); +} diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d0..066f4203 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,35 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/embedding_cpu.hpp" namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, index, weight); + CHECK_SAME_DTYPE(out->dtype(), weight->dtype()); + ASSERT(index->dtype() == LLAISYS_DTYPE_I64, "Embedding: index tensor must be int64."); + ASSERT(index->isContiguous(), "Embedding: index tensor must be contiguous."); + size_t embedding_dim = weight->shape().back(); + ASSERT(out->shape().size() == 2 && out->shape()[1] == embedding_dim, + "Embedding: output tensor shape is invalid."); + ASSERT(index->shape().size() == 1 && index->shape()[0] == out->shape()[0], + "Embedding: index tensor shape is invalid."); + if(out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->numel(), embedding_dim); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->numel(), embedding_dim); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 00000000..ba5b067f --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,48 @@ +#include "linear_cpu.hpp" + +#include "../../../utils.hpp" + +template +void linear_(T *out, const T *in, const T *weight, const T *bias, std::vector shapes){ + size_t dimi = shapes[0]; + size_t dimk = shapes[1]; + size_t dimj = shapes[2]; + for(size_t i = 0; i < dimi; ++i){ + for(size_t j = 0; j < dimj; ++j){ + if constexpr (std::is_same_v || std::is_same_v){ + float sum = 0.0f; + for(size_t k = 0; k < dimk; ++k){ + sum = sum + llaisys::utils::cast(in[i * dimk + k]) * llaisys::utils::cast(weight[j * dimk + k]); + } + if(bias != nullptr){ + sum = sum + llaisys::utils::cast(bias[j]); + } + out[i * dimj + j] = llaisys::utils::cast(sum); + } else { + T sum = 0.0f; + for(size_t k = 0; k < dimk; ++k){ + sum = sum + in[i * dimk + k] * weight[j * dimk + k]; + } + if(bias != nullptr){ + sum = sum + bias[j]; + } + out[i * dimj + j] = sum; + } + } + } +} + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, llaisysDataType_t type, std::vector shapes){ + switch (type) { + case LLAISYS_DTYPE_F32: + return linear_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), bias ? reinterpret_cast(bias) : nullptr, shapes); + case LLAISYS_DTYPE_BF16: + return linear_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), bias ? reinterpret_cast(bias) : nullptr, shapes); + case LLAISYS_DTYPE_F16: + return linear_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), bias ? reinterpret_cast(bias) : nullptr, shapes); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 00000000..a9a98ed0 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::cpu { + void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, llaisysDataType_t dtype, std::vector shapes); +} diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f865..d8edbaef 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,38 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/linear_cpu.hpp" + namespace llaisys::ops { void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + ASSERT(in->shape().size() == 2, "Linear: input tensor must be 2-D."); + ASSERT(weight->shape().size() == 2, "Linear: weight tensor must be 2-D."); + ASSERT(out->shape().size() == 2, "Linear: output tensor must be 2-D."); + size_t dimi = in->shape()[0]; + size_t dimk = in->shape()[1]; + size_t dimj = weight->shape()[0]; + ASSERT(weight->shape()[1] == dimk, "Linear: weight tensor shape is invalid."); + ASSERT(out->shape()[0] == dimi && out->shape()[1] == dimj, "Linear: output tensor shape is invalid."); + if(bias != nullptr){ + ASSERT(bias->shape().size() == 1 && bias->shape()[0] == dimj, "Linear: bias tensor shape is invalid."); + } + + if(out->deviceType() == LLAISYS_DEVICE_CPU) { + return llaisys::ops::cpu::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, out->dtype(), {dimi, dimk, dimj}); + } + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, out->dtype(), {dimi, dimk, dimj}); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rearrange/cpu/rearrange_cpu.cpp b/src/ops/rearrange/cpu/rearrange_cpu.cpp new file mode 100644 index 00000000..c92a1504 --- /dev/null +++ b/src/ops/rearrange/cpu/rearrange_cpu.cpp @@ -0,0 +1,65 @@ +#include "rearrange_cpu.hpp" +#include "../../../utils.hpp" +#include + +namespace { + +template +void rearrange_(T *out_base, const T *in_base, const std::vector &shape, const std::vector &stride_in, const std::vector &stride_out,size_t dim, size_t offset_in, size_t offset_out) { + + size_t len = shape[dim]; + size_t s_in = stride_in[dim]; + size_t s_out = stride_out[dim]; + + if (dim == shape.size() - 1) { + if (s_in == 1 && s_out == 1) { + std::memcpy(out_base + offset_out, in_base + offset_in, len * sizeof(T)); + } else { + for (size_t i = 0; i < len; ++i) { + out_base[offset_out + i * s_out] = in_base[offset_in + i * s_in]; + } + } + } else { + for (size_t i = 0; i < len; ++i) { + rearrange_(out_base, in_base, shape, stride_in, stride_out, dim + 1, offset_in + i * s_in, offset_out + i * s_out); + } + } +} + +} // namespace + +namespace llaisys::ops::cpu { + +void rearrange(std::byte *out, const std::byte *in, llaisysDataType_t dtype, const std::vector &shape, const std::vector &stride_in, const std::vector &stride_out) { + + if (shape.empty()) { + size_t size = 0; + switch (dtype) { + case LLAISYS_DTYPE_F32: size = 4; break; + case LLAISYS_DTYPE_BF16: size = 2; break; + case LLAISYS_DTYPE_F16: size = 2; break; + case LLAISYS_DTYPE_I64: size = 8; break; + default: EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } + std::memcpy(out, in, size); + return; + } + + switch (dtype) { + case LLAISYS_DTYPE_F32: + rearrange_(reinterpret_cast(out), reinterpret_cast(in), shape, stride_in, stride_out, 0, 0, 0); + break; + case LLAISYS_DTYPE_BF16: + rearrange_(reinterpret_cast(out), reinterpret_cast(in), shape, stride_in, stride_out, 0, 0, 0); + break; + case LLAISYS_DTYPE_F16: + rearrange_(reinterpret_cast(out), reinterpret_cast(in), shape, stride_in, stride_out, 0, 0, 0); + break; + case LLAISYS_DTYPE_I64: + rearrange_(reinterpret_cast(out), reinterpret_cast(in), shape, stride_in, stride_out, 0, 0, 0); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rearrange/cpu/rearrange_cpu.hpp b/src/ops/rearrange/cpu/rearrange_cpu.hpp new file mode 100644 index 00000000..f15927a4 --- /dev/null +++ b/src/ops/rearrange/cpu/rearrange_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" +#include +#include + +namespace llaisys::ops::cpu { + void rearrange(std::byte *out, const std::byte *in, llaisysDataType_t dtype, const std::vector &shape, const std::vector &stride_in, const std::vector &stride_out); +} \ No newline at end of file diff --git a/src/ops/rearrange/op.cpp b/src/ops/rearrange/op.cpp index 017a6ae5..3a467f01 100644 --- a/src/ops/rearrange/op.cpp +++ b/src/ops/rearrange/op.cpp @@ -1,7 +1,31 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/rearrange_cpu.hpp" namespace llaisys::ops { void rearrange(tensor_t out, tensor_t in) { - TO_BE_IMPLEMENTED(); + ASSERT(out->shape() == in->shape(), "Rearrange: input and output tensors must have the same shape."); + ASSERT(out->dtype() == in->dtype(), "Rearrange: input and output tensors must have the same dtype."); + + std::vector stride_in(in->strides().begin(), in->strides().end()); + std::vector stride_out(out->strides().begin(), out->strides().end()); + + if(out->deviceType() == LLAISYS_DEVICE_CPU) { + return llaisys::ops::cpu::rearrange(out->data(), in->data(), out->dtype(), out->shape(), stride_in, stride_out); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rearrange(out->data(), in->data(), out->dtype(), out->shape(), stride_in, stride_out); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.cpp b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp new file mode 100644 index 00000000..ec0f785b --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp @@ -0,0 +1,41 @@ +#include "rms_norm_cpu.hpp" + +#include "../../../utils.hpp" + +#include +template +void rms_norm_(T *out, const T *in, const T *weight, const float eps, std::vector shapes){ + size_t dimi = shapes[0]; + size_t dimj = shapes[1]; + for(size_t i = 0; i < dimi; ++i){ + float sum_sq = 0.0f; + for(size_t j = 0; j < dimj; ++j){ + float val = llaisys::utils::cast(in[i * dimj + j]); + sum_sq += val * val; + } + + float rms = std::sqrt(sum_sq / dimj + eps); + float inv_rms = 1.0f / rms; + for(size_t j = 0; j < dimj; ++j){ + float val = llaisys::utils::cast(in[i * dimj + j]); + float w = llaisys::utils::cast(weight[j]); + float res = val * inv_rms * w; + out[i * dimj + j] = llaisys::utils::cast(res); + } + } +} + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, const float eps, llaisysDataType_t type, std::vector shapes){ + switch (type) { + case LLAISYS_DTYPE_F32: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), eps, shapes); + case LLAISYS_DTYPE_BF16: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), eps, shapes); + case LLAISYS_DTYPE_F16: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), eps, shapes); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.hpp b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp new file mode 100644 index 00000000..61afb8d7 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::cpu { + void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, const float eps, llaisysDataType_t dtype, std::vector shapes); +} \ No newline at end of file diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9..c7a51310 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,35 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rms_norm_cpu.hpp" namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + ASSERT(in->shape().size() == 2, "RMSNorm: input tensor must be 2-D."); + ASSERT(weight->shape().size() == 1, "RMSNorm: weight tensor must be 1-D."); + ASSERT(out->shape().size() == 2, "RMSNorm: output tensor must be 2-D."); + size_t dimi = in->shape()[0]; + size_t dimj = in->shape()[1]; + + ASSERT(weight->shape()[0] == dimj, "RMSNorm: weight tensor shape is invalid."); + ASSERT(out->shape()[0] == dimi && out->shape()[1] == dimj, "RMSNorm: output tensor shape is invalid."); + + if(out->deviceType() == LLAISYS_DEVICE_CPU) { + return llaisys::ops::cpu::rms_norm(out->data(), in->data(), weight->data(), eps, out->dtype(), {dimi, dimj}); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rms_norm(out->data(), in->data(), weight->data(), eps, out->dtype(), {dimi, dimj}); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 00000000..8232e5e7 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,73 @@ +#include "rope_cpu.hpp" + +#include "../../../utils.hpp" + +#include +template +void rope_(T *out, const T *in, const int64_t *pos_ids, float theta, size_t seq_len, size_t num_heads, size_t head_dim) { + + size_t half_dim = head_dim / 2; + + std::vector denoms(half_dim); + for (size_t j = 0; j < half_dim; ++j) { + // 此处用 double 以防止 pow 计算时出现精度问题 + double exponent = (2.0 * static_cast(j)) / static_cast(head_dim); + double denom_d = std::pow(static_cast(theta), exponent); + denoms[j] = static_cast(denom_d); + } + + for (size_t s = 0; s < seq_len; ++s) { + int64_t pos = pos_ids[s]; + float pos_f = static_cast(pos); + + for (size_t h = 0; h < num_heads; ++h) { + size_t offset = s * (num_heads * head_dim) + h * head_dim; + + const T* src_vec = in + offset; + T* dst_vec = out + offset; + + for (size_t j = 0; j < half_dim; ++j) { + float angle = pos_f / denoms[j]; + + float cos_val = std::cos(angle); + float sin_val = std::sin(angle); + + float a = llaisys::utils::cast(src_vec[j]); + float b = llaisys::utils::cast(src_vec[j + half_dim]); + + float a_out = a * cos_val - b * sin_val; + float b_out = b * cos_val + a * sin_val; + + dst_vec[j] = llaisys::utils::cast(a_out); + dst_vec[j + half_dim] = llaisys::utils::cast(b_out); + } + } + } +} + +namespace llaisys::ops::cpu { + +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, float theta, llaisysDataType_t dtype, const std::vector &shape) { + + size_t seq_len = shape[0]; + size_t num_heads = shape[1]; + size_t head_dim = shape[2]; + + const int64_t* pos_ptr = reinterpret_cast(pos_ids); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + rope_(reinterpret_cast(out), reinterpret_cast(in), pos_ptr, theta, seq_len, num_heads, head_dim); + break; + case LLAISYS_DTYPE_BF16: + rope_(reinterpret_cast(out), reinterpret_cast(in), pos_ptr, theta, seq_len, num_heads, head_dim); + break; + case LLAISYS_DTYPE_F16: + rope_(reinterpret_cast(out), reinterpret_cast(in), pos_ptr, theta, seq_len, num_heads, head_dim); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 00000000..8a0dae47 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "llaisys.h" +#include +#include +#include + +namespace llaisys::ops::cpu { + void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, float theta, llaisysDataType_t dtype, const std::vector &shape); +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64..97e5f5e8 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,43 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/rope_cpu.hpp" namespace llaisys::ops { void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, pos_ids); + CHECK_SAME_DTYPE(out->dtype(), in->dtype()); + + ASSERT(pos_ids->dtype() == LLAISYS_DTYPE_I64, "RoPE: pos_ids must be int64."); + + ASSERT(in->shape().size() == 3, "RoPE: input tensor must be 3-D [seqlen, nhead, head_dim]."); + ASSERT(out->shape().size() == 3, "RoPE: output tensor must be 3-D."); + ASSERT(pos_ids->shape().size() == 1, "RoPE: pos_ids tensor must be 1-D [seqlen]."); + + size_t seq_len = in->shape()[0]; + size_t head_dim = in->shape()[2]; + + ASSERT(pos_ids->shape()[0] == seq_len, "RoPE: pos_ids length mismatch with input seqlen."); + ASSERT(out->shape() == in->shape(), "RoPE: output shape mismatch with input."); + ASSERT(head_dim % 2 == 0, "RoPE: head_dim must be even."); + + ASSERT(in->isContiguous() && out->isContiguous() && pos_ids->isContiguous(), "RoPE: inputs must be contiguous."); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rope(out->data(), in->data(), pos_ids->data(), theta, out->dtype(), in->shape()); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rope(out->data(), in->data(), pos_ids->data(), theta, out->dtype(), in->shape()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/self_attention/cpu/self_attention_cpu.cpp b/src/ops/self_attention/cpu/self_attention_cpu.cpp new file mode 100644 index 00000000..822d5ed7 --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.cpp @@ -0,0 +1,116 @@ +#include "self_attention_cpu.hpp" +#include "../../../utils.hpp" +#include +#include +#include +#include + +namespace { + +template +void self_attention_kernel(T *attn_val, const T *q, const T *k, const T *v, float scale, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape) { + + size_t sq = q_shape[0]; + size_t nh = q_shape[1]; + size_t d = q_shape[2]; + + size_t sk = k_shape[0]; + size_t nh_kv = k_shape[1]; + + size_t dv = v_shape[2]; + + size_t n_rep = nh / nh_kv; + + for (size_t i = 0; i < sq; ++i) { + + size_t q_abs_pos = sk - sq + i; + + for (size_t h = 0; h < nh; ++h) { + size_t h_kv = h / n_rep; + + const T* q_vec = q + (i * nh * d) + (h * d); + + std::vector scores(sk); + float max_score = -std::numeric_limits::infinity(); + + for (size_t j = 0; j < sk; ++j) { + if (j > q_abs_pos) { + scores[j] = -std::numeric_limits::infinity(); + continue; + } + const T* k_vec = k + (j * nh_kv * d) + (h_kv * d); + float dot = 0.0f; + for (size_t l = 0; l < d; ++l) { + float val_q = llaisys::utils::cast(q_vec[l]); + float val_k = llaisys::utils::cast(k_vec[l]); + dot += val_q * val_k; + } + + float score = dot * scale; + scores[j] = score; + if (score > max_score) { + max_score = score; + } + } + + float sum_exp = 0.0f; + for (size_t j = 0; j < sk; ++j) { + if (scores[j] == -std::numeric_limits::infinity()) { + scores[j] = 0.0f; + } else { + float exp_val = std::exp(scores[j] - max_score); + scores[j] = exp_val; + sum_exp += exp_val; + } + } + + float inv_sum = 1.0f / (sum_exp + 1e-10f); + + std::vector out_accum(dv, 0.0f); + + for (size_t j = 0; j < sk; ++j) { + float weight = scores[j] * inv_sum; + + if (weight < 1e-10f) continue; + const T* v_vec = v + (j * nh_kv * dv) + (h_kv * dv); + + for (size_t l = 0; l < dv; ++l) { + float val_v = llaisys::utils::cast(v_vec[l]); + out_accum[l] += weight * val_v; + } + } + + T* out_ptr = attn_val + (i * nh * dv) + (h * dv); + for (size_t l = 0; l < dv; ++l) { + out_ptr[l] = llaisys::utils::cast(out_accum[l]); + } + } + } +} + +} // namespace + +namespace llaisys::ops::cpu { + +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, float scale, llaisysDataType_t dtype, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + self_attention_kernel(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), scale, q_shape, k_shape, v_shape); + break; + case LLAISYS_DTYPE_BF16: + self_attention_kernel(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), scale, q_shape, k_shape, v_shape); + break; + case LLAISYS_DTYPE_F16: + self_attention_kernel(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), scale, q_shape, k_shape, v_shape); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/cpu/self_attention_cpu.hpp b/src/ops/self_attention/cpu/self_attention_cpu.hpp new file mode 100644 index 00000000..e73a7df3 --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "llaisys.h" +#include +#include + +namespace llaisys::ops::cpu { + +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, float scale, llaisysDataType_t dtype, + const std::vector &q_shape, + const std::vector &k_shape, + const std::vector &v_shape); + +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d62014..7ec2831b 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,49 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/self_attention_cpu.hpp" namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(attn_val, q, k, v); + CHECK_SAME_DTYPE(attn_val->dtype(), q->dtype()); + CHECK_SAME_DTYPE(q->dtype(), k->dtype()); + CHECK_SAME_DTYPE(k->dtype(), v->dtype()); + + ASSERT(q->shape().size() == 3, "SelfAttention: q must be 3-D [seqlen, nhead, d]."); + ASSERT(k->shape().size() == 3, "SelfAttention: k must be 3-D [total_len, nkvhead, d]."); + ASSERT(v->shape().size() == 3, "SelfAttention: v must be 3-D [total_len, nkvhead, dv]."); + ASSERT(attn_val->shape().size() == 3, "SelfAttention: attn_val must be 3-D [seqlen, nhead, dv]."); + + size_t nh = q->shape()[1]; + size_t nh_kv = k->shape()[1]; + size_t d = q->shape()[2]; + + // GQA Check + ASSERT(nh % nh_kv == 0, "SelfAttention: nhead must be divisible by nkvhead (GQA constraint)."); + ASSERT(k->shape()[2] == d, "SelfAttention: Q and K head_dim mismatch."); + ASSERT(attn_val->shape()[0] == q->shape()[0], "SelfAttention: Output seqlen mismatch."); + ASSERT(attn_val->shape()[1] == nh, "SelfAttention: Output nhead mismatch."); + ASSERT(attn_val->shape()[2] == v->shape()[2], "SelfAttention: Output head_dim mismatch with V."); + + ASSERT(q->isContiguous() && k->isContiguous() && v->isContiguous() && attn_val->isContiguous(), + "SelfAttention: Inputs must be contiguous."); + + if (attn_val->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), scale, attn_val->dtype(), q->shape(), k->shape(), v->shape()); + } + + llaisys::core::context().setDevice(attn_val->deviceType(), attn_val->deviceId()); + switch (attn_val->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), scale, attn_val->dtype(), q->shape(), k->shape(), v->shape()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 00000000..df12c5fd --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,40 @@ + +#include "swiglu_cpu.hpp" +#include "../../../utils.hpp" +#include + +namespace { + +template +void swiglu_(T *out, const T *gate, const T *up, size_t numel) { + for (size_t i = 0; i < numel; ++i) { + float g_val = llaisys::utils::cast(gate[i]); + float u_val = llaisys::utils::cast(up[i]); + + float swish_g = g_val / (1.0f + std::exp(-g_val)); + float res = u_val * swish_g; + out[i] = llaisys::utils::cast(res); + } +} + +} // namespace + +namespace llaisys::ops::cpu { + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t dtype, size_t numel) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + swiglu_(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), numel); + break; + case LLAISYS_DTYPE_BF16: + swiglu_(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), numel); + break; + case LLAISYS_DTYPE_F16: + swiglu_(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), numel); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 00000000..85d84006 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t dtype, size_t numel); +} diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc9..c62122ed 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,35 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/swiglu_cpu.hpp" namespace llaisys::ops { void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, gate, up); + CHECK_SAME_DTYPE(out->dtype(), gate->dtype()); + CHECK_SAME_DTYPE(gate->dtype(), up->dtype()); + + ASSERT(gate->shape() == up->shape(), "SwiGLU: gate and up tensor shapes must match."); + ASSERT(out->shape() == gate->shape(), "SwiGLU: output tensor shape must match input."); + ASSERT(gate->isContiguous() && up->isContiguous() && out->isContiguous(), "SwiGLU: Inputs/Output tensors must be contiguous."); + + size_t numel = out->numel(); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb6..0c68dc98 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -164,27 +164,91 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + size_t ndim_ = this->ndim(); + ptrdiff_t stride = 1; + const auto &shape = this->shape(); + const auto &strides = this->strides(); + for (size_t i = 1; i <= ndim_; i++) { + if(strides[ndim_ - i] != stride) return false; + stride *= shape[ndim_ - i]; + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t ndim_ = order.size(); + if(ndim_ != this->ndim()){ + throw std::runtime_error("permute: order size does not match tensor ndim"); + } + std::vector seen(ndim_, false); + std::vector shape(ndim_); + std::vector strides(ndim_); + const auto &old_shape = this->shape(); + const auto &old_strides = this->strides(); + for(size_t i = 0; i < ndim_; ++i){ + size_t idx = order[i]; + if(idx < 0 || idx >= ndim_){ + throw std::runtime_error("permute: order index out of range"); + } + if(seen[idx]){ + throw std::runtime_error("permute: duplicate indices in order"); + } + seen[idx] = true; + shape[i] = old_shape[idx]; + strides[i] = old_strides[idx]; + } + TensorMeta meta = {this->dtype(), shape, strides}; + return std::shared_ptr(new Tensor(meta, _storage, _offset)); } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t numel = std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); + if(numel != this->numel()){ + throw std::runtime_error("view: new shape size does not match tensor size"); + } + if (!this->isContiguous()) { + throw std::runtime_error("view: input tensor must be contiguous. call .contiguous() first."); + } + size_t ndim_ = shape.size(); + std::vector strides(ndim_); + size_t stride = 1; + for (size_t i = 1; i <= ndim_; i++) { + strides[ndim_ - i] = stride; + stride *= shape[ndim_ - i]; + } + TensorMeta meta = {this->dtype(), shape, strides}; + return std::shared_ptr(new Tensor(meta, _storage, _offset)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t ndim_ = this->ndim(); + if(dim >= ndim_){ + throw std::runtime_error("slice: dim out of range"); + } + const auto& old_shape = this->shape(); + if(start >= end || end > old_shape[dim]){ + throw std::runtime_error("slice: invalid start or end"); + } + std::vector shape = old_shape;; + shape[dim] = end - start; + TensorMeta meta = {this->dtype(), shape, this->strides()}; + size_t offset = _offset + start * this->strides()[dim] * this->elementSize(); + return std::shared_ptr(new Tensor(meta, _storage, offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + size_t total_size = this->numel() * this->elementSize(); + void *dst_ = this->data(); + if (this->deviceType() == LLAISYS_DEVICE_NVIDIA) { + core::context().setDevice(this->deviceType(), this->deviceId()); + core::context().runtime().api()->memcpy_sync( + dst_, + src_, + total_size, + LLAISYS_MEMCPY_H2D); + } else { + std::memcpy(dst_, src_, total_size); + } } tensor_t Tensor::contiguous() const { diff --git a/test/ops/rope_debug.py b/test/ops/rope_debug.py new file mode 100644 index 00000000..54b5352c --- /dev/null +++ b/test/ops/rope_debug.py @@ -0,0 +1,113 @@ +import sys +import os +import torch +import numpy as np + +# Adjust path to find llaisys package +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../python")) +sys.path.insert(0, parent_dir) +import llaisys + +def torch_rope(y: torch.Tensor, x: torch.Tensor, pos_ids: torch.Tensor, theta: float): + seq_len, n_heads, head_dim = y.shape + x_a, x_b = x[..., : head_dim // 2], x[..., head_dim // 2 :] + positions = pos_ids.to(torch.float32).unsqueeze(1) + i = torch.arange(0, head_dim // 2, dtype=torch.float32, device=y.device) + freqs = positions / (theta ** (2 * i / head_dim)) + sin, cos = freqs.sin(), freqs.cos() + sin = sin.unsqueeze(1) + cos = cos.unsqueeze(1) + y[..., : head_dim // 2] = x_a * cos - x_b * sin + y[..., head_dim // 2 :] = x_b * cos + x_a * sin + +def debug_rope(): + # Configuration matching the failing case + shape = (512, 4, 4096) + start_pos = 512 + end_pos = 1024 + dtype = torch.float32 + theta = 10000.0 + + print(f"Debugging RoPE with shape={shape}, range=[{start_pos}, {end_pos}), dtype={dtype}") + + # 1. Setup Data + torch.manual_seed(42) + x = torch.randn(shape, dtype=dtype) + pos_ids = torch.arange(start_pos, end_pos, dtype=torch.int64) + y_torch = torch.zeros_like(x) + + # 2. Run PyTorch + torch_rope(y_torch, x, pos_ids, theta) + + # 3. Setup LLAISYS + # Helpers + device_enum = llaisys.DeviceType.CPU + dt_enum = llaisys.DataType.F32 + api = llaisys.RuntimeAPI(device_enum) + + # Create LLAISYS tensors + x_ll = llaisys.Tensor(shape, dtype=dt_enum, device=device_enum) + y_ll = llaisys.Tensor(shape, dtype=dt_enum, device=device_enum) + pos_ll = llaisys.Tensor((len(pos_ids),), dtype=llaisys.DataType.I64, device=device_enum) + + # Copy Input Data (x, pos_ids) + # Using HostToHost since we are on CPU + kind = llaisys.MemcpyKind.HostToHost + + api.memcpy_sync(x_ll.data_ptr(), x.data_ptr(), x.numel() * x.element_size(), kind) + api.memcpy_sync(pos_ll.data_ptr(), pos_ids.data_ptr(), pos_ids.numel() * pos_ids.element_size(), kind) + + # Run Op + llaisys.Ops.rope(y_ll, x_ll, pos_ll, theta) + + # Copy Output Data back + y_llaisys = torch.zeros_like(x) + api.memcpy_sync(y_llaisys.data_ptr(), y_ll.data_ptr(), y_ll.numel() * y_ll.element_size(), kind) + + # 4. Analyze Error + diff = (y_torch - y_llaisys).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + print(f"Max Diff: {max_diff:.2e}") + print(f"Mean Diff: {mean_diff:.2e}") + + # 5. Detailed Breakdown + if max_diff > 1e-5: # Only show details if significant error + max_indices = torch.nonzero(diff == max_diff) + if len(max_indices) > 0: + idx = max_indices[0] + seq_idx, head_idx, dim_idx = idx.tolist() + print(f"Max error at index: seq={seq_idx}, head={head_idx}, dim={dim_idx}") + curr_pos = pos_ids[seq_idx].item() + print(f"Pos ID at failure: {curr_pos}") + + # Theoretical calc + head_dim = shape[2] + freq_idx = dim_idx if dim_idx < head_dim // 2 else dim_idx - head_dim // 2 + + freq_exponent_f = (2.0 * freq_idx) / head_dim + denom_f = theta ** freq_exponent_f + angle_f = curr_pos / denom_f + + # Double precision check + freq_exponent_d = (2.0 * freq_idx) / float(head_dim) + denom_d = theta ** freq_exponent_d + angle_d = curr_pos / denom_d + + print(f"Angle(float) approx: {angle_f}") + print(f"Angle(double) approx: {angle_d}") + + val_t = y_torch[seq_idx, head_idx, dim_idx].item() + val_l = y_llaisys[seq_idx, head_idx, dim_idx].item() + print(f"Values: Torch={val_t:.8f}, LLAISYS={val_l:.8f}") + print(f"Diff: {abs(val_t - val_l):.8f}") + + if max_diff > 5e-4: + print("\n\033[91mFAILED: Error exceeds 5e-4\033[0m") + sys.exit(1) + else: + print("\n\033[92mPASSED\033[0m") + +if __name__ == "__main__": + debug_rope() diff --git a/xmake.lua b/xmake.lua index 1f65f7a9..47ee0fde 100644 --- a/xmake.lua +++ b/xmake.lua @@ -95,6 +95,23 @@ target("llaisys-ops") on_install(function (target) end) target_end() +target("llaisys-models") + set_kind("static") + add_deps("llaisys-ops") + add_deps("llaisys-tensor") + add_deps("llaisys-core") + + set_languages("cxx17") + set_warnings("all", "error") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + end + + add_files("src/models/*/*.cpp") + + on_install(function (target) end) +target_end() + target("llaisys") set_kind("shared") add_deps("llaisys-utils") @@ -102,6 +119,7 @@ target("llaisys") add_deps("llaisys-core") add_deps("llaisys-tensor") add_deps("llaisys-ops") + add_deps("llaisys-models") set_languages("cxx17") set_warnings("all", "error")