diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h index 7054626d..a3e282c8 100644 --- a/include/llaisys/models/qwen2.h +++ b/include/llaisys/models/qwen2.h @@ -2,6 +2,7 @@ #define LLAISYS_MODELS_QWEN2_H #include "../tensor.h" +#include __C { struct LlaisysQwen2Meta { @@ -14,19 +15,20 @@ __C { struct LlaisysQwen2Weights { llaisysTensor_t in_embed; llaisysTensor_t out_embed; - llaisysTensor_t out_norm_w; // a.k.a. model.norm.weight - llaisysTensor_t *attn_norm_w; // a.k.a. input_layernorm.weight - llaisysTensor_t *attn_q_w; - llaisysTensor_t *attn_q_b; - llaisysTensor_t *attn_k_w; - llaisysTensor_t *attn_k_b; - llaisysTensor_t *attn_v_w; - llaisysTensor_t *attn_v_b; - llaisysTensor_t *attn_o_w; - llaisysTensor_t *mlp_norm_w; // a.k.a. post_attention_layernorm.weight - llaisysTensor_t *mlp_gate_w; - llaisysTensor_t *mlp_up_w; - llaisysTensor_t *mlp_down_w; + llaisysTensor_t out_norm_w; + // 改为 vector + std::vector attn_norm_w; + std::vector attn_q_w; + std::vector attn_q_b; + std::vector attn_k_w; + std::vector attn_k_b; + std::vector attn_v_w; + std::vector attn_v_b; + std::vector attn_o_w; + std::vector mlp_norm_w; + std::vector mlp_gate_w; + std::vector mlp_up_w; + std::vector mlp_down_w; }; struct LlaisysQwen2Model; @@ -37,6 +39,18 @@ __C { __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model); - __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); + __export void llaisysQwen2LoadWeight( + struct LlaisysQwen2Model * model, + const char * name, + const void * data, + size_t * shape, + size_t ndim, + llaisysDataType_t dtype); + + __export int64_t llaisysQwen2ModelInfer( + struct LlaisysQwen2Model * model, + int64_t * token_ids, + size_t ntoken, + size_t start_pos); } #endif // LLAISYS_MODELS_QWEN2_H diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b2..d8877aae 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,23 +1,117 @@ from typing import Sequence from ..libllaisys import LIB_LLAISYS -from ..libllaisys import DeviceType +from ..libllaisys import DeviceType, DataType + +from .qwen2_binding import register_qwen2_lib, Qwen2MetaCStruct from pathlib import Path -import safetensors +import safetensors.torch +import torch +import os +import json +import ctypes class Qwen2: def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor + self.lib = LIB_LLAISYS + register_qwen2_lib(self.lib) # Register C functions model_path = Path(model_path) + config_path = model_path / "config.json" + + if not config_path.exists(): + raise FileNotFoundError(f"Config not found at {config_path}") + + with open(config_path, "r") as f: + config = json.load(f) + + self.meta = Qwen2MetaCStruct() + + # Populate Meta (Default fallback values based on typical Qwen2 config) + self.meta.hs = config.get("hidden_size", 1536) + self.meta.nlayer = config.get("num_hidden_layers", 28) + self.meta.nh = config.get("num_attention_heads", 12) + self.meta.nkvh = config.get("num_key_value_heads", 2) + self.meta.voc = config.get("vocab_size", 151936) + self.meta.maxseq = config.get("max_position_embeddings", 32768) + self.meta.di = config.get("intermediate_size", 8960) + self.meta.epsilon = config.get("rms_norm_eps", 1e-6) + self.meta.theta = config.get("rope_theta", 10000.0) + self.meta.dh = self.meta.hs // self.meta.nh + + # Determine EOS token + eos_id = config.get("eos_token_id", 151643) # ID of <|endoftext|> + if isinstance(eos_id, list): + self.meta.end_token = eos_id[0] + else: + self.meta.end_token = eos_id + + # Set dtype for the model struct (match weight dtype when possible) + torch_dtype = str(config.get("torch_dtype", "float32")).lower() + if "bfloat16" in torch_dtype or "bf16" in torch_dtype: + self.meta.dtype = DataType.BF16 + elif "float16" in torch_dtype or "fp16" in torch_dtype: + self.meta.dtype = DataType.F16 + else: + self.meta.dtype = DataType.F32 + + # Create C Model + device_ids = (ctypes.c_int * 1)(0) + # Use F32 for KV cache for stability on CPU + self.model = self.lib.llaisysQwen2ModelCreate( + ctypes.byref(self.meta), + device, + device_ids, + 1 + ) + + if not self.model: + raise RuntimeError("Failed to create native Qwen2 model instance.") + # Load Weights for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") - for name_ in data_.keys(): - ## TODO: load the model weights - pass + print(f"Loading weights from {file}...") + # Use safe_open from safetensors (torch backend to support BF16) + with safetensors.torch.safe_open(file, framework="pt", device="cpu") as f: + for name in f.keys(): + tensor = f.get_tensor(name) + + # Map torch dtype to Llaisys DataType + dt = DataType.F32 + if tensor.dtype == torch.float16: + dt = DataType.F16 + elif tensor.dtype == torch.float32: + dt = DataType.F32 + elif tensor.dtype == torch.bfloat16: + dt = DataType.BF16 + elif tensor.dtype == torch.int64: + dt = DataType.I64 + + # Ensure contiguous + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + + shape = tensor.shape + c_shape = (ctypes.c_size_t * len(shape))(*shape) + + # Keep a reference to data pointer valid during the C call + data_ptr = ctypes.c_void_p(tensor.data_ptr()) + + self.lib.llaisysQwen2LoadWeight( + self.model, + name.encode('utf-8'), + data_ptr, + c_shape, + len(shape), + dt + ) + + def __del__(self): + if hasattr(self, 'model') and self.model: + self.lib.llaisysQwen2ModelDestroy(self.model) + self.model = None def generate( self, @@ -27,7 +121,33 @@ def generate( top_p: float = 0.8, temperature: float = 0.8, ): - - # TODO: Implement generate function - - return [] + if max_new_tokens is None: + max_new_tokens = 20 + + tokens = list(inputs) + start_pos = 0 + + for _ in range(max_new_tokens): + if start_pos == 0: + current_input = tokens + else: + current_input = tokens[-1:] # Next token generation: use only last token + + n_tokens = len(current_input) + c_inputs = (ctypes.c_int64 * n_tokens)(*current_input) + + # Infer (argmax inside backend) + next_token_id = self.lib.llaisysQwen2ModelInfer( + self.model, + c_inputs, + n_tokens, + start_pos + ) + + tokens.append(next_token_id) + start_pos += n_tokens + + if next_token_id == self.meta.end_token: + break + + return tokens diff --git a/python/llaisys/models/qwen2_binding.py b/python/llaisys/models/qwen2_binding.py new file mode 100644 index 00000000..009f8329 --- /dev/null +++ b/python/llaisys/models/qwen2_binding.py @@ -0,0 +1,57 @@ +import ctypes +from ctypes import c_size_t, c_int, c_float, c_void_p, c_int64, POINTER, Structure, c_char_p +from ..libllaisys.llaisys_types import DataType, llaisysDataType_t, llaisysDeviceType_t + +class Qwen2MetaCStruct(Structure): + _fields_ = [ + ("dtype", llaisysDataType_t), + ("nlayer", c_size_t), + ("hs", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("maxseq", c_size_t), + ("voc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_int64), + ] + +# Opaque pointer handle +LlaisysQwen2ModelHandle = c_void_p + +def register_qwen2_lib(lib): + if hasattr(lib, "llaisysQwen2ModelCreate"): + # Create + lib.llaisysQwen2ModelCreate.restype = LlaisysQwen2ModelHandle + lib.llaisysQwen2ModelCreate.argtypes = [ + POINTER(Qwen2MetaCStruct), + llaisysDeviceType_t, + POINTER(c_int), # device_ids + c_int # ndev + ] + + # Destroy + lib.llaisysQwen2ModelDestroy.restype = None + lib.llaisysQwen2ModelDestroy.argtypes = [LlaisysQwen2ModelHandle] + + # Load Weight + lib.llaisysQwen2LoadWeight.restype = None + lib.llaisysQwen2LoadWeight.argtypes = [ + LlaisysQwen2ModelHandle, + c_char_p, # name + c_void_p, # data + POINTER(c_size_t), # shape + c_size_t, # ndim + llaisysDataType_t # dtype + ] + + # Infer + lib.llaisysQwen2ModelInfer.restype = c_int64 + lib.llaisysQwen2ModelInfer.argtypes = [ + LlaisysQwen2ModelHandle, + POINTER(c_int64), # input_ids_ptr + c_size_t, # seq_len + c_size_t # start_pos + ] diff --git a/reports/original/hw.md b/reports/original/hw.md new file mode 100644 index 00000000..c86062b6 --- /dev/null +++ b/reports/original/hw.md @@ -0,0 +1,161 @@ +什么时候连续? + +1. shape 和 strides 什么形状? + // see line 21 + // strides' dim = shape' dim + // e.g. shape is of [3,2] + // strides should be [2,1] (when cont) + + // when a tensor is contiguous? + // 1. obv, there must be a stride = 1; + // 2. stides[i] = strides[i+1] * shapes[i+1]; + +2. 什么时候连续? + +连续是指内存的连续布局 + +举例而言 +内存布局 +位置: 0 1 2 3 4 5 +值: [0] [1] [2] [3] [4] [5] +逻辑视图: +[[0, 1, 2], // 行0 + [3, 4, 5]] // 行1 + +形状 (shape): (2, 3) +步长 (strides): (3, 1) + +数学化的表示 stride[i] = stride[i+1] * shape[i+1] 就连续了. +例如这里 3 = 3 * 1. + +现在, 我们对张量进行转置: + +[[0, 3], // 新行0(原来列0) + [1, 4], // 新行1(原来列1) + [2, 5]] // 新行2(原来列2) + +形状 (shape): (3, 2) – 原来2行3列,现在3行2列。 +步长 (strides): (1, 3) – 原来 (3, 1) 交换后变为 (1, 3),因为转置不改变内存布局,只改变访问方式。 +总元素数: 仍为6,内存布局不变。 + +!!!note "转置可以直接转置 shape 和 strides" + +!!!note "这里是连续张量吗?" + 理论上是, 是列优先的. + 但是例如 pytorch 的实现可能认为不是. + +现在, 我们不满足之前的规律了. 好在, 我们修正一下就好了 + +stride[i] = stride[i-1] * shape[i-1] +现在又变成 3 = 3 * 1 了. + +~~显然~~, 我们最内层的维度必须有stride = 1. (这不显然, 要想很久) 然后按照一定顺序, 满足 stride[第i内层] = stride[第i+1内层] * shape[第i+1内层], 就能判断是不是内存连续了. 为此, 我们可以定义一个 perm 数组, 用来记录维度的排序方式. + +!!!note 连续张量判定: + 1. stride[perm[-1]] == 1 + 2. stride[perm[i]] == stride[perm[i+1]] * shape[perm[i+1]] + +终于, 回到我们题目, 居然没有 perm? 大概就是用 pytorch 一样标准只考虑行布局? 对的. + + +permute, shape 直接 新 strides[i] = 旧 strides[order[i]], strides 呢? +也是! + +- [ ] 为啥呢? ... + + + + + + + + + +(.venv) root@node-212:~/other/llaisys# xmake --root +[ 39%]: cache compiling.release src/tensor/tensor.cpp +error: src/tensor/tensor.cpp:170:33: error: comparison of integer expressions of different signedness: 'const value_type' {aka 'const long int'} and 'size_t' {aka 'long unsigned int'} [-Werror=sign-compare] + 170 | if (strides()[ndim - i] != stride) +cc1plus: all warnings being treated as errors + > in src/tensor/tensor.cpp + + 为什么我这样不行 + bool Tensor::isContiguous() const { + size_t stride = 1; + size_t ndim = shape().size(); + for (size_t i = 1; i <= ndim; i++) { + if (strides()[ndim - i] != stride) + return false; + stride *= shape()[ndim - i]; + } + return true; + } + +他那样就行? +size_t ndim_ = shape.size(); + std::vector strides(ndim_); + size_t stride = 1; + for (size_t i = 1; i <= ndim_; i++) { + strides[ndim_ - i] = stride; + stride *= shape[ndim_ - i]; + } + TensorMeta meta{dtype, shape, strides}; +### 问题原因 +编译错误是因为在 `isContiguous()` 中,`strides()[ndim - i]` 返回 `ptrdiff_t`(有符号,`long int`),而 `stride` 是 `size_t`(无符号,`long unsigned int`)。比较有符号和无符号整数时,编译器发出警告(`-Werror=sign-compare` 将其视为错误)。 + +- `ptrdiff_t` 用于 strides,因为步长可能为负(但在连续张量中通常正)。 +- `size_t` 用于正整数如尺寸。 + +`create()` 中没有警告,因为赋值 `strides[ndim_ - i] = stride` 是隐式转换(`size_t` 到 `ptrdiff_t`),编译器允许,但比较时严格。 + +### 为什么 `create` “行”? +`create` 中是赋值,不是比较,所以没触发警告。但为了避免潜在问题,最好统一类型。 + +### 修复方法 +1. **强制转换比较**(推荐,保持 `stride` 为 `size_t`): + ```cpp + bool Tensor::isContiguous() const { + size_t stride = 1; + size_t ndim = shape().size(); + for (size_t i = 1; i <= ndim; i++) { + if (static_cast(strides()[ndim - i]) != stride) + return false; + stride *= shape()[ndim - i]; + } + return true; + } + ``` + +2. **将 `stride` 改为 `ptrdiff_t`**: + ```cpp + bool Tensor::isContiguous() const { + ptrdiff_t stride = 1; + size_t ndim = shape().size(); + for (size_t i = 1; i <= ndim; i++) { + if (strides()[ndim - i] != stride) + return false; + stride *= static_cast(shape()[ndim - i]); + } + return true; + } + ``` + +第一种更安全,因为 strides 在连续张量中是正数。用第一种修复后,代码应编译通过。 + + + + struct LlaisysRuntimeAPI { + get_device_count_api get_device_count; + set_device_api set_device; + device_synchronize_api device_synchronize; + create_stream_api create_stream; + destroy_stream_api destroy_stream; + stream_synchronize_api stream_synchronize; + malloc_device_api malloc_device; + free_device_api free_device; + malloc_host_api malloc_host; + free_host_api free_host; + memcpy_sync_api memcpy_sync; + memcpy_async_api memcpy_async; + }; + +谁知道呢, api 真的是 api. \ No newline at end of file diff --git a/reports/original/project-1.md b/reports/original/project-1.md new file mode 100644 index 00000000..4da724e2 --- /dev/null +++ b/reports/original/project-1.md @@ -0,0 +1,52 @@ +# Project 1: Linear OP Performance Optimization (CPU) + +## Benchmark +``` +(venv) scbz@dsw-607126-85f54bdf75-5lzlx:~/llaisys$ OPENBLAS_NUM_THREADS=32 OMP_NUM_THREADS=32 python test/ops/linear.py --profile +Testing Ops.linear on cpu + out (2, 3), x (2, 4), w (3, 4), bias True, dtype + Torch time: 0.00374 ms + LLAISYS time: 0.00158 ms + out (2, 3), x (2, 4), w (3, 4), bias True, dtype + Torch time: 0.01351 ms + LLAISYS time: 0.00374 ms + out (2, 3), x (2, 4), w (3, 4), bias True, dtype + Torch time: 0.01428 ms + LLAISYS time: 0.00366 ms + out (512, 4096), x (512, 4096), w (4096, 4096), bias True, dtype + Torch time: 11.43182 ms + LLAISYS time: 17.31473 ms + out (512, 4096), x (512, 4096), w (4096, 4096), bias True, dtype + Torch time: 72.20329 ms + LLAISYS time: 20.56061 ms + out (512, 4096), x (512, 4096), w (4096, 4096), bias True, dtype + Torch time: 38.64677 ms + LLAISYS time: 19.27763 ms +Test passed! +``` + +## Main changes (this branch) +- **Aligned allocation + memcpy reduction:** `src/device/cpu/cpu_runtime_api.cpp` now allocates aligned buffers and skips memcpy for `float`/`double` when inputs are already aligned. +- **Vectorized cast helpers:** `src/ops/utils.hpp` adds faster casting paths (SIMD-friendly) with alignment checks. +- **GEMM backend:** `linear_cpu.cpp` now prefers OpenBLAS for `fp32/fp16` when inputs are aligned, avoiding unnecessary copy stage. + +## Roadmap (next improvements) +- SIMD casting path for `bf16` / full `f16` pipeline. +- Reduce memory passes (aim for 1-2 passes instead of 3) by avoiding extra temporary buffers. +- Add runtime dispatch for AVX2/AVX-512 (and fallback for AMD) to maximize portability and performance. +- Standardize benchmarking on a dedicated machine (remove background work) to get stable numbers. + +## Known issues / caveats +1. Current implementation still uses multiple memory passes; performance isn't yet optimal. +2. `bf16` is still slower than `fp16`/`fp32` due to casting overhead. +3. Benchmark host is shared; results may vary with background load. +4. Recommended tuning: `OPENBLAS_NUM_THREADS=32 OMP_NUM_THREADS=32` (not fully validated). + +## CPU ISA (benchmark host) +``` +(venv) scbz@dsw-607126-85f54bdf75-5lzlx:~/llaisys$ cat /proc/cpuinfo | grep flags | head -1 +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd avx512vbmi umip pku avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm md_clear arch_capabilities +``` + +- Supports AVX2 + F16C. +- AVX-512 is available on this host but is avoided for broader AMD compatibility. diff --git a/src/device/cpu/cpu_runtime_api.cpp b/src/device/cpu/cpu_runtime_api.cpp index 8d57cc40..5f2aae04 100644 --- a/src/device/cpu/cpu_runtime_api.cpp +++ b/src/device/cpu/cpu_runtime_api.cpp @@ -30,7 +30,14 @@ void streamSynchronize(llaisysStream_t stream) { } void *mallocDevice(size_t size) { - return std::malloc(size); + if (size == 0) return nullptr; + + void *ptr = nullptr; + // posix_memalign requires alignment to be a power of two and multiple of sizeof(void*) + if (posix_memalign(&ptr, 32, size) != 0) { + return nullptr; + } + return ptr; } void freeDevice(void *ptr) { diff --git a/src/llaisys/models/qwen2.cpp b/src/llaisys/models/qwen2.cpp new file mode 100644 index 00000000..8c62617c --- /dev/null +++ b/src/llaisys/models/qwen2.cpp @@ -0,0 +1,331 @@ +#include "llaisys/models/qwen2.h" +#include "llaisys/tensor.h" +#include "llaisys.h" +#include "llaisys/runtime.h" +#include "llaisys/ops.h" + +#include +#include +#include + +struct LlaisysQwen2Model { +public: + LlaisysQwen2Meta meta; + LlaisysQwen2Weights weights; + std::vector k_cache; + std::vector v_cache; + const LlaisysRuntimeAPI *runtime_api; + llaisysDeviceType_t device; + int device_id; + + // 构造函数 + LlaisysQwen2Model(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int device_id) + : meta(*meta), device(device), device_id(device_id) { + this->runtime_api = llaisysGetRuntimeAPI(device); + llaisysSetContextRuntime(device, this->device_id); + createGlobalWeights(this->weights, meta, device, this->device_id); + createLayerWeights(this->weights, meta, device, this->device_id); + this->k_cache.resize(meta->nlayer); + this->v_cache.resize(meta->nlayer); + size_t shape_kv[3] = {meta->maxseq, meta->nkvh, meta->dh}; + for (size_t i = 0; i < meta->nlayer; ++i) { + this->k_cache[i] = tensorCreate(shape_kv, 3, meta->dtype, device, this->device_id); + this->v_cache[i] = tensorCreate(shape_kv, 3, meta->dtype, device, this->device_id); + } + } + + // 析构函数 + ~LlaisysQwen2Model() { + // 释放全局权重 + tensorDestroy(weights.in_embed); + tensorDestroy(weights.out_embed); + tensorDestroy(weights.out_norm_w); + + // 释放层权重 + for (auto &t : weights.attn_norm_w) tensorDestroy(t); + for (auto &t : weights.attn_q_w) tensorDestroy(t); + for (auto &t : weights.attn_q_b) tensorDestroy(t); + for (auto &t : weights.attn_k_w) tensorDestroy(t); + for (auto &t : weights.attn_k_b) tensorDestroy(t); + for (auto &t : weights.attn_v_w) tensorDestroy(t); + for (auto &t : weights.attn_v_b) tensorDestroy(t); + for (auto &t : weights.attn_o_w) tensorDestroy(t); + for (auto &t : weights.mlp_norm_w) tensorDestroy(t); + for (auto &t : weights.mlp_gate_w) tensorDestroy(t); + for (auto &t : weights.mlp_up_w) tensorDestroy(t); + for (auto &t : weights.mlp_down_w) tensorDestroy(t); + + // 释放 KV cache + for (auto &t : k_cache) tensorDestroy(t); + for (auto &t : v_cache) tensorDestroy(t); + } + + // 推理方法 + int64_t infer(int64_t *token_ids, size_t ntoken, size_t start_pos = 0) { + size_t hs = meta.hs; + size_t head_dim = hs / meta.nh; + size_t kv_dim = head_dim * meta.nkvh; + + // 创建输入张量 + size_t input_shape[1] = {ntoken}; + llaisysTensor_t input_tensor = tensorCreate(input_shape, 1, LLAISYS_DTYPE_I64, device, device_id); + + // 加载 token_ids + tensorLoad(input_tensor, token_ids); + + size_t hidden_shape[2] = {ntoken, hs}; + llaisysTensor_t hidden_states = tensorCreate(hidden_shape, 2, meta.dtype, device, device_id); + + // Embedding + llaisysEmbedding(hidden_states, input_tensor, weights.in_embed); + + // 位置 IDs + size_t pos_shape[1] = {ntoken}; + llaisysTensor_t pos_ids = tensorCreate(pos_shape, 1, LLAISYS_DTYPE_I64, device, device_id); + std::vector pos_vec(ntoken); + for (size_t i = 0; i < ntoken; ++i) pos_vec[i] = start_pos + i; + tensorLoad(pos_ids, pos_vec.data()); + + for (size_t i = 0; i < meta.nlayer; ++i) { + // RMS Norm + llaisysTensor_t norm_out = tensorCreate(hidden_shape, 2, meta.dtype, device, device_id); + llaisysRmsNorm(norm_out, hidden_states, weights.attn_norm_w[i], meta.epsilon); + + // Q, K, V + size_t q_shape_2d[2] = {ntoken, hs}; + size_t kv_shape_2d[2] = {ntoken, kv_dim}; + llaisysTensor_t q2d = tensorCreate(q_shape_2d, 2, meta.dtype, device, device_id); + llaisysTensor_t k2d = tensorCreate(kv_shape_2d, 2, meta.dtype, device, device_id); + llaisysTensor_t v2d = tensorCreate(kv_shape_2d, 2, meta.dtype, device, device_id); + + llaisysLinear(q2d, norm_out, weights.attn_q_w[i], weights.attn_q_b[i]); + llaisysLinear(k2d, norm_out, weights.attn_k_w[i], weights.attn_k_b[i]); + llaisysLinear(v2d, norm_out, weights.attn_v_w[i], weights.attn_v_b[i]); + + size_t q_shape[3] = {ntoken, meta.nh, head_dim}; + size_t kv_shape[3] = {ntoken, meta.nkvh, head_dim}; + llaisysTensor_t q = tensorView(q2d, q_shape, 3); + llaisysTensor_t k = tensorView(k2d, kv_shape, 3); + llaisysTensor_t v = tensorView(v2d, kv_shape, 3); + + // RoPE + llaisysROPE(q, q, pos_ids, meta.theta); + llaisysROPE(k, k, pos_ids, meta.theta); + + // KV Cache + llaisysTensor_t layer_k_cache = k_cache[i]; + llaisysTensor_t layer_v_cache = v_cache[i]; + + llaisysTensor_t k_slot = tensorSlice(layer_k_cache, 0, start_pos, start_pos + ntoken); + llaisysTensor_t v_slot = tensorSlice(layer_v_cache, 0, start_pos, start_pos + ntoken); + tensorLoad(k_slot, tensorGetData(k2d)); + tensorLoad(v_slot, tensorGetData(v2d)); + + llaisysTensor_t full_k = tensorSlice(layer_k_cache, 0, 0, start_pos + ntoken); + llaisysTensor_t full_v = tensorSlice(layer_v_cache, 0, 0, start_pos + ntoken); + + // Self Attention + size_t attn_shape[3] = {ntoken, meta.nh, head_dim}; + llaisysTensor_t attn_out = tensorCreate(attn_shape, 3, meta.dtype, device, device_id); + float scale = static_cast(1.0 / sqrt(head_dim)); + llaisysSelfAttention(attn_out, q, full_k, full_v, scale); + + // Proj + llaisysTensor_t attn_out_2d = tensorView(attn_out, hidden_shape, 2); + llaisysTensor_t proj_out = tensorCreate(hidden_shape, 2, meta.dtype, device, device_id); + llaisysLinear(proj_out, attn_out_2d, weights.attn_o_w[i], nullptr); + + llaisysAdd(hidden_states, hidden_states, proj_out); + + // FFN + llaisysTensor_t ffn_norm_out = tensorCreate(hidden_shape, 2, meta.dtype, device, device_id); + llaisysRmsNorm(ffn_norm_out, hidden_states, weights.mlp_norm_w[i], meta.epsilon); + + size_t inter_size = meta.di; + size_t inter_shape[2] = {ntoken, inter_size}; + llaisysTensor_t gate = tensorCreate(inter_shape, 2, meta.dtype, device, device_id); + llaisysTensor_t up = tensorCreate(inter_shape, 2, meta.dtype, device, device_id); + + llaisysLinear(gate, ffn_norm_out, weights.mlp_gate_w[i], nullptr); + llaisysLinear(up, ffn_norm_out, weights.mlp_up_w[i], nullptr); + + llaisysTensor_t act = tensorCreate(inter_shape, 2, meta.dtype, device, device_id); + llaisysSwiGLU(act, gate, up); + + llaisysTensor_t mlp_out = tensorCreate(hidden_shape, 2, meta.dtype, device, device_id); + llaisysLinear(mlp_out, act, weights.mlp_down_w[i], nullptr); + + llaisysAdd(hidden_states, hidden_states, mlp_out); + + tensorDestroy(norm_out); + tensorDestroy(q2d); + tensorDestroy(k2d); + tensorDestroy(v2d); + tensorDestroy(q); + tensorDestroy(k); + tensorDestroy(v); + // layer_k_cache/layer_v_cache are owned by model + tensorDestroy(k_slot); + tensorDestroy(v_slot); + tensorDestroy(full_k); + tensorDestroy(full_v); + tensorDestroy(attn_out); + tensorDestroy(attn_out_2d); + tensorDestroy(proj_out); + tensorDestroy(ffn_norm_out); + tensorDestroy(gate); + tensorDestroy(up); + tensorDestroy(act); + tensorDestroy(mlp_out); + } + + // Final Norm + llaisysTensor_t final_norm = tensorCreate(hidden_shape, 2, meta.dtype, device, device_id); + llaisysRmsNorm(final_norm, hidden_states, weights.out_norm_w, meta.epsilon); + + // Logits + size_t logits_shape[2] = {ntoken, meta.voc}; + llaisysTensor_t logits = tensorCreate(logits_shape, 2, meta.dtype, device, device_id); + llaisysLinear(logits, final_norm, weights.out_embed, nullptr); + + // Argmax on last token + llaisysTensor_t last_token_logits = tensorSlice(logits, 0, ntoken - 1, ntoken); + llaisysTensor_t final_logits = last_token_logits; + size_t max_shape[1] = {1}; + llaisysTensor_t max_idx = tensorCreate(max_shape, 1, LLAISYS_DTYPE_I64, device, device_id); + llaisysTensor_t max_val = tensorCreate(max_shape, 1, meta.dtype, device, device_id); + llaisysArgmax(max_idx, max_val, final_logits); + + int64_t result = *reinterpret_cast(tensorGetData(max_idx)); + + // 释放临时张量 + tensorDestroy(input_tensor); + tensorDestroy(hidden_states); + tensorDestroy(pos_ids); + tensorDestroy(final_norm); + tensorDestroy(logits); + tensorDestroy(last_token_logits); + tensorDestroy(max_idx); + tensorDestroy(max_val); + + return result; + } + +private: + // 辅助函数:创建全局权重 + static void createGlobalWeights(LlaisysQwen2Weights &weights, const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int device_id) { + size_t shape_in_embed[2] = {meta->voc, meta->hs}; + weights.in_embed = tensorCreate(shape_in_embed, 2, meta->dtype, device, device_id); + + size_t shape_out_embed[2] = {meta->voc, meta->hs}; + weights.out_embed = tensorCreate(shape_out_embed, 2, meta->dtype, device, device_id); + + size_t shape_out_norm[1] = {meta->hs}; + weights.out_norm_w = tensorCreate(shape_out_norm, 1, meta->dtype, device, device_id); + } + + // 辅助函数:创建层权重 + static void createLayerWeights(LlaisysQwen2Weights &weights, const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int device_id) { + size_t n = meta->nlayer; + weights.attn_norm_w.resize(n); + weights.attn_q_w.resize(n); + weights.attn_q_b.resize(n); + weights.attn_k_w.resize(n); + weights.attn_k_b.resize(n); + weights.attn_v_w.resize(n); + weights.attn_v_b.resize(n); + weights.attn_o_w.resize(n); + weights.mlp_norm_w.resize(n); + weights.mlp_gate_w.resize(n); + weights.mlp_up_w.resize(n); + weights.mlp_down_w.resize(n); + + for (size_t i = 0; i < n; ++i) { + size_t shape_norm[1] = {meta->hs}; + weights.attn_norm_w[i] = tensorCreate(shape_norm, 1, meta->dtype, device, device_id); + size_t shape_q[2] = {meta->nh * meta->dh, meta->hs}; + weights.attn_q_w[i] = tensorCreate(shape_q, 2, meta->dtype, device, device_id); + size_t shape_qb[1] = {meta->nh * meta->dh}; + weights.attn_q_b[i] = tensorCreate(shape_qb, 1, meta->dtype, device, device_id); + size_t shape_k[2] = {meta->nkvh * meta->dh, meta->hs}; + weights.attn_k_w[i] = tensorCreate(shape_k, 2, meta->dtype, device, device_id); + size_t shape_kb[1] = {meta->nkvh * meta->dh}; + weights.attn_k_b[i] = tensorCreate(shape_kb, 1, meta->dtype, device, device_id); + weights.attn_v_w[i] = tensorCreate(shape_k, 2, meta->dtype, device, device_id); + weights.attn_v_b[i] = tensorCreate(shape_kb, 1, meta->dtype, device, device_id); + size_t shape_o[2] = {meta->hs, meta->hs}; + weights.attn_o_w[i] = tensorCreate(shape_o, 2, meta->dtype, device, device_id); + weights.mlp_norm_w[i] = tensorCreate(shape_norm, 1, meta->dtype, device, device_id); + size_t shape_gate[2] = {meta->di, meta->hs}; + weights.mlp_gate_w[i] = tensorCreate(shape_gate, 2, meta->dtype, device, device_id); + weights.mlp_up_w[i] = tensorCreate(shape_gate, 2, meta->dtype, device, device_id); + size_t shape_down[2] = {meta->hs, meta->di}; + weights.mlp_down_w[i] = tensorCreate(shape_down, 2, meta->dtype, device, device_id); + } + } +}; + +__export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) { + return new LlaisysQwen2Model(meta, device, device_ids[0]); +} + +__export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model *model) { + delete model; +} + +__export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model *model) { + return &model->weights; +} + +static llaisysTensor_t resolve_weight_tensor(LlaisysQwen2Model *model, const std::string &name) { + if (name == "model.embed_tokens.weight") return model->weights.in_embed; + if (name == "lm_head.weight") return model->weights.out_embed; + if (name == "model.norm.weight") return model->weights.out_norm_w; + + const std::string prefix = "model.layers."; + if (name.rfind(prefix, 0) == 0) { + size_t idx_start = prefix.size(); + size_t idx_end = name.find('.', idx_start); + if (idx_end == std::string::npos) return nullptr; + size_t layer = static_cast(std::stoul(name.substr(idx_start, idx_end - idx_start))); + if (layer >= model->meta.nlayer) return nullptr; + std::string suffix = name.substr(idx_end + 1); + + if (suffix == "input_layernorm.weight") return model->weights.attn_norm_w[layer]; + if (suffix == "post_attention_layernorm.weight") return model->weights.mlp_norm_w[layer]; + + if (suffix == "self_attn.q_proj.weight") return model->weights.attn_q_w[layer]; + if (suffix == "self_attn.q_proj.bias") return model->weights.attn_q_b[layer]; + if (suffix == "self_attn.k_proj.weight") return model->weights.attn_k_w[layer]; + if (suffix == "self_attn.k_proj.bias") return model->weights.attn_k_b[layer]; + if (suffix == "self_attn.v_proj.weight") return model->weights.attn_v_w[layer]; + if (suffix == "self_attn.v_proj.bias") return model->weights.attn_v_b[layer]; + if (suffix == "self_attn.o_proj.weight") return model->weights.attn_o_w[layer]; + + if (suffix == "mlp.gate_proj.weight") return model->weights.mlp_gate_w[layer]; + if (suffix == "mlp.up_proj.weight") return model->weights.mlp_up_w[layer]; + if (suffix == "mlp.down_proj.weight") return model->weights.mlp_down_w[layer]; + } + + return nullptr; +} + +__export void llaisysQwen2LoadWeight( + struct LlaisysQwen2Model *model, + const char *name, + const void *data, + size_t *shape, + size_t ndim, + llaisysDataType_t dtype) { + if (!model || !name || !data) return; + llaisysTensor_t tensor = resolve_weight_tensor(model, std::string(name)); + if (!tensor) return; + if (tensorGetDataType(tensor) != dtype) { + return; + } + tensorLoad(tensor, data); +} + +__export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken, size_t start_pos) { + return model->infer(token_ids, ntoken, start_pos); +} diff --git a/src/llaisys/ops.cc b/src/llaisys/ops.cc index c99fbc32..3adfc510 100644 --- a/src/llaisys/ops.cc +++ b/src/llaisys/ops.cc @@ -23,6 +23,10 @@ __C { llaisys::ops::embedding(out->tensor, index->tensor, weight->tensor); } void llaisysLinear(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, llaisysTensor_t bias) { + if (bias == nullptr) { + llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, nullptr); + return; + } llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, bias->tensor); } void llaisysRearrange(llaisysTensor_t out, llaisysTensor_t in) { diff --git a/src/ops/add/cpu/add_cpu.cpp b/src/ops/add/cpu/add_cpu.cpp index 47f6a3d4..a3cdc91e 100644 --- a/src/ops/add/cpu/add_cpu.cpp +++ b/src/ops/add/cpu/add_cpu.cpp @@ -4,28 +4,275 @@ #include +#ifdef __C + #define LLAISYS_EXTERN_C 1 + #pragma push_macro("__C") + #undef __C +#endif + +#include // AVX2 for vectorized conversions + +#ifdef LLAISYS_EXTERN_C + #undef LLAISYS_EXTERN_C + #pragma pop_macro("__C") +#endif + + template -void add_(T *c, const T *a, const T *b, size_t numel) { +void add_(std::byte *c, const std::byte *a, const std::byte *b, size_t numel); + +template<> +void add_(std::byte *c, const std::byte *a, const std::byte *b, size_t numel) +{ + float* c_typed = reinterpret_cast(c); + const float* a_typed = reinterpret_cast(a); + const float* b_typed = reinterpret_cast(b); + +#ifdef __AVX2__ + + size_t last = numel - (numel % 8); + + for (size_t i = 0; i < last; i += 8) + { + __m256 va = _mm256_loadu_ps(a_typed + i); + __m256 vb = _mm256_loadu_ps(b_typed + i); + + __m256 vc = _mm256_add_ps(va, vb); + + _mm256_storeu_ps(c_typed + i, vc); + } + + for (size_t i = last; i < numel; i++) + c_typed[i] = a_typed[i] + b_typed[i]; + +#else + + for (size_t i = 0; i < numel; i++) + c_typed[i] = a_typed[i] + b_typed[i]; + +#endif +} + +template<> +void add_(std::byte *c, const std::byte *a, const std::byte *b, size_t numel) +{ + uint16_t* c16 = reinterpret_cast(c); + const uint16_t* a16 = reinterpret_cast(a); + const uint16_t* b16 = reinterpret_cast(b); + +#if defined(__AVX2__) && defined(__F16C__) + + size_t last = numel - (numel % 16); + + for (size_t i = 0; i < last; i += 16) + { + __m256i va = _mm256_loadu_si256((__m256i*)(a16 + i)); + __m256i vb = _mm256_loadu_si256((__m256i*)(b16 + i)); + + __m128i alo = _mm256_extracti128_si256(va,0); + __m128i ahi = _mm256_extracti128_si256(va,1); + + __m128i blo = _mm256_extracti128_si256(vb,0); + __m128i bhi = _mm256_extracti128_si256(vb,1); + + __m256 fa0 = _mm256_cvtph_ps(alo); + __m256 fa1 = _mm256_cvtph_ps(ahi); + + __m256 fb0 = _mm256_cvtph_ps(blo); + __m256 fb1 = _mm256_cvtph_ps(bhi); + + __m256 fc0 = _mm256_add_ps(fa0,fb0); + __m256 fc1 = _mm256_add_ps(fa1,fb1); + + __m128i lo = _mm256_cvtps_ph(fc0,0); + __m128i hi = _mm256_cvtps_ph(fc1,0); + + __m256i packed = _mm256_set_m128i(hi,lo); + + _mm256_storeu_si256((__m256i*)(c16 + i), packed); + } + + for (size_t i = last; i < numel; i++) { + reinterpret_cast(c)[i] = llaisys::utils::cast( + llaisys::utils::cast(reinterpret_cast(a)[i]) + + llaisys::utils::cast(reinterpret_cast(b)[i])); + } + +#else + for (size_t i = 0; i < numel; i++) { - if constexpr (std::is_same_v || std::is_same_v) { - c[i] = llaisys::utils::cast(llaisys::utils::cast(a[i]) + llaisys::utils::cast(b[i])); - } else { - c[i] = a[i] + b[i]; - } + reinterpret_cast(c)[i] = llaisys::utils::cast( + llaisys::utils::cast(reinterpret_cast(a)[i]) + + llaisys::utils::cast(reinterpret_cast(b)[i])); } + +#endif } +template<> +void add_(std::byte *c, const std::byte *a, const std::byte *b, size_t numel) +{ + const uint16_t* a16 = reinterpret_cast(a); + const uint16_t* b16 = reinterpret_cast(b); + uint16_t* c16 = reinterpret_cast(c); + +#ifdef __AVX2__ + + size_t last = numel - (numel % 16); + +#ifdef _OPENMP +#pragma omp parallel for if(numel > 16384) +#endif + for (size_t i = 0; i < last; i += 16) + { + __m256i va = _mm256_loadu_si256((const __m256i*)(a16 + i)); + __m256i vb = _mm256_loadu_si256((const __m256i*)(b16 + i)); + + __m128i alo = _mm256_extracti128_si256(va, 0); + __m128i ahi = _mm256_extracti128_si256(va, 1); + + __m128i blo = _mm256_extracti128_si256(vb, 0); + __m128i bhi = _mm256_extracti128_si256(vb, 1); + + __m256i alo32 = _mm256_slli_epi32(_mm256_cvtepu16_epi32(alo), 16); + __m256i ahi32 = _mm256_slli_epi32(_mm256_cvtepu16_epi32(ahi), 16); + + __m256i blo32 = _mm256_slli_epi32(_mm256_cvtepu16_epi32(blo), 16); + __m256i bhi32 = _mm256_slli_epi32(_mm256_cvtepu16_epi32(bhi), 16); + + __m256 fa_lo = _mm256_castsi256_ps(alo32); + __m256 fa_hi = _mm256_castsi256_ps(ahi32); + + __m256 fb_lo = _mm256_castsi256_ps(blo32); + __m256 fb_hi = _mm256_castsi256_ps(bhi32); + + __m256 fc_lo = _mm256_add_ps(fa_lo, fb_lo); + __m256 fc_hi = _mm256_add_ps(fa_hi, fb_hi); + + __m256i lo_i = _mm256_castps_si256(fc_lo); + __m256i hi_i = _mm256_castps_si256(fc_hi); + + __m256i bias = _mm256_set1_epi32(0x7FFF); + + __m256i lsb_lo = _mm256_and_si256(_mm256_srli_epi32(lo_i,16), _mm256_set1_epi32(1)); + __m256i lsb_hi = _mm256_and_si256(_mm256_srli_epi32(hi_i,16), _mm256_set1_epi32(1)); + + __m256i round_lo = _mm256_add_epi32(lo_i,_mm256_add_epi32(bias,lsb_lo)); + __m256i round_hi = _mm256_add_epi32(hi_i,_mm256_add_epi32(bias,lsb_hi)); + + __m256i shr_lo = _mm256_srli_epi32(round_lo,16); + __m256i shr_hi = _mm256_srli_epi32(round_hi,16); + + __m128i lo = _mm_packus_epi32(_mm256_extracti128_si256(shr_lo,0), + _mm256_extracti128_si256(shr_lo,1)); + + __m128i hi = _mm_packus_epi32(_mm256_extracti128_si256(shr_hi,0), + _mm256_extracti128_si256(shr_hi,1)); + + __m256i packed = _mm256_set_m128i(hi,lo); + + _mm256_storeu_si256((__m256i*)(c16 + i), packed); + } + + for (size_t i = last; i < numel; i++) { + reinterpret_cast(c)[i] = llaisys::utils::cast( + llaisys::utils::cast(reinterpret_cast(a)[i]) + + llaisys::utils::cast(reinterpret_cast(b)[i])); + } + +#else + for (size_t i = 0; i < numel; i++) { + reinterpret_cast(c)[i] = llaisys::utils::cast( + llaisys::utils::cast(reinterpret_cast(a)[i]) + + llaisys::utils::cast(reinterpret_cast(b)[i])); + } +#endif +} + +template<> +void add_(std::byte *c, const std::byte *a, const std::byte *b, size_t numel) +{ + int32_t* c_typed = reinterpret_cast(c); + const int32_t* a_typed = reinterpret_cast(a); + const int32_t* b_typed = reinterpret_cast(b); + +#ifdef __AVX2__ + + size_t last = numel - (numel % 8); + + for (size_t i = 0; i < last; i += 8) + { + __m256i va = _mm256_loadu_si256((const __m256i*)(a_typed + i)); + __m256i vb = _mm256_loadu_si256((const __m256i*)(b_typed + i)); + + __m256i vc = _mm256_add_epi32(va, vb); + + _mm256_storeu_si256((__m256i*)(c_typed + i), vc); + } + + for (size_t i = last; i < numel; i++) + c_typed[i] = a_typed[i] + b_typed[i]; + +#else + + for (size_t i = 0; i < numel; i++) + c_typed[i] = a_typed[i] + b_typed[i]; + +#endif +} + +template<> +void add_(std::byte *c, const std::byte *a, const std::byte *b, size_t numel) +{ + double* c_typed = reinterpret_cast(c); + const double* a_typed = reinterpret_cast(a); + const double* b_typed = reinterpret_cast(b); + +#ifdef __AVX2__ + + size_t last = numel - (numel % 4); + + for (size_t i = 0; i < last; i += 4) + { + __m256d va = _mm256_loadu_pd(a_typed + i); + __m256d vb = _mm256_loadu_pd(b_typed + i); + + __m256d vc = _mm256_add_pd(va, vb); + + _mm256_storeu_pd(c_typed + i, vc); + } + + for (size_t i = last; i < numel; i++) + c_typed[i] = a_typed[i] + b_typed[i]; + +#else + + for (size_t i = 0; i < numel; i++) + c_typed[i] = a_typed[i] + b_typed[i]; + +#endif +} + +template +void add_(std::byte *c, const std::byte *a, const std::byte *b, size_t numel) { + T *c_typed = reinterpret_cast(c); + const T *a_typed = reinterpret_cast(a); + const T *b_typed = reinterpret_cast(b); + for (size_t i = 0; i < numel; i++) { + c_typed[i] = llaisys::utils::cast(llaisys::utils::cast(a_typed[i]) + llaisys::utils::cast(b_typed[i])); + } +} + +#define DISPATCH_SWIGLU(dtype, ctype) case dtype: add_(c, a, b, numel); break; + namespace llaisys::ops::cpu { void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel) { switch (type) { - case LLAISYS_DTYPE_F32: - return add_(reinterpret_cast(c), reinterpret_cast(a), reinterpret_cast(b), numel); - case LLAISYS_DTYPE_BF16: - return add_(reinterpret_cast(c), reinterpret_cast(a), - reinterpret_cast(b), numel); - case LLAISYS_DTYPE_F16: - return add_(reinterpret_cast(c), reinterpret_cast(a), - reinterpret_cast(b), numel); + DISPATCH_SWIGLU(LLAISYS_DTYPE_F32, float) + DISPATCH_SWIGLU(LLAISYS_DTYPE_BF16, llaisys::bf16_t) + DISPATCH_SWIGLU(LLAISYS_DTYPE_F16, llaisys::fp16_t) + DISPATCH_SWIGLU(LLAISYS_DTYPE_I32, int32_t) + DISPATCH_SWIGLU(LLAISYS_DTYPE_F64, double) default: EXCEPTION_UNSUPPORTED_DATATYPE(type); } diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 00000000..ae6243cb --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,46 @@ +#include "argmax_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +void argmax_(std::byte *max_idx, std::byte *max_val, const std::byte *vals, size_t numel) { + int64_t *max_idx_typed = reinterpret_cast(max_idx); + T *max_val_typed = reinterpret_cast(max_val); + const T *vals_typed = reinterpret_cast(vals); + + *max_idx_typed = 0; + *max_val_typed = vals_typed[0]; + for (size_t i = 0; i < numel; i++) { + if constexpr (std::is_same_v || std::is_same_v) { + float temp = llaisys::utils::cast(vals_typed[i]); + if (temp > llaisys::utils::cast(*max_val_typed)) { + *max_idx_typed = static_cast(i); + *max_val_typed = llaisys::utils::cast(temp); + } + } else { + T temp = vals_typed[i]; + if (temp > *max_val_typed) { + *max_idx_typed = static_cast(i); + *max_val_typed = temp; + } + } + } +} + +#define DISPATCH_ARGMAX(dtype, ctype) case dtype: argmax_(max_idx, max_val, vals, numel); break; + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) { + switch (type) { + DISPATCH_ARGMAX(LLAISYS_DTYPE_F32, float) + DISPATCH_ARGMAX(LLAISYS_DTYPE_BF16, llaisys::bf16_t) + DISPATCH_ARGMAX(LLAISYS_DTYPE_F16, llaisys::fp16_t) + DISPATCH_ARGMAX(LLAISYS_DTYPE_I32, int32_t) + DISPATCH_ARGMAX(LLAISYS_DTYPE_F64, double) + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 00000000..ea235143 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t vals_type, size_t size); +} \ No newline at end of file diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d42..4b89b92c 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,48 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/argmax_cpu.hpp" + namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(max_idx, max_val, vals); + // Only support contiguous inputs for now. + ASSERT(max_idx->isContiguous() && max_val->isContiguous() && vals->isContiguous(), "Argmax: all tensors must be contiguous."); + /* Deprecated + Tests use i64. + // // Data of max_idx should be of type size_t + // // size_t 应该是 U32 或者 U64 吧... + // ASSERT(max_idx->dtype() == LLAISYS_DTYPE_U32 || max_idx->dtype() == LLAISYS_DTYPE_U64, + // "Argmax: max_idx should in type U32 or U64"); + // Data of max_val and vals should be in same type + */ + + // 假设 idx_type 永远是 LLAISYS_DTYPE_I64 + ASSERT(max_idx->dtype() == LLAISYS_DTYPE_I64, "Argmax: idx_type must be LLAISYS_DTYPE_I64"); + + ASSERT(max_val->dtype() == vals->dtype(), "Argmax: max_val and vals should have be of same type"); + + ASSERT(vals->numel(), "Argmax: the tensor must be non-empty"); + + // always support cpu calculation + if (vals->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); + } + + llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId()); + + switch (vals->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 00000000..97dcfa72 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,35 @@ +#include "embedding_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +void embedding_(std::byte *out_raw, const int64_t *index, const size_t i_size, + const std::byte *weight_raw, const size_t w_rows, const size_t w_cols) { + T * out = reinterpret_cast(out_raw); + const T * weight = reinterpret_cast(weight_raw); + for (size_t i = 0; i < i_size; i++) { + for (size_t j = 0; j < w_cols; j++) { + out[i * w_cols + j] = weight[index[i] * w_cols + j]; + } + } +} + +#define DISPATCH_EMBEDDING(dtype, ctype) case dtype: embedding_(out, reinterpret_cast(index), i_size, weight, w_rows, w_cols); break; + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const size_t i_size, + const std::byte *weight, const size_t w_rows, const size_t w_cols, + llaisysDataType_t type) { + switch (type) { + DISPATCH_EMBEDDING(LLAISYS_DTYPE_F32, float) + DISPATCH_EMBEDDING(LLAISYS_DTYPE_BF16, llaisys::bf16_t) + DISPATCH_EMBEDDING(LLAISYS_DTYPE_F16, llaisys::fp16_t) + DISPATCH_EMBEDDING(LLAISYS_DTYPE_I32, int32_t) + DISPATCH_EMBEDDING(LLAISYS_DTYPE_F64, double) + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 00000000..65760bdd --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,10 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const size_t i_size, + const std::byte *weight, const size_t w_rows, const size_t w_cols, + llaisysDataType_t type); +} \ No newline at end of file diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d0..b9cd6c12 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,45 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/embedding_cpu.hpp" + namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, index, weight); + + CHECK_SAME_DTYPE(out->dtype(), weight->dtype()); + CHECK_SAME_DTYPE(index->dtype(), LLAISYS_DTYPE_I64); + + // 抱歉时间来不及了, 这个实现很不elegant. 维度完全没有拓展性了. + CHECK_SAME_SHAPE(out->shape()[0], index->shape()[0]); + CHECK_SAME_SHAPE(out->shape()[1], weight->shape()[1]); + + // only support contiguous for now + ASSERT(out->isContiguous() && index->isContiguous() && weight->isContiguous(), "Add: tensors out and index must be contiguous."); + + // always support cpu calculation + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::embedding(out->data(), index->data(), index->shape()[0], + weight->data(), weight->shape()[0], weight->shape()[1], + weight->dtype()); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::embedding(out->data(), index->data(), index->shape()[0], + weight->data(), weight->shape()[0], weight->shape()[1], + weight->dtype()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 00000000..866e4926 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,99 @@ +#include "linear_cpu.hpp" +#include "openblas_capable_array.hpp" + +#include "../../../utils.hpp" + +#include + +#include + +template +void linear_(std::byte *out_raw, const std::byte *in_raw, const std::byte *weight_raw, + const size_t M, const size_t N, const size_t K, const std::byte *bias_raw, + llaisysDataType_t dtype) { + // C[:,M,N] = A[:,M,K] * B[:,K,N] + // out[:,M,N] = in[:,M,K] * weight[:,N,K]^T + bias[:,N] (broadcast) + T *out = reinterpret_cast(out_raw); + const T *in = reinterpret_cast(in_raw); + const T *weight = reinterpret_cast(weight_raw); + const T *bias = reinterpret_cast(bias_raw); + + // OpenBLAS only supports float/double. For other types we cast into float/double + // buffers, compute, then cast back. + const llaisysDataType_t storage_dtype = + (dtype == LLAISYS_DTYPE_F64 || dtype == LLAISYS_DTYPE_I32) ? LLAISYS_DTYPE_F64 : LLAISYS_DTYPE_F32; + + llaisys::ops::linear::cpu::OpenBlasCapableArray in_aligned(in, M * K, storage_dtype); + llaisys::ops::linear::cpu::OpenBlasCapableArray weight_aligned(weight, K * N, storage_dtype); + llaisys::ops::linear::cpu::OpenBlasCapableArray out_aligned(out, M * N, storage_dtype); + + if (out_aligned.dtype() == LLAISYS_DTYPE_F32) { + float *A = reinterpret_cast(in_aligned.data()); + float *B = reinterpret_cast(weight_aligned.data()); + float *C = reinterpret_cast(out_aligned.data()); + + if (bias != nullptr) { + out_aligned.broadcast_row(bias, M, N); + } else { + out_aligned.zeros(); + } + + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, + CblasTrans, + M, N, K, + 1.0f, + A, K, + B, K, + 1.0f, + C, N + ); + } else if (out_aligned.dtype() == LLAISYS_DTYPE_F64) { + double *A = reinterpret_cast(in_aligned.data()); + double *B = reinterpret_cast(weight_aligned.data()); + double *C = reinterpret_cast(out_aligned.data()); + + if (bias != nullptr) { + out_aligned.broadcast_row(bias, M, N); + } else { + out_aligned.zeros(); + } + + cblas_dgemm( + CblasRowMajor, + CblasNoTrans, + CblasTrans, + M, N, K, + 1.0, + A, K, + B, K, + 1.0, + C, N + ); + } else { + throw std::invalid_argument("Unsupported data type for linear_"); + } + + // Only cast back if we allocated an intermediate buffer. + if (out_aligned.owns_data()) { + out_aligned.cast_back(out); + } +} + +#define DISPATCH_LINEAR(dtype, ctype) case dtype: linear_(out, in, weight, M, N, K, bias, type); break; + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + const size_t M, const size_t N, const size_t K, llaisysDataType_t type) { + switch (type) { + DISPATCH_LINEAR(LLAISYS_DTYPE_F32, float) + DISPATCH_LINEAR(LLAISYS_DTYPE_BF16, llaisys::bf16_t) + DISPATCH_LINEAR(LLAISYS_DTYPE_F16, llaisys::fp16_t) + DISPATCH_LINEAR(LLAISYS_DTYPE_I32, int32_t) + DISPATCH_LINEAR(LLAISYS_DTYPE_F64, double) + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 00000000..a53140c8 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + const size_t M, const size_t N, const size_t K, llaisysDataType_t type); +} \ No newline at end of file diff --git a/src/ops/linear/cpu/openblas_capable_array.hpp b/src/ops/linear/cpu/openblas_capable_array.hpp new file mode 100644 index 00000000..4c2a0114 --- /dev/null +++ b/src/ops/linear/cpu/openblas_capable_array.hpp @@ -0,0 +1,271 @@ +#pragma once + +#include "llaisys.h" + +#include +#include + +#include +#include +#include + +#ifdef __C + #define LLAISYS_EXTERN_C 1 + #pragma push_macro("__C") + #undef __C +#endif + +#include // 只在这个 cpp 内部使用 + +#ifdef LLAISYS_EXTERN_C + #undef LLAISYS_EXTERN_C + #pragma pop_macro("__C") +#endif + +#include + +#include "../../../utils.hpp" // for llaisys::utils::cast + fp16/bf16 helpers + +namespace llaisys::ops::linear::cpu { +// AVX2 + F16C vectorized conversions for OpenBLAS +class OpenBlasCapableArray { +public: + // Allocate an intermediate buffer suitable for OpenBLAS (F32/F64). + explicit OpenBlasCapableArray(size_t n, llaisysDataType_t dtype) + : numel_(n), data_(nullptr), dtype_(dtype), owns_data_(true) { + size_t elem_size = (dtype_ == LLAISYS_DTYPE_F64 ? sizeof(double) : sizeof(float)); + if (posix_memalign(&data_, 32, n * elem_size) != 0) { + throw std::bad_alloc(); + } + } + + // When the source is already float/double, we can use the caller's pointer directly + // (no allocation, no own/free). This avoids extra copies when OpenBLAS can work on + // the original buffer. + template + OpenBlasCapableArray(const T* src, size_t n, llaisysDataType_t dtype) + : numel_(n), data_(const_cast(static_cast(src))), dtype_(dtype), owns_data_(false) { + if constexpr (std::is_same_v) { + if (dtype_ != LLAISYS_DTYPE_F32) { + throw std::invalid_argument("dtype mismatch for float source"); + } + } else if constexpr (std::is_same_v) { + if (dtype_ != LLAISYS_DTYPE_F64) { + throw std::invalid_argument("dtype mismatch for double source"); + } + } else { + // For other types, we still need to allocate and cast. + // We do this by falling back to the “normal” construction path. + size_t elem_size = (dtype_ == LLAISYS_DTYPE_F64 ? sizeof(double) : sizeof(float)); + if (posix_memalign(&data_, 32, n * elem_size) != 0) { + throw std::bad_alloc(); + } + owns_data_ = true; + cast_from(src); + } + } + + // No copies / moves for simplicity + OpenBlasCapableArray(const OpenBlasCapableArray&) = delete; + OpenBlasCapableArray& operator=(const OpenBlasCapableArray&) = delete; + OpenBlasCapableArray(OpenBlasCapableArray&& other) = delete; + + ~OpenBlasCapableArray() { + if (owns_data_) { + std::free(data_); + } + } + + // Fill the internal buffer from user-provided data + template + void cast_from(const T* src) { + if (dtype_ == LLAISYS_DTYPE_F32) { + cast_helper(src, reinterpret_cast(data_), numel_); + } else if (dtype_ == LLAISYS_DTYPE_F64) { + cast_helper(src, reinterpret_cast(data_), numel_); + } else { + throw std::runtime_error("OpenBlasCapableArray only supports F32/F64 storage"); + } + } + + // Broadcast a single row `src` (length cols) into a `rows x cols` block. + template + void broadcast_row(const T* src, size_t rows, size_t cols) { + if (rows == 0 || cols == 0) return; + if (rows * cols != numel_) { + throw std::out_of_range("broadcast_row is not equal to buffer size"); + } + + if (dtype_ == LLAISYS_DTYPE_F32) { + float* base = reinterpret_cast(data_); + cast_helper(src, base, cols); +#ifdef _OPENMP + #pragma omp parallel for if(numel_ > 16384) +#endif + for (size_t r = 1; r < rows; ++r) { + std::memcpy(base + r * cols, base, cols * sizeof(float)); + } + } else { + double* base = reinterpret_cast(data_); + cast_helper(src, base, cols); +#ifdef _OPENMP + #pragma omp parallel for if(numel_ > 16384) +#endif + for (size_t r = 1; r < rows; ++r) { + std::memcpy(base + r * cols, base, cols * sizeof(double)); + } + } + } + + // Fill the internal buffer with zeros. + // Uses AVX stores + OpenMP for large buffers to maximize throughput. + void zeros() { + if (numel_ == 0) return; + + if (dtype_ == LLAISYS_DTYPE_F32) { + float* fdata = reinterpret_cast(data_); + const size_t last_block_start = (numel_ / 8) * 8; + __m256 zero = _mm256_setzero_ps(); + +#ifdef _OPENMP + #pragma omp parallel for if(numel_ > 16384) +#endif + for (size_t i = 0; i < last_block_start; i += 8) { + _mm256_store_ps(fdata + i, zero); + } + for (size_t i = last_block_start; i < numel_; ++i) { + fdata[i] = 0.0f; + } + + } else if (dtype_ == LLAISYS_DTYPE_F64) { + double* ddata = reinterpret_cast(data_); + const size_t last_block_start = (numel_ / 4) * 4; + __m256d zero = _mm256_setzero_pd(); + +#ifdef _OPENMP + #pragma omp parallel for if(numel_ > 16384) +#endif + for (size_t i = 0; i < last_block_start; i += 4) { + _mm256_store_pd(ddata + i, zero); + } + for (size_t i = last_block_start; i < numel_; ++i) { + ddata[i] = 0.0; + } + + } else { + // Fallback: memset works for all bitwise-zero types and is often optimized. + const size_t elem_size = (dtype_ == LLAISYS_DTYPE_F64 ? sizeof(double) : sizeof(float)); + std::memset(data_, 0, numel_ * elem_size); + } + } + + // Convert the internal buffer back to `T`. + template + void cast_back(T* dst) const { + if (dtype_ == LLAISYS_DTYPE_F32) { + const float* data_float = reinterpret_cast(data_); + if constexpr (std::is_same_v) { + uint16_t* dst_raw = reinterpret_cast(static_cast(dst)); + llaisys::utils::fp32_to_fp16_vec(data_float, dst_raw, numel_); + + } else if constexpr (std::is_same_v) { + uint16_t* dst_raw = reinterpret_cast(static_cast(dst)); + llaisys::utils::fp32_to_bf16_vec(data_float, dst_raw, numel_); + + } else { + // float or other types: fallback to scalar cast + for (size_t i = 0; i < numel_; i++) { + dst[i] = llaisys::utils::cast(data_float[i]); + } + } + + } else { + const double* data_double = reinterpret_cast(data_); + if constexpr (std::is_same_v) { + size_t last_block_start = numel_ - (numel_ % 8); + +#ifdef _OPENMP + #pragma omp parallel for if(numel_ > 16384) +#endif + for (size_t i = 0; i < last_block_start; i += 8) { + __m256d d1 = _mm256_load_pd(data_double + i); + __m256d d2 = _mm256_load_pd(data_double + i + 4); + __m128i i1 = _mm256_cvttpd_epi32(d1); + __m128i i2 = _mm256_cvttpd_epi32(d2); + __m256i combined = _mm256_set_m128i(i2, i1); + _mm256_storeu_si256((__m256i*)(dst + i), combined); + } + for (size_t i = last_block_start; i < numel_; i++) { + dst[i] = static_cast(data_double[i]); + } + } else { + for (size_t i = 0; i < numel_; i++) { + dst[i] = llaisys::utils::cast(data_double[i]); + } + } + } + } + + // Getters + size_t numel() const noexcept { return numel_; } + void* data() const noexcept { return data_; } + llaisysDataType_t dtype() const noexcept { return dtype_; } + bool owns_data() const noexcept { return owns_data_; } + +private: + template + static void cast_helper(const T* src, float* dst, size_t n) { + if constexpr (std::is_same_v) { + std::memcpy(dst, src, n * sizeof(float)); + } else if constexpr (std::is_same_v) { + const uint16_t* raw_src = reinterpret_cast(static_cast(src)); + llaisys::utils::fp16_to_fp32_vec(raw_src, dst, n); + + } else if constexpr (std::is_same_v) { + const uint16_t* raw_src = reinterpret_cast(static_cast(src)); + llaisys::utils::bf16_to_fp32_vec(raw_src, dst, n); + + } else { + for (size_t i = 0; i < n; i++) { + dst[i] = llaisys::utils::cast(src[i]); + } + } + } + + template + static void cast_helper(const T* src, double* dst, size_t n) { + if constexpr (std::is_same_v) { + std::memcpy(dst, src, n * sizeof(double)); + } else if constexpr (std::is_same_v) { + const size_t last_block_start = n - (n % 8); + +#ifdef _OPENMP + #pragma omp parallel for if(n > 16384) +#endif + for (size_t i = 0; i < last_block_start; i += 8) { + __m256i vi = _mm256_loadu_si256((__m256i*)(src + i)); + __m128i lo = _mm256_extracti128_si256(vi, 0); + __m128i hi = _mm256_extracti128_si256(vi, 1); + __m256d dlo = _mm256_cvtepi32_pd(lo); + __m256d dhi = _mm256_cvtepi32_pd(hi); + _mm256_store_pd(dst + i, dlo); + _mm256_store_pd(dst + i + 4, dhi); + } + for (size_t i = last_block_start; i < n; i++) { + dst[i] = llaisys::utils::cast(src[i]); + } + + } else { + for (size_t i = 0; i < n; i++) { + dst[i] = llaisys::utils::cast(src[i]); + } + } + } + +private: + size_t numel_; + void* data_; + llaisysDataType_t dtype_; + bool owns_data_ = true; +}; +} // namespace llaisys::ops::linear::cpu diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f865..104d8395 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,66 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/linear_cpu.hpp" + namespace llaisys::ops { -void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); +void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias = nullptr) { + // C[:,M,N] = A[:,M,K] * B[:,K,N] + // out[:,M,N] = in[:,M,K] * weight[:,N,K]^T + bias[:,M,N] + const size_t M = out->shape()[out->shape().size() - 2]; + const size_t N = out->shape()[out->shape().size() - 1]; + const size_t K = in->shape()[in->shape().size() - 1]; + + CHECK_SAME_DEVICE(out, in, weight); + + CHECK_SAME_SHAPE(M, in->shape()[in->shape().size() - 2]); + CHECK_SAME_SHAPE(N, weight->shape()[weight->shape().size() - 2]); + CHECK_SAME_SHAPE(K, weight->shape()[weight->shape().size() - 1]); + + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), + "Add: all tensors must be contiguous."); + + // bias 形状应该是 (N,), 不是(M, N,) + // if (bias != nullptr) { + // CHECK_SAME_DEVICE(out, bias); + // CHECK_SAME_SHAPE(M, bias->shape()[bias->shape().size() - 2]); + // CHECK_SAME_SHAPE(N, bias->shape()[bias->shape().size() - 1]); + // ASSERT(bias->isContiguous(), "Add: all tensors must be contiguous."); + // } + + if (bias != nullptr) { + CHECK_SAME_DEVICE(out, bias); + CHECK_SAME_SHAPE(N, bias->shape()[bias->shape().size() - 1]); + ASSERT(bias->isContiguous(), "Add: all tensors must be contiguous."); + } + + // always support cpu calculation + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + if (bias != nullptr) { + return cpu::linear(out->data(), in->data(), weight->data(), bias->data(), M, N, K, weight->dtype()); + } + return cpu::linear(out->data(), in->data(), weight->data(), nullptr, M, N, K, weight->dtype()); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + if (bias != nullptr) { + return cpu::linear(out->data(), in->data(), weight->data(), bias->data(), M, N, K, weight->dtype()); + } + return cpu::linear(out->data(), in->data(), weight->data(), nullptr, M, N, K, weight->dtype()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } + + } } // namespace llaisys::ops diff --git a/src/ops/rms_norm/cpu/rms_cpu.cpp b/src/ops/rms_norm/cpu/rms_cpu.cpp new file mode 100644 index 00000000..be291bcc --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_cpu.cpp @@ -0,0 +1,97 @@ +#include "rms_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +void rms_norm_(std::byte *out_raw, const std::byte *in_raw, const std::byte *weight_raw, size_t w_size, size_t numel, float eps) { + T *out = reinterpret_cast(out_raw); + const T *in = reinterpret_cast(in_raw); + const T *weight = reinterpret_cast(weight_raw); + + /* 算子理解错了, 分母应该是行的平方和, 不是列的 + // 分母 + // naive 优化, 避免多次不连续列访问 + float *downs = (float*) calloc(w_size, sizeof(float)); + + for (size_t i = 0; i < numel; i++) { + if constexpr (std::is_same_v || std::is_same_v) { + downs[i % w_size] += powf(llaisys::utils::cast(in[numel]), 2); + } else { + downs[i % w_size] += powf(in[numel], 2); + } + } + + for (size_t i = 0; i < w_size; i++) { + downs[i] = sqrt(downs[i] + eps); + } + + for (size_t i = 0; i < numel; i++) { + if constexpr (std::is_same_v || std::is_same_v) { + out[i] = llaisys::utils::cast( + llaisys::utils::cast(weight[i % w_size]) + * llaisys::utils::cast(in[i]) + / downs[i % w_size] + ); + } else { + out[i] = weight[i % w_size] * in[i] / downs[i % w_size]; + } + } + */ + + // README 给的公式有问题!!!!! + // $Y_i = \frac{W_i \times X_i}{\sqrt{(\sum_{j=1}^n X_j^2) + \epsilon}}$ 不对 + // $Y_i = \frac{W_i \times X_i}{\sqrt{( \mathbf{\frac{1}{n}} \sum_{j=1}^n X_j^2) + \epsilon}}$ + // ^^^^^^^^^^^^^^^^^^ + size_t num_rows = numel / w_size; + for (size_t row = 0; row < num_rows; ++row) { + float sum_sq = 0.0f; + for (size_t col = 0; col < w_size; ++col) { + size_t idx = row * w_size + col; + float val; + if constexpr (std::is_same_v || std::is_same_v) { + val = llaisys::utils::cast(in[idx]); + } else { + val = static_cast(in[idx]); + } + sum_sq += val * val; + } + // 这里公式有问题 + float rms = sqrtf((sum_sq / static_cast(w_size)) + eps); + for (size_t col = 0; col < w_size; ++col) { + size_t idx = row * w_size + col; + float w_val; + float in_val; + if constexpr (std::is_same_v || std::is_same_v) { + w_val = llaisys::utils::cast(weight[col]); + in_val = llaisys::utils::cast(in[idx]); + } else { + w_val = static_cast(weight[col]); + in_val = static_cast(in[idx]); + } + float out_val = w_val * in_val / rms; + if constexpr (std::is_same_v || std::is_same_v) { + out[idx] = llaisys::utils::cast(out_val); + } else { + out[idx] = static_cast(out_val); + } + } + } +} + +#define DISPATCH_SELF_ATTENTION(dtype, ctype) case dtype: rms_norm_(out, in, weight, w_size, numel, eps); break; + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, size_t w_size, size_t numel, float eps) { + switch (type) { + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_F32, float) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_BF16, llaisys::bf16_t) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_F16, llaisys::fp16_t) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_I32, int32_t) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_F64, double) + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rms_norm/cpu/rms_cpu.hpp b/src/ops/rms_norm/cpu/rms_cpu.hpp new file mode 100644 index 00000000..8f2b4e9a --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, size_t w_size, size_t size, float eps); +} \ No newline at end of file diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9..7d72545c 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,36 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rms_cpu.hpp" + namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight); + // Only support contiguous inputs with same shape for now. + CHECK_SAME_SHAPE(out->shape(), in->shape()); + CHECK_SAME_SHAPE(out->shape()[1], weight->shape()[0]); + CHECK_SAME_DTYPE(out->dtype(), in->dtype(), weight->dtype()); + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), "Add: all tensors must be contiguous."); + + // always support cpu calculation + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rms_norm(out->data(), in->data(), weight->data(), out->dtype(), weight->numel(), out->numel(), eps); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rms_norm(out->data(), in->data(), weight->data(), out->dtype(), weight->numel(), out->numel(), eps); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 00000000..d6b2dd04 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,55 @@ +#include "rope_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +void rope_(std::byte *out_raw, const std::byte *in_raw, const std::byte *pos_ids_raw, size_t seqlen, size_t nhead, size_t d, float theta) { + T *out = reinterpret_cast(out_raw); + const T *in = reinterpret_cast(in_raw); + const int64_t *pos_ids = reinterpret_cast(pos_ids_raw); + + const size_t half_d = d / 2; + + for (size_t i = 0; i < seqlen; ++i) { + float p = static_cast(pos_ids[i]); + for (size_t h = 0; h < nhead; ++h) { + for (size_t j = 0; j < half_d; ++j) { + float phi = p / pow(theta, 2.0f * j / d); + float cos_phi = cos(phi); + float sin_phi = sin(phi); + + size_t idx_a = i * nhead * d + h * d + j; + size_t idx_b = i * nhead * d + h * d + j + half_d; + if constexpr (std::is_same_v || std::is_same_v) { + float a = llaisys::utils::cast(in[idx_a]); + float b = llaisys::utils::cast(in[idx_b]); + out[idx_a] = llaisys::utils::cast(a * cos_phi - b * sin_phi); + out[idx_b] = llaisys::utils::cast(b * cos_phi + a * sin_phi); + } else { + float a = static_cast(in[idx_a]); + float b = static_cast(in[idx_b]); + out[idx_a] = static_cast(a * cos_phi - b * sin_phi); + out[idx_b] = static_cast(b * cos_phi + a * sin_phi); + } + } + } + } +} + +#define DISPATCH_SELF_ATTENTION(dtype, ctype) case dtype: rope_(out, in, pos_ids, seqlen, nhead, d, theta); break; + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, size_t seqlen, size_t nhead, size_t d, float theta) { + switch (type) { + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_F32, float) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_BF16, llaisys::bf16_t) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_F16, llaisys::fp16_t) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_I32, int32_t) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_F64, double) + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 00000000..f97f6bbd --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, size_t seqlen, size_t nhead, size_t d, float theta); +} \ No newline at end of file diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64..5d160d4f 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,44 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rope_cpu.hpp" + namespace llaisys::ops { void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(in, out, pos_ids); + // Only support contiguous inputs with same shape for now. + CHECK_SAME_SHAPE(in->shape(), out->shape()); + CHECK_SAME_SHAPE(in->shape()[0], pos_ids->shape()[0]); + // 假设 pos_ids 是 INT64 + CHECK_SAME_DTYPE(in->dtype(), out->dtype()); + CHECK_SAME_DTYPE(LLAISYS_DTYPE_I64, pos_ids->dtype()); + ASSERT(in->isContiguous() && out->isContiguous() && pos_ids->isContiguous(), "Rope: all tensors must be contiguous."); + + const size_t seqlen = in->shape()[0]; + const size_t nhead = in->shape()[1]; + const size_t d = in->shape()[2]; + + ASSERT(d%2==0, "Rope: tensor in is of size [seqlen, nhead, d], but d is not divided by 2"); + + // always support cpu calculation + if (in->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rope(out->data(), in->data(), pos_ids->data(), in->dtype(), seqlen, nhead, d, theta); + } + + llaisys::core::context().setDevice(in->deviceType(), in->deviceId()); + + switch (in->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rope(out->data(), in->data(), pos_ids->data(), in->dtype(), seqlen, nhead, d, theta); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/self_attention/cpu/self_attention_cpu.cpp b/src/ops/self_attention/cpu/self_attention_cpu.cpp new file mode 100644 index 00000000..6f8ecc7a --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.cpp @@ -0,0 +1,133 @@ +#include "self_attention_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include +#include +#include + +template +void self_attention_(std::byte *attn_val_raw, const std::byte *q_raw, const std::byte *k_raw, const std::byte *v_raw, const size_t seqlen, const size_t nhead, const size_t nkvhead, const size_t d, const size_t dv, const size_t token_len, float scale) { + T *out = reinterpret_cast(attn_val_raw); + const T *q = reinterpret_cast(q_raw); + const T *k = reinterpret_cast(k_raw); + const T *v = reinterpret_cast(v_raw); + + const size_t group_size = nhead / nkvhead; + const size_t total_len = token_len; + + // 1. 遍历序列中的每一个 Query Token + for (size_t i = 0; i < seqlen; ++i) { + + // 测试用例使用 tril(diagonal=S-L),其中 L=seqlen, S=token_len + // 允许的 key 位置满足 t <= i + (S - L) + int64_t mask_limit = static_cast(i) + + static_cast(total_len) + - static_cast(seqlen); + if (mask_limit >= static_cast(total_len)) { + mask_limit = static_cast(total_len) - 1; + } + + // 2. 遍历每一个 Attention Head + for (size_t h = 0; h < nhead; ++h) { + + size_t kv_h = h / group_size; + + // 临时存储 Scores + std::vector scores(total_len, 0.0f); + float max_score = -std::numeric_limits::infinity(); + + // --- Step 1: 计算 Q * K^T --- + for (size_t t = 0; t < total_len; ++t) { + // 因果掩码:只能看以前的 token + // 如果 t > mask_limit,说明 Key 的位置在 Query 之后,屏蔽掉 + if (static_cast(t) > mask_limit) { + scores[t] = -std::numeric_limits::infinity(); + continue; + } + + // 点积计算 + float dot = 0.0f; + const T* q_vec = q + (i * nhead * d) + (h * d); + const T* k_vec = k + (t * nkvhead * d) + (kv_h * d); + + for (size_t m = 0; m < d; ++m) { + float q_val, k_val; + if constexpr (std::is_same_v || std::is_same_v) { + q_val = llaisys::utils::cast(q_vec[m]); + k_val = llaisys::utils::cast(k_vec[m]); + } else { + q_val = static_cast(q_vec[m]); + k_val = static_cast(k_vec[m]); + } + dot += q_val * k_val; + } + + scores[t] = dot * scale; + if (scores[t] > max_score) { + max_score = scores[t]; + } + } + + // --- Step 2: Softmax --- + float exp_sum = 0.0f; + for (size_t t = 0; t < total_len; ++t) { + if (static_cast(t) > mask_limit) { + scores[t] = 0.0f; + } else { + float exp_val = std::exp(scores[t] - max_score); + scores[t] = exp_val; + exp_sum += exp_val; + } + } + float inv_exp_sum = 1.0f / (exp_sum + 1e-9f); + + // --- Step 3: 加权求和 (prob * V) --- + std::vector acc(dv, 0.0f); + for (size_t t = 0; t < total_len; ++t) { + if (scores[t] == 0.0f) continue; + + float prob = scores[t] * inv_exp_sum; + const T* v_vec = v + (t * nkvhead * dv) + (kv_h * dv); + + for (size_t m = 0; m < dv; ++m) { + float v_val; + if constexpr (std::is_same_v || std::is_same_v) { + v_val = llaisys::utils::cast(v_vec[m]); + } else { + v_val = static_cast(v_vec[m]); + } + acc[m] += prob * v_val; + } + } + + // --- Step 4: 写入输出 --- + T* out_vec = out + (i * nhead * dv) + (h * dv); + for (size_t m = 0; m < dv; ++m) { + if constexpr (std::is_same_v || std::is_same_v) { + out_vec[m] = llaisys::utils::cast(acc[m]); + } else { + out_vec[m] = static_cast(acc[m]); + } + } + } + } +} + +#define DISPATCH_SELF_ATTENTION(dtype, ctype) case dtype: self_attention_(attn_val, q, k, v, seqlen, nhead, nkvhead, d, dv, token_len, scale); break; + +namespace llaisys::ops::cpu { +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, llaisysDataType_t type, const size_t seqlen, const size_t nhead, const size_t nkvhead, const size_t d, const size_t dv, const size_t token_len, float scale){ + switch (type) { + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_F32, float) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_BF16, llaisys::bf16_t) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_F16, llaisys::fp16_t) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_I32, int32_t) + DISPATCH_SELF_ATTENTION(LLAISYS_DTYPE_F64, double) + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/cpu/self_attention_cpu.hpp b/src/ops/self_attention/cpu/self_attention_cpu.hpp new file mode 100644 index 00000000..177b976a --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, llaisysDataType_t type, const size_t seqlen, const size_t nhead, const size_t nkvhead, const size_t d, const size_t dv, const size_t token_len, float scale); +} \ No newline at end of file diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d62014..64525951 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,48 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/self_attention_cpu.hpp" + namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(attn_val, q, k, v); + // Only support contiguous inputs. + CHECK_SAME_SHAPE(attn_val->shape()[0], q->shape()[0]); + CHECK_SAME_SHAPE(attn_val->shape()[1], q->shape()[1]); + CHECK_SAME_SHAPE(attn_val->shape()[2], v->shape()[2]); + CHECK_SAME_SHAPE(k->shape()[0], v->shape()[0]); + CHECK_SAME_SHAPE(k->shape()[1], v->shape()[1]); + CHECK_SAME_SHAPE(q->shape()[2], k->shape()[2]); + CHECK_SAME_DTYPE(attn_val->dtype(), q->dtype(), k->dtype(), v->dtype()); + ASSERT(attn_val->isContiguous() && q->isContiguous() && k->isContiguous() && v->isContiguous(), "Self-Attention: all tensors must be contiguous."); + + const size_t seqlen = attn_val->shape()[0]; + const size_t nhead = attn_val->shape()[1]; + // FIX: d should be the head dimension of Q/K, not the output (which is dv) + const size_t d = q->shape()[2]; + const size_t token_len = k->shape()[0]; + const size_t nkvhead = k->shape()[1]; + const size_t dv = v->shape()[2]; + + // always support cpu calculation + if (attn_val->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), attn_val->dtype(), seqlen, nhead, nkvhead, d, dv, token_len, scale); + } + + llaisys::core::context().setDevice(attn_val->deviceType(), attn_val->deviceId()); + + switch (attn_val->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), attn_val->dtype(), seqlen, nhead, nkvhead, d, dv, token_len, scale); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 00000000..392225a7 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,35 @@ +#include "swiglu_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +void swiglu_(std::byte *out_raw, const std::byte *gate_raw, const std::byte *up_raw, const size_t numel) { + T *out = reinterpret_cast(out_raw); + const T *gate = reinterpret_cast(gate_raw); + const T *up = reinterpret_cast(up_raw); + for (size_t i = 0; i < numel; i++) { + if constexpr (std::is_same_v || std::is_same_v) { + out[i] = llaisys::utils::cast(llaisys::utils::cast(up[i]) * llaisys::utils::cast(gate[i]) / (1 + std::exp(- llaisys::utils::cast(gate[i])))); + } else { + out[i] = llaisys::utils::cast(static_cast(up[i]) * static_cast(gate[i]) / (1 + std::exp(- static_cast(gate[i])))); + } + } +} + +#define DISPATCH_SWIGLU(dtype, ctype) case dtype: swiglu_(out, gate, up, numel); break; + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t type, const size_t numel) { + switch (type) { + DISPATCH_SWIGLU(LLAISYS_DTYPE_F32, float) + DISPATCH_SWIGLU(LLAISYS_DTYPE_BF16, llaisys::bf16_t) + DISPATCH_SWIGLU(LLAISYS_DTYPE_F16, llaisys::fp16_t) + DISPATCH_SWIGLU(LLAISYS_DTYPE_I32, int32_t) + DISPATCH_SWIGLU(LLAISYS_DTYPE_F64, double) + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 00000000..0ec69674 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t type, const size_t numel); +} \ No newline at end of file diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc9..092af106 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,35 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/swiglu_cpu.hpp" + namespace llaisys::ops { void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, gate, up); + // Only support contiguous inputs with same shape for now. + CHECK_SAME_SHAPE(out->shape(), gate->shape(), up->shape()); + CHECK_SAME_DTYPE(out->dtype(), gate->dtype(), up->dtype()); + ASSERT(out->isContiguous() && gate->isContiguous() && up->isContiguous(), "Add: all tensors must be contiguous."); + + // always support cpu calculation + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), out->numel()); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), out->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb6..f5bcdbc7 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -164,27 +164,73 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + size_t stride = 1; + size_t ndim = shape().size(); + for (size_t i = 1; i <= ndim; i++) { + if (static_cast(strides()[ndim - i]) != stride) + return false; + stride *= shape()[ndim - i]; + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t ndim = _meta.shape.size(); + std::vector new_shape(ndim); + std::vector new_strides(ndim); + for (size_t i = 0; i < order.size(); i++){ + new_shape[i] = shape()[order[i]]; + new_strides[i] = strides()[order[i]]; + } + TensorMeta meta{_meta.dtype, new_shape, new_strides}; + return std::shared_ptr(new Tensor(meta, _storage)); } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + ASSERT(this->isContiguous(), "View: only contiguous"); + size_t ndim = shape.size(); + std::vector new_strides(ndim); + size_t stride = 1; + for (size_t i = 1; i <= ndim; i++) { + new_strides[ndim - i] = stride; + stride *= shape[ndim - i]; + } + ASSERT( + stride == std::accumulate( + shape.begin(), + shape.end(), + size_t(1), + std::multiplies()), + "view: invalid shape (number mismatched)" + ); + TensorMeta meta{_meta.dtype, shape, new_strides}; + return std::shared_ptr(new Tensor(meta, _storage)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + if (dim >= ndim() || start > end || end > shape()[dim]) { + throw std::runtime_error("slice: invalid parameters"); + } + + std::vector new_shape = shape(); + std::vector new_strides = strides(); + new_shape[dim] = end - start; + + size_t new_offset = _offset + start * new_strides[dim] * this->elementSize(); + + TensorMeta new_meta{_meta.dtype, new_shape, new_strides}; + return std::shared_ptr(new Tensor(new_meta, _storage, new_offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + core::context().setDevice(this->deviceType(), this->deviceId()); + auto memcpy_type = (this->deviceType() == LLAISYS_DEVICE_CPU) ? LLAISYS_MEMCPY_H2H : LLAISYS_MEMCPY_H2D; + core::context().runtime().api()->memcpy_sync( + this->data(), + src_, + this->numel() * this->elementSize(), + memcpy_type + ); } tensor_t Tensor::contiguous() const { diff --git a/src/utils/types.cpp b/src/utils/types.cpp index 4163c214..c93146ee 100644 --- a/src/utils/types.cpp +++ b/src/utils/types.cpp @@ -1,8 +1,224 @@ #include "types.hpp" #include +#include + +#ifdef __C + #define LLAISYS_EXTERN_C 1 + #pragma push_macro("__C") + #undef __C +#endif + +#include // AVX2 for vectorized conversions + +#ifdef LLAISYS_EXTERN_C + #undef LLAISYS_EXTERN_C + #pragma pop_macro("__C") +#endif + +#include namespace llaisys::utils { + +size_t dsize(llaisysDataType_t dtype) { + switch (dtype) { + case LLAISYS_DTYPE_BYTE: + return sizeof(char); + case LLAISYS_DTYPE_BOOL: + return sizeof(char); + case LLAISYS_DTYPE_I8: + return sizeof(int8_t); + case LLAISYS_DTYPE_I16: + return sizeof(int16_t); + case LLAISYS_DTYPE_I32: + return sizeof(int32_t); + case LLAISYS_DTYPE_I64: + return sizeof(int64_t); + case LLAISYS_DTYPE_U8: + return sizeof(uint8_t); + case LLAISYS_DTYPE_U16: + return sizeof(uint16_t); + case LLAISYS_DTYPE_U32: + return sizeof(uint32_t); + case LLAISYS_DTYPE_U64: + return sizeof(uint64_t); + case LLAISYS_DTYPE_F8: + return sizeof(f8_t); + case LLAISYS_DTYPE_F16: + return sizeof(fp16_t); + case LLAISYS_DTYPE_BF16: + return sizeof(bf16_t); + case LLAISYS_DTYPE_F32: + return sizeof(float); + case LLAISYS_DTYPE_F64: + return sizeof(double); + case LLAISYS_DTYPE_C16: + return sizeof(cp16_t); + case LLAISYS_DTYPE_C32: + return sizeof(cp32_t); + case LLAISYS_DTYPE_C64: + return sizeof(cp64_t); + case LLAISYS_DTYPE_C128: + return sizeof(cp128_t); + case LLAISYS_DTYPE_INVALID: + default: + throw std::invalid_argument("Unsupported or invalid data type."); + } +} + +const char *dtype_to_str(llaisysDataType_t dtype) { + switch (dtype) { + case LLAISYS_DTYPE_BYTE: + return "byte"; + case LLAISYS_DTYPE_BOOL: + return "bool"; + case LLAISYS_DTYPE_I8: + return "int8"; + case LLAISYS_DTYPE_I16: + return "int16"; + case LLAISYS_DTYPE_I32: + return "int32"; + case LLAISYS_DTYPE_I64: + return "int64"; + case LLAISYS_DTYPE_U8: + return "uint8"; + case LLAISYS_DTYPE_U16: + return "uint16"; + case LLAISYS_DTYPE_U32: + return "uint32"; + case LLAISYS_DTYPE_U64: + return "uint64"; + case LLAISYS_DTYPE_F8: + return "float8"; + case LLAISYS_DTYPE_F16: + return "float16"; + case LLAISYS_DTYPE_BF16: + return "bfloat16"; + case LLAISYS_DTYPE_F32: + return "float32"; + case LLAISYS_DTYPE_F64: + return "float64"; + case LLAISYS_DTYPE_C16: + return "complex16"; + case LLAISYS_DTYPE_C32: + return "complex32"; + case LLAISYS_DTYPE_C64: + return "complex64"; + case LLAISYS_DTYPE_C128: + return "complex128"; + case LLAISYS_DTYPE_INVALID: + default: + throw std::invalid_argument("Unsupported or invalid data type."); + } +} + +void fp16_to_fp32_vec(const uint16_t* src, float* dst, size_t n) { +#if defined(__AVX2__) && defined(__F16C__) + const size_t last_block_start = n - (n % 16); + +#ifdef _OPENMP + #pragma omp parallel for if(n > 16384) +#endif + for (size_t i = 0; i < last_block_start; i += 16) { + __m256i vfp16 = _mm256_loadu_si256((const __m256i*)(src + i)); + __m128i lo = _mm256_extracti128_si256(vfp16, 0); + __m128i hi = _mm256_extracti128_si256(vfp16, 1); + _mm256_store_ps(dst + i, _mm256_cvtph_ps(lo)); + _mm256_store_ps(dst + i + 8, _mm256_cvtph_ps(hi)); + } + for (size_t i = last_block_start; i < n; i++) { + dst[i] = _f16_to_f32(reinterpret_cast(src)[i]); + } +#else + for (size_t i = 0; i < n; i++) { + dst[i] = _f16_to_f32(reinterpret_cast(src)[i]); + } +#endif +} + +void bf16_to_fp32_vec(const uint16_t* src, float* dst, size_t n) { +#ifdef __AVX2__ + const size_t last_block_start = n - (n % 16); + +#ifdef _OPENMP + #pragma omp parallel for if(n > 16384) +#endif + for (size_t i = 0; i < last_block_start; i += 16) { + __m256i vbf16 = _mm256_loadu_si256((const __m256i*)(src + i)); + __m128i lo = _mm256_extracti128_si256(vbf16, 0); + __m128i hi = _mm256_extracti128_si256(vbf16, 1); + __m256i lo_32 = _mm256_slli_epi32(_mm256_cvtepu16_epi32(lo), 16); + __m256i hi_32 = _mm256_slli_epi32(_mm256_cvtepu16_epi32(hi), 16); + _mm256_store_ps(dst + i, _mm256_castsi256_ps(lo_32)); + _mm256_store_ps(dst + i + 8, _mm256_castsi256_ps(hi_32)); + } + for (size_t i = last_block_start; i < n; i++) { + dst[i] = _bf16_to_f32(reinterpret_cast(src)[i]); + } +#else + for (size_t i = 0; i < n; i++) { + dst[i] = _bf16_to_f32(reinterpret_cast(src)[i]); + } +#endif +} + +void fp32_to_fp16_vec(const float* src, uint16_t* dst, size_t n) { +#if defined(__AVX2__) && defined(__F16C__) + const size_t last_block_start = n - (n % 16); + +#ifdef _OPENMP + #pragma omp parallel for if(n > 16384) +#endif + for (size_t i = 0; i < last_block_start; i += 16) { + __m256 lo = _mm256_load_ps(src + i); + __m256 hi = _mm256_load_ps(src + i + 8); + __m128i lo16 = _mm256_cvtps_ph(lo, 0); + __m128i hi16 = _mm256_cvtps_ph(hi, 0); + __m256i combined = _mm256_set_m128i(hi16, lo16); + _mm256_storeu_si256((__m256i*)(dst + i), combined); + } + for (size_t i = last_block_start; i < n; i++) { + dst[i] = _f32_to_f16(src[i])._v; + } +#else + for (size_t i = 0; i < n; i++) { + dst[i] = _f32_to_f16(src[i])._v; + } +#endif +} + +void fp32_to_bf16_vec(const float* src, uint16_t* dst, size_t n) { +#ifdef __AVX__ + const size_t last_block_start = n - (n % 8); + +#ifdef _OPENMP + #pragma omp parallel for if(n > 16384) +#endif + for (size_t i = 0; i < last_block_start; i += 8) { + __m256 f = _mm256_load_ps(src + i); + __m256i as_int = _mm256_castps_si256(f); + __m256i bias = _mm256_set1_epi32(0x7FFF); + __m256i lsb = _mm256_and_si256(_mm256_srli_epi32(as_int, 16), _mm256_set1_epi32(1)); + __m256i rounding = _mm256_add_epi32(bias, lsb); + __m256i rounded = _mm256_add_epi32(as_int, rounding); + __m256i shifted = _mm256_srli_epi32(rounded, 16); + + __m128i lo = _mm256_extracti128_si256(shifted, 0); // 4 x uint32 + __m128i hi = _mm256_extracti128_si256(shifted, 1); // 4 x uint32 + + __m128i packed = _mm_packus_epi32(lo, hi); + _mm_storeu_si128((__m128i*)(dst + i), packed); + } + for (size_t i = last_block_start; i < n; i++) { + dst[i] = _f32_to_bf16(src[i])._v; + } +#else + for (size_t i = 0; i < n; i++) { + dst[i] = _f32_to_bf16(src[i])._v; + } +#endif +} + float _f16_to_f32(fp16_t val) { uint16_t h = val._v; uint32_t sign = (h & 0x8000) << 16; diff --git a/src/utils/types.hpp b/src/utils/types.hpp index e09619db..225bd3d2 100644 --- a/src/utils/types.hpp +++ b/src/utils/types.hpp @@ -1,7 +1,10 @@ +#pragma once + #include "llaisys.h" -#include -#include +#include +#include +#include namespace llaisys { struct CustomFloat16 { @@ -14,98 +17,39 @@ struct CustomBFloat16 { }; typedef struct CustomBFloat16 bf16_t; +struct CustomFloat8 { + uint8_t _v; +}; +typedef struct CustomFloat8 f8_t; + +struct CustomComplex16 { + fp16_t re; + fp16_t im; +}; +typedef struct CustomComplex16 cp16_t; + +struct CustomComplex32 { + fp16_t re; + fp16_t im; +}; +typedef struct CustomComplex32 cp32_t; + +struct CustomComplex64 { + float re; + float im; +}; +typedef struct CustomComplex64 cp64_t; + +struct CustomComplex128 { + double re; + double im; +}; +typedef struct CustomComplex128 cp128_t; + namespace utils { -inline size_t dsize(llaisysDataType_t dtype) { - switch (dtype) { - case LLAISYS_DTYPE_BYTE: - return sizeof(char); - case LLAISYS_DTYPE_BOOL: - return sizeof(char); - case LLAISYS_DTYPE_I8: - return sizeof(int8_t); - case LLAISYS_DTYPE_I16: - return sizeof(int16_t); - case LLAISYS_DTYPE_I32: - return sizeof(int32_t); - case LLAISYS_DTYPE_I64: - return sizeof(int64_t); - case LLAISYS_DTYPE_U8: - return sizeof(uint8_t); - case LLAISYS_DTYPE_U16: - return sizeof(uint16_t); - case LLAISYS_DTYPE_U32: - return sizeof(uint32_t); - case LLAISYS_DTYPE_U64: - return sizeof(uint64_t); - case LLAISYS_DTYPE_F8: - return 1; // usually 8-bit float (custom) - case LLAISYS_DTYPE_F16: - return 2; // 16-bit float - case LLAISYS_DTYPE_BF16: - return 2; // bfloat16 - case LLAISYS_DTYPE_F32: - return sizeof(float); - case LLAISYS_DTYPE_F64: - return sizeof(double); - case LLAISYS_DTYPE_C16: - return 2; // 2 bytes complex (not standard) - case LLAISYS_DTYPE_C32: - return 4; // 4 bytes complex - case LLAISYS_DTYPE_C64: - return 8; // 8 bytes complex - case LLAISYS_DTYPE_C128: - return 16; // 16 bytes complex - case LLAISYS_DTYPE_INVALID: - default: - throw std::invalid_argument("Unsupported or invalid data type."); - } -} -inline const char *dtype_to_str(llaisysDataType_t dtype) { - switch (dtype) { - case LLAISYS_DTYPE_BYTE: - return "byte"; - case LLAISYS_DTYPE_BOOL: - return "bool"; - case LLAISYS_DTYPE_I8: - return "int8"; - case LLAISYS_DTYPE_I16: - return "int16"; - case LLAISYS_DTYPE_I32: - return "int32"; - case LLAISYS_DTYPE_I64: - return "int64"; - case LLAISYS_DTYPE_U8: - return "uint8"; - case LLAISYS_DTYPE_U16: - return "uint16"; - case LLAISYS_DTYPE_U32: - return "uint32"; - case LLAISYS_DTYPE_U64: - return "uint64"; - case LLAISYS_DTYPE_F8: - return "float8"; - case LLAISYS_DTYPE_F16: - return "float16"; - case LLAISYS_DTYPE_BF16: - return "bfloat16"; - case LLAISYS_DTYPE_F32: - return "float32"; - case LLAISYS_DTYPE_F64: - return "float64"; - case LLAISYS_DTYPE_C16: - return "complex16"; - case LLAISYS_DTYPE_C32: - return "complex32"; - case LLAISYS_DTYPE_C64: - return "complex64"; - case LLAISYS_DTYPE_C128: - return "complex128"; - case LLAISYS_DTYPE_INVALID: - default: - throw std::invalid_argument("Unsupported or invalid data type."); - } -} +size_t dsize(llaisysDataType_t dtype); +const char *dtype_to_str(llaisysDataType_t dtype); float _f16_to_f32(fp16_t val); fp16_t _f32_to_f16(float val); @@ -113,6 +57,12 @@ fp16_t _f32_to_f16(float val); float _bf16_to_f32(bf16_t val); bf16_t _f32_to_bf16(float val); +// Vectorized conversions (AVX2 + F16C) for bulk casting +void fp16_to_fp32_vec(const uint16_t* src, float* dst, size_t n); +void bf16_to_fp32_vec(const uint16_t* src, float* dst, size_t n); +void fp32_to_fp16_vec(const float* src, uint16_t* dst, size_t n); +void fp32_to_bf16_vec(const float* src, uint16_t* dst, size_t n); + template TypeTo cast(TypeFrom val) { if constexpr (std::is_same::value) { diff --git a/xmake.lua b/xmake.lua index 1f65f7a9..89f0edea 100644 --- a/xmake.lua +++ b/xmake.lua @@ -27,6 +27,15 @@ target("llaisys-utils") add_cxflags("-fPIC", "-Wno-unknown-pragmas") end + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cxflags("-fopenmp") + add_cxflags("-mavx2", "-mf16c") + add_ldflags("-fopenmp") + else + add_cxxflags("/openmp") + end + add_files("src/utils/*.cpp") on_install(function (target) end) @@ -105,7 +114,9 @@ target("llaisys") set_languages("cxx17") set_warnings("all", "error") + add_files("src/llaisys/*.cc") + add_files("src/llaisys/models/*.cpp") set_installdir(".") diff --git a/xmake/cpu.lua b/xmake/cpu.lua index 101d894e..e9c703c6 100644 --- a/xmake/cpu.lua +++ b/xmake/cpu.lua @@ -18,6 +18,14 @@ target("llaisys-ops-cpu") set_warnings("all", "error") if not is_plat("windows") then add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cxflags("-mavx2") -- AVX2 + add_cxflags("-mf16c") -- F16C + add_cxflags("-fopenmp") -- OpenMP + add_ldflags("-fopenmp") -- OpenMP runtime + add_shflags("-lgomp") + add_shflags("-lopenblas") -- OpenBlas + else + add_cxxflags("/openmp") end add_files("../src/ops/*/cpu/*.cpp")