diff --git a/.gitignore b/.gitignore index e38cf574..fe63887d 100644 --- a/.gitignore +++ b/.gitignore @@ -87,4 +87,8 @@ htmlcov/ # Windows Thumbs.db ehthumbs.db -desktop.ini \ No newline at end of file +desktop.ini + +# Models +/models/ +/checkpoints/ \ No newline at end of file diff --git a/include/llaisys.h b/include/llaisys.h index 73ca7eea..ca9f0318 100644 --- a/include/llaisys.h +++ b/include/llaisys.h @@ -24,6 +24,7 @@ typedef enum { LLAISYS_DEVICE_CPU = 0, //// TODO: Add more device types here. Numbers need to be consecutive. LLAISYS_DEVICE_NVIDIA = 1, + LLAISYS_DEVICE_METAX = 2, LLAISYS_DEVICE_TYPE_COUNT } llaisysDeviceType_t; diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h index 7054626d..c4dd10d9 100644 --- a/include/llaisys/models/qwen2.h +++ b/include/llaisys/models/qwen2.h @@ -38,5 +38,7 @@ __C { __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model); __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); + + __export void llaisysQwen2ModelResetCache(struct LlaisysQwen2Model * model); } #endif // LLAISYS_MODELS_QWEN2_H diff --git a/include/llaisys/models/qwen2_tp.h b/include/llaisys/models/qwen2_tp.h new file mode 100644 index 00000000..ac1839a6 --- /dev/null +++ b/include/llaisys/models/qwen2_tp.h @@ -0,0 +1,37 @@ +#ifndef LLAISYS_MODELS_QWEN2_TP_H +#define LLAISYS_MODELS_QWEN2_TP_H + +#include "qwen2.h" + +__C { + // Tensor Parallel Qwen2 Model + struct LlaisysQwen2ModelTP; + + // Create a TP model with multiple devices + // device_ids: array of device IDs (e.g., [0, 1, 2, 3] for 4-GPU TP) + // ndevice: number of devices (TP world size) + __export struct LlaisysQwen2ModelTP *llaisysQwen2ModelTPCreate( + const struct LlaisysQwen2Meta *meta, + const int *device_ids, + int world_size); + + __export void llaisysQwen2ModelTPDestroy(struct LlaisysQwen2ModelTP *model); + + // Get weights for each rank + // Returns array of weight pointers, one for each rank + __export struct LlaisysQwen2Weights *llaisysQwen2ModelTPWeights( + struct LlaisysQwen2ModelTP *model, + int rank); + + __export int64_t llaisysQwen2ModelTPInfer( + struct LlaisysQwen2ModelTP *model, + const int64_t *token_ids, + size_t ntoken); + + __export void llaisysQwen2ModelTPResetCache(struct LlaisysQwen2ModelTP *model); + + // Get the number of ranks in the TP group + __export int llaisysQwen2ModelTPGetWorldSize(struct LlaisysQwen2ModelTP *model); +} + +#endif // LLAISYS_MODELS_QWEN2_TP_H diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py index f536fb52..7738c75c 100644 --- a/python/llaisys/libllaisys/__init__.py +++ b/python/llaisys/libllaisys/__init__.py @@ -12,10 +12,12 @@ from .tensor import llaisysTensor_t from .tensor import load_tensor from .ops import load_ops +from .models.qwen2 import load_qwen2, LlaisysQwen2Meta, LlaisysQwen2Weights +from .models.qwen2_tp import load_qwen2_tp def load_shared_library(): - lib_dir = Path(__file__).parent + lib_dir = Path(__file__).parent.resolve() if sys.platform.startswith("linux"): libname = "libllaisys.so" @@ -26,9 +28,9 @@ def load_shared_library(): else: raise RuntimeError("Unsupported platform") - lib_path = os.path.join(lib_dir, libname) + lib_path = lib_dir / libname - if not os.path.isfile(lib_path): + if not lib_path.is_file(): raise FileNotFoundError(f"Shared library not found: {lib_path}") return ctypes.CDLL(str(lib_path)) @@ -38,6 +40,8 @@ def load_shared_library(): load_runtime(LIB_LLAISYS) load_tensor(LIB_LLAISYS) load_ops(LIB_LLAISYS) +load_qwen2(LIB_LLAISYS) +load_qwen2_tp(LIB_LLAISYS) __all__ = [ @@ -52,4 +56,6 @@ def load_shared_library(): "llaisysMemcpyKind_t", "MemcpyKind", "llaisysStream_t", + "LlaisysQwen2Meta", + "LlaisysQwen2Weights", ] diff --git a/python/llaisys/libllaisys/llaisys_types.py b/python/llaisys/libllaisys/llaisys_types.py index c5a0b467..cbe92132 100644 --- a/python/llaisys/libllaisys/llaisys_types.py +++ b/python/llaisys/libllaisys/llaisys_types.py @@ -6,7 +6,8 @@ class DeviceType(IntEnum): CPU = 0 NVIDIA = 1 - COUNT = 2 + METAX = 2 + COUNT = 3 llaisysDeviceType_t = ctypes.c_int diff --git a/python/llaisys/libllaisys/models/__init__.py b/python/llaisys/libllaisys/models/__init__.py new file mode 100644 index 00000000..777fde1e --- /dev/null +++ b/python/llaisys/libllaisys/models/__init__.py @@ -0,0 +1,3 @@ +from .qwen2 import load_qwen2, LlaisysQwen2Meta, LlaisysQwen2Weights + +__all__ = ["load_qwen2", "LlaisysQwen2Meta", "LlaisysQwen2Weights"] diff --git a/python/llaisys/libllaisys/models/qwen2.py b/python/llaisys/libllaisys/models/qwen2.py new file mode 100644 index 00000000..60c264a1 --- /dev/null +++ b/python/llaisys/libllaisys/models/qwen2.py @@ -0,0 +1,83 @@ +from ctypes import ( + Structure, + POINTER, + c_void_p, + c_size_t, + c_int, + c_int64, + c_float, +) +from ..llaisys_types import llaisysDataType_t, llaisysDeviceType_t +from ..tensor import llaisysTensor_t + + +class LlaisysQwen2Meta(Structure): + _fields_ = [ + ("dtype", llaisysDataType_t), + ("nlayer", c_size_t), + ("hs", c_size_t), # hidden_size + ("nh", c_size_t), # num_attention_heads + ("nkvh", c_size_t), # num_key_value_heads + ("dh", c_size_t), # head_dim + ("di", c_size_t), # intermediate_size + ("maxseq", c_size_t), # max_position_embeddings + ("voc", c_size_t), # vocab_size + ("epsilon", c_float), # rms_norm_eps + ("theta", c_float), # rope_theta + ("end_token", c_int64), # eos_token_id + ] + + +class LlaisysQwen2Weights(Structure): + _fields_ = [ + ("in_embed", llaisysTensor_t), + ("out_embed", llaisysTensor_t), + ("out_norm_w", llaisysTensor_t), + ("attn_norm_w", POINTER(llaisysTensor_t)), + ("attn_q_w", POINTER(llaisysTensor_t)), + ("attn_q_b", POINTER(llaisysTensor_t)), + ("attn_k_w", POINTER(llaisysTensor_t)), + ("attn_k_b", POINTER(llaisysTensor_t)), + ("attn_v_w", POINTER(llaisysTensor_t)), + ("attn_v_b", POINTER(llaisysTensor_t)), + ("attn_o_w", POINTER(llaisysTensor_t)), + ("mlp_norm_w", POINTER(llaisysTensor_t)), + ("mlp_gate_w", POINTER(llaisysTensor_t)), + ("mlp_up_w", POINTER(llaisysTensor_t)), + ("mlp_down_w", POINTER(llaisysTensor_t)), + ] + + +# Model handle type +llaisysQwen2Model_t = c_void_p + + +def load_qwen2(lib): + # llaisysQwen2ModelCreate + lib.llaisysQwen2ModelCreate.argtypes = [ + POINTER(LlaisysQwen2Meta), # meta + llaisysDeviceType_t, # device + POINTER(c_int), # device_ids + c_int, # ndevice + ] + lib.llaisysQwen2ModelCreate.restype = llaisysQwen2Model_t + + # llaisysQwen2ModelDestroy + lib.llaisysQwen2ModelDestroy.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelDestroy.restype = None + + # llaisysQwen2ModelWeights + lib.llaisysQwen2ModelWeights.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelWeights.restype = POINTER(LlaisysQwen2Weights) + + # llaisysQwen2ModelInfer + lib.llaisysQwen2ModelInfer.argtypes = [ + llaisysQwen2Model_t, # model + POINTER(c_int64), # token_ids + c_size_t, # ntoken + ] + lib.llaisysQwen2ModelInfer.restype = c_int64 + + # llaisysQwen2ModelResetCache + lib.llaisysQwen2ModelResetCache.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelResetCache.restype = None diff --git a/python/llaisys/libllaisys/models/qwen2_tp.py b/python/llaisys/libllaisys/models/qwen2_tp.py new file mode 100644 index 00000000..2e239e39 --- /dev/null +++ b/python/llaisys/libllaisys/models/qwen2_tp.py @@ -0,0 +1,52 @@ +from ctypes import ( + Structure, + POINTER, + c_void_p, + c_size_t, + c_int, + c_int64, +) +from ..llaisys_types import llaisysDataType_t, llaisysDeviceType_t +from ..tensor import llaisysTensor_t +from .qwen2 import LlaisysQwen2Meta, LlaisysQwen2Weights + + +# TP Model handle type +llaisysQwen2ModelTP_t = c_void_p + + +def load_qwen2_tp(lib): + # llaisysQwen2ModelTPCreate + lib.llaisysQwen2ModelTPCreate.argtypes = [ + POINTER(LlaisysQwen2Meta), # meta + POINTER(c_int), # device_ids + c_int, # world_size + ] + lib.llaisysQwen2ModelTPCreate.restype = llaisysQwen2ModelTP_t + + # llaisysQwen2ModelTPDestroy + lib.llaisysQwen2ModelTPDestroy.argtypes = [llaisysQwen2ModelTP_t] + lib.llaisysQwen2ModelTPDestroy.restype = None + + # llaisysQwen2ModelTPWeights + lib.llaisysQwen2ModelTPWeights.argtypes = [ + llaisysQwen2ModelTP_t, # model + c_int, # rank + ] + lib.llaisysQwen2ModelTPWeights.restype = POINTER(LlaisysQwen2Weights) + + # llaisysQwen2ModelTPInfer + lib.llaisysQwen2ModelTPInfer.argtypes = [ + llaisysQwen2ModelTP_t, # model + POINTER(c_int64), # token_ids + c_size_t, # ntoken + ] + lib.llaisysQwen2ModelTPInfer.restype = c_int64 + + # llaisysQwen2ModelTPResetCache + lib.llaisysQwen2ModelTPResetCache.argtypes = [llaisysQwen2ModelTP_t] + lib.llaisysQwen2ModelTPResetCache.restype = None + + # llaisysQwen2ModelTPGetWorldSize + lib.llaisysQwen2ModelTPGetWorldSize.argtypes = [llaisysQwen2ModelTP_t] + lib.llaisysQwen2ModelTPGetWorldSize.restype = c_int diff --git a/python/llaisys/models/__init__.py b/python/llaisys/models/__init__.py index af9918b0..d516b608 100644 --- a/python/llaisys/models/__init__.py +++ b/python/llaisys/models/__init__.py @@ -1 +1,4 @@ from .qwen2 import Qwen2 +from .qwen2_tp import Qwen2TP + +__all__ = ["Qwen2", "Qwen2TP"] diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b2..95ea936b 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,33 +1,164 @@ from typing import Sequence from ..libllaisys import LIB_LLAISYS -from ..libllaisys import DeviceType +from ..libllaisys import DeviceType, DataType +from ..libllaisys import LlaisysQwen2Meta, LlaisysQwen2Weights from pathlib import Path +from ctypes import c_int, c_int64, c_size_t, POINTER, byref +import json import safetensors class Qwen2: def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor - model_path = Path(model_path) + # Load config + config_path = model_path / "config.json" + with open(config_path, "r") as f: + config = json.load(f) + + # Extract model parameters + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_key_value_heads = config["num_key_value_heads"] + head_dim = hidden_size // num_attention_heads + + # Create meta structure + # Use FP32 for better precision on MetaX (avoid bfloat16 accumulation errors) + self._meta = LlaisysQwen2Meta() + self._meta.dtype = DataType.F32.value if device == DeviceType.METAX else DataType.BF16.value + self._meta.nlayer = config["num_hidden_layers"] + self._meta.hs = hidden_size + self._meta.nh = num_attention_heads + self._meta.nkvh = num_key_value_heads + self._meta.dh = head_dim + self._meta.di = config["intermediate_size"] + self._meta.maxseq = min(config.get("max_position_embeddings", 131072), 4096) + self._meta.voc = config["vocab_size"] + self._meta.epsilon = config.get("rms_norm_eps", 1e-6) + self._meta.theta = config.get("rope_theta", 10000.0) + self._meta.end_token = config.get("eos_token_id", 151643) + + self._device = device + self._nlayer = self._meta.nlayer + + # Create model + device_ids = (c_int * 1)(0) + self._model = LIB_LLAISYS.llaisysQwen2ModelCreate( + byref(self._meta), + device.value, + device_ids, + 1 + ) + + # Get weights pointer + weights_ptr = LIB_LLAISYS.llaisysQwen2ModelWeights(self._model) + self._weights = weights_ptr.contents + + # Load weights from safetensors + self._load_weights(model_path) + + def __del__(self): + if hasattr(self, "_model") and self._model: + LIB_LLAISYS.llaisysQwen2ModelDestroy(self._model) + self._model = None + + def _load_weights(self, model_path: Path): + """Load weights from safetensors files.""" + import torch + + # Collect all tensors from safetensors files + all_tensors = {} for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") - for name_ in data_.keys(): - ## TODO: load the model weights - pass + data = safetensors.safe_open(file, framework="pt", device="cpu") + for name in data.keys(): + all_tensors[name] = data.get_tensor(name) + + # Helper to load a tensor + def load_tensor(tensor_handle, tensor_data): + # Convert to the target dtype + if self._device == DeviceType.METAX: + # Use FP32 for MetaX to avoid precision issues + data = tensor_data.to(torch.float32).contiguous() + else: + data = tensor_data.to(torch.bfloat16).contiguous() + data_ptr = data.data_ptr() + LIB_LLAISYS.tensorLoad(tensor_handle, data_ptr) + + # Load embedding weights + load_tensor(self._weights.in_embed, all_tensors["model.embed_tokens.weight"]) + load_tensor(self._weights.out_embed, all_tensors["lm_head.weight"]) + load_tensor(self._weights.out_norm_w, all_tensors["model.norm.weight"]) + + # Load per-layer weights + for i in range(self._nlayer): + prefix = f"model.layers.{i}." + + load_tensor(self._weights.attn_norm_w[i], + all_tensors[prefix + "input_layernorm.weight"]) + load_tensor(self._weights.attn_q_w[i], + all_tensors[prefix + "self_attn.q_proj.weight"]) + load_tensor(self._weights.attn_q_b[i], + all_tensors[prefix + "self_attn.q_proj.bias"]) + load_tensor(self._weights.attn_k_w[i], + all_tensors[prefix + "self_attn.k_proj.weight"]) + load_tensor(self._weights.attn_k_b[i], + all_tensors[prefix + "self_attn.k_proj.bias"]) + load_tensor(self._weights.attn_v_w[i], + all_tensors[prefix + "self_attn.v_proj.weight"]) + load_tensor(self._weights.attn_v_b[i], + all_tensors[prefix + "self_attn.v_proj.bias"]) + load_tensor(self._weights.attn_o_w[i], + all_tensors[prefix + "self_attn.o_proj.weight"]) + load_tensor(self._weights.mlp_norm_w[i], + all_tensors[prefix + "post_attention_layernorm.weight"]) + load_tensor(self._weights.mlp_gate_w[i], + all_tensors[prefix + "mlp.gate_proj.weight"]) + load_tensor(self._weights.mlp_up_w[i], + all_tensors[prefix + "mlp.up_proj.weight"]) + load_tensor(self._weights.mlp_down_w[i], + all_tensors[prefix + "mlp.down_proj.weight"]) def generate( self, inputs: Sequence[int], - max_new_tokens: int = None, + max_new_tokens: int = 128, top_k: int = 1, top_p: float = 0.8, temperature: float = 0.8, ): + # Reset KV cache for new generation + LIB_LLAISYS.llaisysQwen2ModelResetCache(self._model) + + # Convert input to ctypes array + input_len = len(inputs) + input_arr = (c_int64 * input_len)(*inputs) + + # Output tokens list (starts with input) + output_tokens = list(inputs) + + # First forward pass with all input tokens + next_token = LIB_LLAISYS.llaisysQwen2ModelInfer( + self._model, + input_arr, + c_size_t(input_len) + ) + output_tokens.append(next_token) + + # Generate remaining tokens one by one + for _ in range(max_new_tokens - 1): + if next_token == self._meta.end_token: + break - # TODO: Implement generate function + # Single token input + single_token = (c_int64 * 1)(next_token) + next_token = LIB_LLAISYS.llaisysQwen2ModelInfer( + self._model, + single_token, + c_size_t(1) + ) + output_tokens.append(next_token) - return [] + return output_tokens diff --git a/python/llaisys/models/qwen2_tp.py b/python/llaisys/models/qwen2_tp.py new file mode 100644 index 00000000..edac5cf4 --- /dev/null +++ b/python/llaisys/models/qwen2_tp.py @@ -0,0 +1,309 @@ +from typing import Sequence, List +from ..libllaisys import LIB_LLAISYS +from ..libllaisys import DeviceType, DataType +from ..libllaisys import LlaisysQwen2Meta + +from pathlib import Path +from ctypes import c_int, c_int64, c_size_t, POINTER, byref +import json +import safetensors +import torch +import numpy as np + + +class Qwen2TP: + """Tensor Parallel Qwen2 Model. + + This class distributes the model across multiple GPUs using tensor parallelism. + """ + + def __init__(self, model_path, device_ids: List[int], device: DeviceType = DeviceType.NVIDIA): + """ + Args: + model_path: Path to the model directory + device_ids: List of GPU device IDs to use for tensor parallelism + device: Device type (must be NVIDIA for tensor parallelism) + """ + if device != DeviceType.NVIDIA: + raise ValueError("Tensor parallelism is only supported on NVIDIA GPUs") + + if len(device_ids) < 2: + raise ValueError("Tensor parallelism requires at least 2 GPUs") + + model_path = Path(model_path) + self._device_ids = device_ids + self._world_size = len(device_ids) + self._device = device + + # Load config + config_path = model_path / "config.json" + with open(config_path, "r") as f: + config = json.load(f) + + # Extract model parameters + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_key_value_heads = config["num_key_value_heads"] + head_dim = hidden_size // num_attention_heads + intermediate_size = config["intermediate_size"] + + # Validate divisibility + if num_attention_heads % self._world_size != 0: + raise ValueError(f"num_attention_heads ({num_attention_heads}) must be divisible by world_size ({self._world_size})") + if num_key_value_heads % self._world_size != 0: + raise ValueError(f"num_key_value_heads ({num_key_value_heads}) must be divisible by world_size ({self._world_size})") + if intermediate_size % self._world_size != 0: + raise ValueError(f"intermediate_size ({intermediate_size}) must be divisible by world_size ({self._world_size})") + + # Create meta structure + self._meta = LlaisysQwen2Meta() + self._meta.dtype = DataType.BF16.value + self._meta.nlayer = config["num_hidden_layers"] + self._meta.hs = hidden_size + self._meta.nh = num_attention_heads + self._meta.nkvh = num_key_value_heads + self._meta.dh = head_dim + self._meta.di = intermediate_size + self._meta.maxseq = min(config.get("max_position_embeddings", 131072), 4096) + self._meta.voc = config["vocab_size"] + self._meta.epsilon = config.get("rms_norm_eps", 1e-6) + self._meta.theta = config.get("rope_theta", 10000.0) + self._meta.end_token = config.get("eos_token_id", 151643) + + self._nlayer = self._meta.nlayer + + # Create TP model + device_ids_arr = (c_int * self._world_size)(*device_ids) + self._model = LIB_LLAISYS.llaisysQwen2ModelTPCreate( + byref(self._meta), + device_ids_arr, + self._world_size + ) + + # Get weights for each rank + self._weights_per_rank = [] + for i in range(self._world_size): + weights_ptr = LIB_LLAISYS.llaisysQwen2ModelTPWeights(self._model, i) + self._weights_per_rank.append(weights_ptr.contents) + + # Load weights from safetensors + self._load_weights(model_path) + + # Flag for automatic warm-up on first generate call + self._needs_warmup = True + + def _warmup(self): + """Run a 1-token warm-up inference to initialize NCCL internal state. + + This prevents 'unhandled cuda error' on the first real inference. + The warm-up runs a minimal forward pass that triggers all NCCL collectives. + """ + # Run 1-token inference (token 1 is a safe choice) + warmup_arr = (c_int64 * 1)(1) + LIB_LLAISYS.llaisysQwen2ModelTPInfer( + self._model, + warmup_arr, + c_size_t(1) + ) + # Reset cache so warm-up doesn't affect subsequent inference + LIB_LLAISYS.llaisysQwen2ModelTPResetCache(self._model) + + def __del__(self): + if hasattr(self, "_model") and self._model: + LIB_LLAISYS.llaisysQwen2ModelTPDestroy(self._model) + self._model = None + + def _load_weights(self, model_path: Path): + """Load and shard weights from safetensors files.""" + + # Collect all tensors from safetensors files + all_tensors = {} + for file in sorted(model_path.glob("*.safetensors")): + data = safetensors.safe_open(file, framework="pt", device="cpu") + for name in data.keys(): + all_tensors[name] = data.get_tensor(name) + + nh = self._meta.nh + nkvh = self._meta.nkvh + dh = self._meta.dh + di = self._meta.di + hs = self._meta.hs + world_size = self._world_size + + # Sharded dimensions + nh_shard = nh // world_size + nkvh_shard = nkvh // world_size + di_shard = di // world_size + + # Helper to load shared tensor (replicated on all ranks) + def load_shared(tensor_handle, tensor_data, dtype=torch.bfloat16): + data = tensor_data.to(dtype).contiguous() + data_ptr = data.data_ptr() + LIB_LLAISYS.tensorLoad(tensor_handle, data_ptr) + + # Helper to column-shard a weight tensor + def load_column_sharded(weight_handles, tensor_data, output_dim_size, dtype=torch.bfloat16): + """Column-shard a weight tensor along the output dimension.""" + for rank in range(world_size): + start = rank * output_dim_size + end = start + output_dim_size + shard = tensor_data[start:end].to(dtype).contiguous() + LIB_LLAISYS.tensorLoad(weight_handles[rank], shard.data_ptr()) + + # Helper to row-shard a weight tensor + def load_row_sharded(weight_handles, tensor_data, input_dim_size, dtype=torch.bfloat16): + """Row-shard a weight tensor along the input dimension.""" + for rank in range(world_size): + start = rank * input_dim_size + end = start + input_dim_size + shard = tensor_data[:, start:end].to(dtype).contiguous() + LIB_LLAISYS.tensorLoad(weight_handles[rank], shard.data_ptr()) + + # Load embedding weights (shared) + for rank in range(world_size): + load_shared(self._weights_per_rank[rank].in_embed, + all_tensors["model.embed_tokens.weight"]) + load_shared(self._weights_per_rank[rank].out_embed, + all_tensors["lm_head.weight"]) + load_shared(self._weights_per_rank[rank].out_norm_w, + all_tensors["model.norm.weight"]) + + # Load per-layer weights + for i in range(self._nlayer): + prefix = f"model.layers.{i}." + + for rank in range(world_size): + w = self._weights_per_rank[rank] + + # Layer norm weights (shared) + load_shared(w.attn_norm_w[i], + all_tensors[prefix + "input_layernorm.weight"]) + load_shared(w.mlp_norm_w[i], + all_tensors[prefix + "post_attention_layernorm.weight"]) + + # Column-shard Q, K, V weights and biases + # attn_q_w: [nh * dh, hs] -> [nh_shard * dh, hs] per rank + load_column_sharded( + [w.attn_q_w[i] for w in self._weights_per_rank], + all_tensors[prefix + "self_attn.q_proj.weight"], + nh_shard * dh + ) + load_column_sharded( + [w.attn_q_b[i] for w in self._weights_per_rank], + all_tensors[prefix + "self_attn.q_proj.bias"], + nh_shard * dh + ) + + # attn_k_w: [nkvh * dh, hs] -> [nkvh_shard * dh, hs] per rank + load_column_sharded( + [w.attn_k_w[i] for w in self._weights_per_rank], + all_tensors[prefix + "self_attn.k_proj.weight"], + nkvh_shard * dh + ) + load_column_sharded( + [w.attn_k_b[i] for w in self._weights_per_rank], + all_tensors[prefix + "self_attn.k_proj.bias"], + nkvh_shard * dh + ) + + # attn_v_w: [nkvh * dh, hs] -> [nkvh_shard * dh, hs] per rank + load_column_sharded( + [w.attn_v_w[i] for w in self._weights_per_rank], + all_tensors[prefix + "self_attn.v_proj.weight"], + nkvh_shard * dh + ) + load_column_sharded( + [w.attn_v_b[i] for w in self._weights_per_rank], + all_tensors[prefix + "self_attn.v_proj.bias"], + nkvh_shard * dh + ) + + # Row-shard O projection weight + # attn_o_w: [hs, nh * dh] -> [hs, nh_shard * dh] per rank + load_row_sharded( + [w.attn_o_w[i] for w in self._weights_per_rank], + all_tensors[prefix + "self_attn.o_proj.weight"], + nh_shard * dh + ) + + # Column-shard MLP gate and up weights + # mlp_gate_w: [di, hs] -> [di_shard, hs] per rank + load_column_sharded( + [w.mlp_gate_w[i] for w in self._weights_per_rank], + all_tensors[prefix + "mlp.gate_proj.weight"], + di_shard + ) + load_column_sharded( + [w.mlp_up_w[i] for w in self._weights_per_rank], + all_tensors[prefix + "mlp.up_proj.weight"], + di_shard + ) + + # Row-shard MLP down weight + # mlp_down_w: [hs, di] -> [hs, di_shard] per rank + load_row_sharded( + [w.mlp_down_w[i] for w in self._weights_per_rank], + all_tensors[prefix + "mlp.down_proj.weight"], + di_shard + ) + + def generate( + self, + inputs: Sequence[int], + max_new_tokens: int = 128, + top_k: int = 1, + top_p: float = 0.8, + temperature: float = 0.8, + ): + """Generate tokens using tensor parallel inference.""" + # Note: Currently only supports greedy decoding (top_k=1) + # Temperature and top_p are ignored for simplicity + + # Automatic warm-up on first call to initialize NCCL state + if self._needs_warmup: + self._warmup() + self._needs_warmup = False + + # Reset KV cache for new generation + LIB_LLAISYS.llaisysQwen2ModelTPResetCache(self._model) + + # Convert input to ctypes array + input_len = len(inputs) + input_arr = (c_int64 * input_len)(*inputs) + + # Output tokens list (starts with input) + output_tokens = list(inputs) + + # First forward pass with all input tokens + next_token = LIB_LLAISYS.llaisysQwen2ModelTPInfer( + self._model, + input_arr, + c_size_t(input_len) + ) + output_tokens.append(next_token) + + # Generate remaining tokens one by one + for _ in range(max_new_tokens - 1): + if next_token == self._meta.end_token: + break + + # Single token input + single_token = (c_int64 * 1)(next_token) + next_token = LIB_LLAISYS.llaisysQwen2ModelTPInfer( + self._model, + single_token, + c_size_t(1) + ) + output_tokens.append(next_token) + + return output_tokens + + @property + def world_size(self) -> int: + """Get the tensor parallel world size.""" + return self._world_size + + @property + def device_ids(self) -> List[int]: + """Get the list of GPU device IDs.""" + return self._device_ids.copy() diff --git a/src/device/metax/cuda_utils.cuh b/src/device/metax/cuda_utils.cuh new file mode 100644 index 00000000..d948c874 --- /dev/null +++ b/src/device/metax/cuda_utils.cuh @@ -0,0 +1,56 @@ +#pragma once + +#include +#include + +// Define the custom types directly in CUDA code +// These must match the definitions in utils/types.hpp +struct CustomFloat16 { + uint16_t _v; +}; +typedef struct CustomFloat16 fp16_t_metax; + +struct CustomBFloat16 { + uint16_t _v; +}; +typedef struct CustomBFloat16 bf16_t_metax; + +namespace llaisys::ops::metax { + +// Device functions for type conversion from custom types to float +__device__ inline float to_float_metax(float val) { return val; } + +__device__ inline float to_float_metax(fp16_t_metax val) { + __half h = __ushort_as_half(val._v); + return __half2float(h); +} + +__device__ inline float to_float_metax(bf16_t_metax val) { + uint32_t u = static_cast(val._v) << 16; + return __uint_as_float(u); +} + +// Device functions for type conversion from float to custom types +__device__ inline float from_float_metax(float val, float*) { return val; } + +__device__ inline fp16_t_metax from_float_metax(float val, fp16_t_metax*) { + __half h = __float2half(val); + fp16_t_metax result; + result._v = __half_as_ushort(h); + return result; +} + +__device__ inline bf16_t_metax from_float_metax(float val, bf16_t_metax*) { + uint32_t u = __float_as_uint(val); + bf16_t_metax result; + result._v = static_cast(u >> 16); + return result; +} + +// Template helper for converting to any type from float +template +__device__ inline T from_float_metax(float val) { + return from_float_metax(val, static_cast(nullptr)); +} + +} // namespace llaisys::ops::metax diff --git a/src/device/metax/metax_resource.cu b/src/device/metax/metax_resource.cu new file mode 100644 index 00000000..6be65739 --- /dev/null +++ b/src/device/metax/metax_resource.cu @@ -0,0 +1,27 @@ +#include "metax_resource.cuh" + +#include + +namespace llaisys::device::metax { + +// McblasHandle implementation +mcblasHandle_t& McblasHandle::get() { + static McblasHandle instance; + return instance.handle_; +} + +McblasHandle::McblasHandle() { + mcblasCreate(&handle_); +} + +McblasHandle::~McblasHandle() { + if (handle_) { + mcblasDestroy(handle_); + } +} + +Resource::Resource(int device_id) : llaisys::device::DeviceResource(LLAISYS_DEVICE_METAX, device_id) {} + +Resource::~Resource() = default; + +} // namespace llaisys::device::metax diff --git a/src/device/metax/metax_resource.cuh b/src/device/metax/metax_resource.cuh new file mode 100644 index 00000000..e1c1478b --- /dev/null +++ b/src/device/metax/metax_resource.cuh @@ -0,0 +1,30 @@ +#pragma once + +#include "../device_resource.hpp" + +// Forward declaration for mcblas types +typedef struct mcblasContext* mcblasHandle_t; + +namespace llaisys::device::metax { + +// Singleton mcblas handle for lazy initialization +class McblasHandle { +public: + static mcblasHandle_t& get(); + +private: + McblasHandle(); + ~McblasHandle(); + mcblasHandle_t handle_; + + McblasHandle(const McblasHandle&) = delete; + McblasHandle& operator=(const McblasHandle&) = delete; +}; + +class Resource : public llaisys::device::DeviceResource { +public: + Resource(int device_id); + ~Resource(); +}; + +} // namespace llaisys::device::metax diff --git a/src/device/metax/metax_runtime_api.cu b/src/device/metax/metax_runtime_api.cu new file mode 100644 index 00000000..dba49bdf --- /dev/null +++ b/src/device/metax/metax_runtime_api.cu @@ -0,0 +1,173 @@ +#include "../runtime_api.hpp" + +#include +#include +#include +#include + +namespace llaisys::device::metax { + +namespace runtime_api { + +constexpr size_t ALIGNMENT = 64; + +cudaMemcpyKind convertMemcpyKind(llaisysMemcpyKind_t kind) { + switch (kind) { + case LLAISYS_MEMCPY_H2H: + return cudaMemcpyHostToHost; + case LLAISYS_MEMCPY_H2D: + return cudaMemcpyHostToDevice; + case LLAISYS_MEMCPY_D2H: + return cudaMemcpyDeviceToHost; + case LLAISYS_MEMCPY_D2D: + return cudaMemcpyDeviceToDevice; + default: + return cudaMemcpyHostToHost; + } +} + +int getDeviceCount() { + int count = 0; + cudaGetDeviceCount(&count); + return count; +} + +void setDevice(int device_id) { + cudaSetDevice(device_id); +} + +void deviceSynchronize() { + cudaDeviceSynchronize(); +} + +llaisysStream_t createStream() { + cudaStream_t stream; + cudaStreamCreate(&stream); + return (llaisysStream_t)stream; +} + +void destroyStream(llaisysStream_t stream) { + cudaStreamDestroy((cudaStream_t)stream); +} + +void streamSynchronize(llaisysStream_t stream) { + cudaStreamSynchronize((cudaStream_t)stream); +} + +void *mallocDevice(size_t size) { + void *ptr = nullptr; + size_t aligned_size = ((size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + cudaMalloc(&ptr, aligned_size > 0 ? aligned_size : ALIGNMENT); + return ptr; +} + +void freeDevice(void *ptr) { + cudaFree(ptr); +} + +void *mallocHost(size_t size) { + void *ptr = nullptr; + size_t aligned_size = ((size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + posix_memalign(&ptr, ALIGNMENT, aligned_size > 0 ? aligned_size : ALIGNMENT); + return ptr; +} + +void freeHost(void *ptr) { + free(ptr); +} + +void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { + bool dst_aligned = (reinterpret_cast(dst) % ALIGNMENT) == 0; + bool src_aligned = (reinterpret_cast(src) % ALIGNMENT) == 0; + bool size_aligned = (size % ALIGNMENT) == 0; + + // For MetaX, we need to handle misaligned D2D copies specially + // because mcMemcpy requires both source and destination to be aligned + if (kind == LLAISYS_MEMCPY_D2D && (!dst_aligned || !src_aligned || !size_aligned)) { + // Use a kernel to copy data element by element to avoid alignment issues + // This is slower but works for misaligned addresses + void *temp_buf = nullptr; + size_t aligned_size = ((size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + cudaMalloc(&temp_buf, aligned_size); + + if (temp_buf) { + // D2H to temp buffer (temp_buf is aligned from cudaMalloc) + cudaError_t err = cudaMemcpy(temp_buf, src, size, cudaMemcpyDeviceToHost); + if (err == cudaSuccess) { + // H2D from temp buffer to destination + err = cudaMemcpy(dst, temp_buf, size, cudaMemcpyHostToDevice); + } + if (err != cudaSuccess) { + fprintf(stderr, "[MetaX] D2D copy via temp buffer failed: %s\n", cudaGetErrorString(err)); + } + cudaFree(temp_buf); + } + return; + } + + if (dst_aligned && src_aligned && size_aligned) { + cudaError_t err = cudaMemcpy(dst, src, size, convertMemcpyKind(kind)); + if (err != cudaSuccess) { + fprintf(stderr, "[MetaX] cudaMemcpy failed: %s (dst=%p, src=%p, size=%zu)\n", + cudaGetErrorString(err), dst, src, size); + } + } else { + void *aligned_buf = nullptr; + size_t aligned_size = ((size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + posix_memalign(&aligned_buf, ALIGNMENT, aligned_size); + + cudaError_t err = cudaSuccess; + switch (kind) { + case LLAISYS_MEMCPY_H2D: + std::memcpy(aligned_buf, src, size); + err = cudaMemcpy(dst, aligned_buf, size, cudaMemcpyHostToDevice); + break; + case LLAISYS_MEMCPY_D2H: + err = cudaMemcpy(aligned_buf, src, size, cudaMemcpyDeviceToHost); + std::memcpy(dst, aligned_buf, size); + break; + case LLAISYS_MEMCPY_D2D: + // This should not happen due to the check above, but handle it anyway + err = cudaMemcpy(aligned_buf, src, size, cudaMemcpyDeviceToHost); + if (err == cudaSuccess) { + err = cudaMemcpy(dst, aligned_buf, size, cudaMemcpyHostToDevice); + } + break; + default: + std::memcpy(dst, src, size); + break; + } + + if (err != cudaSuccess) { + fprintf(stderr, "[MetaX] cudaMemcpy (fallback) failed: %s (dst=%p, src=%p, size=%zu)\n", + cudaGetErrorString(err), dst, src, size); + } + + free(aligned_buf); + } +} + +void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) { + cudaMemcpyAsync(dst, src, size, convertMemcpyKind(kind), (cudaStream_t)stream); +} + +static const LlaisysRuntimeAPI RUNTIME_API = { + &getDeviceCount, + &setDevice, + &deviceSynchronize, + &createStream, + &destroyStream, + &streamSynchronize, + &mallocDevice, + &freeDevice, + &mallocHost, + &freeHost, + &memcpySync, + &memcpyAsync}; + +} // namespace runtime_api + +const LlaisysRuntimeAPI *getRuntimeAPI() { + return &runtime_api::RUNTIME_API; +} +} // namespace llaisys::device::metax diff --git a/src/device/nvidia/cuda_utils.cuh b/src/device/nvidia/cuda_utils.cuh new file mode 100644 index 00000000..cf849ca5 --- /dev/null +++ b/src/device/nvidia/cuda_utils.cuh @@ -0,0 +1,56 @@ +#pragma once + +#include +#include + +// Define the custom types directly in CUDA code +// These must match the definitions in utils/types.hpp +struct CustomFloat16 { + uint16_t _v; +}; +typedef struct CustomFloat16 fp16_t_cuda; + +struct CustomBFloat16 { + uint16_t _v; +}; +typedef struct CustomBFloat16 bf16_t_cuda; + +namespace llaisys::ops::nvidia { + +// Device functions for type conversion from custom types to float +__device__ inline float to_float_cuda(float val) { return val; } + +__device__ inline float to_float_cuda(fp16_t_cuda val) { + __half h = __ushort_as_half(val._v); + return __half2float(h); +} + +__device__ inline float to_float_cuda(bf16_t_cuda val) { + uint32_t u = static_cast(val._v) << 16; + return __uint_as_float(u); +} + +// Device functions for type conversion from float to custom types +__device__ inline float from_float_cuda(float val, float*) { return val; } + +__device__ inline fp16_t_cuda from_float_cuda(float val, fp16_t_cuda*) { + __half h = __float2half(val); + fp16_t_cuda result; + result._v = __half_as_ushort(h); + return result; +} + +__device__ inline bf16_t_cuda from_float_cuda(float val, bf16_t_cuda*) { + uint32_t u = __float_as_uint(val); + bf16_t_cuda result; + result._v = static_cast(u >> 16); + return result; +} + +// Template helper for converting to any type from float +template +__device__ inline T from_float_cuda(float val) { + return from_float_cuda(val, static_cast(nullptr)); +} + +} // namespace llaisys::ops::nvidia diff --git a/src/device/nvidia/nccl_communicator.hpp b/src/device/nvidia/nccl_communicator.hpp new file mode 100644 index 00000000..6b74c5a6 --- /dev/null +++ b/src/device/nvidia/nccl_communicator.hpp @@ -0,0 +1,102 @@ +#pragma once + +#include +#include +#include +#include + +// Forward declare NCCL types +typedef struct ncclComm* ncclComm_t; + +// Forward declare CUDA stream type (it's a pointer type in CUDA) +// We use a different name to avoid conflicts +typedef void* LlaisysCudaStream_t; + +namespace llaisys { +namespace device { +namespace nvidia { + +// NCCL data type enum (matches NCCL definitions) +enum class NCCLDataType { + Float32 = 0, + Float64 = 1, + Float16 = 2, + Bfloat16 = 3, + Int8 = 4, + Int32 = 5, + Int64 = 6, + Uint8 = 7, + Uint32 = 8, + Uint64 = 9 +}; + +// NCCL Communicator for tensor parallel operations +// This class manages NCCL communication across multiple GPUs +class NCCLCommunicator { +public: + // Create communicators for a group of GPUs (one per rank) + // This should be called once to create all communicators + static std::vector> createAll( + const std::vector& device_ids); + + ~NCCLCommunicator(); + + // Disable copy + NCCLCommunicator(const NCCLCommunicator&) = delete; + NCCLCommunicator& operator=(const NCCLCommunicator&) = delete; + + // Getters + int world_size() const { return world_size_; } + int rank() const { return rank_; } + int device_id() const { return device_id_; } + const std::vector& device_ids() const { return device_ids_; } + + // Get internal handles (for CUDA code) + ncclComm_t comm() const; + LlaisysCudaStream_t stream() const; + + // Set device for current thread + void setDevice() const; + + // All-Reduce: sum operation + void allReduce(void* buff, size_t count, int dtype); + void allReduce(const void* sendbuff, void* recvbuff, size_t count, int dtype); + + // All-Reduce async + void allReduceAsync(void* buff, size_t count, int dtype); + void allReduceAsync(const void* sendbuff, void* recvbuff, size_t count, int dtype); + + // Synchronize stream + void streamSynchronize(); + + // Synchronize all devices in the group + static void synchronizeAll(const std::vector>& comms); + +public: + // Constructor for direct creation (used by worker threads) + NCCLCommunicator(int rank, int world_size, const std::vector& device_ids, + ncclComm_t comm, LlaisysCudaStream_t stream); + +private: + + // Allow implementation to access private members + class Impl; + friend class Impl; + friend class NCCLCommunicatorImpl; + + std::vector device_ids_; + int world_size_; + int rank_; + int device_id_; + std::unique_ptr impl_; +}; + +// Helper to convert data type to NCCL type +int toNCCLDataType(int dtype); + +// Get element size for NCCL data type +size_t getNCCLDataTypeSize(int dtype); + +} // namespace nvidia +} // namespace device +} // namespace llaisys diff --git a/src/device/nvidia/nccl_communicator_impl.cu b/src/device/nvidia/nccl_communicator_impl.cu new file mode 100644 index 00000000..b5939ad4 --- /dev/null +++ b/src/device/nvidia/nccl_communicator_impl.cu @@ -0,0 +1,223 @@ +#include "nccl_communicator.hpp" +#include "nccl_communicator_impl.cuh" +#include +#include +#include +#include +#include +#include +#include + +namespace llaisys { +namespace device { +namespace nvidia { + +// Impl structure definition +struct NCCLCommunicator::Impl { + ncclComm_t comm; + LlaisysCudaStream_t stream; +}; + +// Simple barrier implementation for C++17 +class SimpleBarrier { +public: + explicit SimpleBarrier(int count) : threshold_(count), count_(count), generation_(0) {} + + void arrive_and_wait() { + std::unique_lock lock(mutex_); + int gen = generation_; + if (--count_ == 0) { + generation_++; + count_ = threshold_; + cv_.notify_all(); + } else { + cv_.wait(lock, [this, gen] { return gen != generation_; }); + } + } + +private: + std::mutex mutex_; + std::condition_variable cv_; + int threshold_; + int count_; + int generation_; +}; + +// Static method to create all communicators +std::vector> NCCLCommunicator::createAll( + const std::vector& device_ids) { + + int world_size = static_cast(device_ids.size()); + if (world_size == 0) { + throw std::invalid_argument("Device IDs cannot be empty"); + } + + // Initialize CUDA driver + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + + // Initialize CUDA context on first device + CUDA_CHECK(cudaSetDevice(device_ids[0])); + // Force driver initialization + cudaFree(0); + CUDA_CHECK(cudaDeviceSynchronize()); + + ncclUniqueId id; + NCCL_CHECK(ncclGetUniqueId(&id)); + + std::vector> communicators(world_size); + std::vector threads; + std::mutex comm_mutex; + std::exception_ptr init_exception; + SimpleBarrier sync_barrier(world_size); + + for (int rank = 0; rank < world_size; ++rank) { + threads.emplace_back([&, rank]() { + try { + int device_id = device_ids[rank]; + CUDA_CHECK(cudaSetDevice(device_id)); + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + ncclUniqueId local_id; + std::memcpy(&local_id, &id, sizeof(ncclUniqueId)); + sync_barrier.arrive_and_wait(); + + ncclComm_t comm; + NCCL_CHECK(ncclCommInitRank(&comm, world_size, local_id, rank)); + + auto comm_obj = std::shared_ptr( + new NCCLCommunicator(rank, world_size, device_ids, comm, + reinterpret_cast(stream))); + + { + std::lock_guard lock(comm_mutex); + communicators[rank] = comm_obj; + } + } catch (...) { + std::lock_guard lock(comm_mutex); + if (!init_exception) { + init_exception = std::current_exception(); + } + } + }); + } + + for (auto& t : threads) { + t.join(); + } + + if (init_exception) { + std::rethrow_exception(init_exception); + } + + return communicators; +} + +NCCLCommunicator::NCCLCommunicator(int rank, int world_size, + const std::vector& device_ids, + ncclComm_t comm, LlaisysCudaStream_t stream) + : device_ids_(device_ids), world_size_(world_size), rank_(rank), + device_id_(device_ids[rank]) { + impl_ = std::make_unique(); + impl_->comm = comm; + impl_->stream = stream; +} + +NCCLCommunicator::~NCCLCommunicator() { + if (impl_) { + if (impl_->stream) { + cudaStreamDestroy(reinterpret_cast(impl_->stream)); + } + if (impl_->comm) { + ncclCommDestroy(impl_->comm); + } + } +} + +void NCCLCommunicator::setDevice() const { + CUDA_CHECK(cudaSetDevice(device_id_)); +} + +void NCCLCommunicator::allReduce(void* buff, size_t count, int dtype) { + CUDA_CHECK(cudaSetDevice(device_id_)); + ncclDataType_t nccl_dtype = static_cast(toNCCLDataType(dtype)); + cudaStream_t stream = reinterpret_cast(impl_->stream); + NCCL_CHECK(ncclAllReduce(buff, buff, count, nccl_dtype, ncclSum, impl_->comm, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +void NCCLCommunicator::allReduce(const void* sendbuff, void* recvbuff, + size_t count, int dtype) { + ncclDataType_t nccl_dtype = static_cast(toNCCLDataType(dtype)); + cudaStream_t stream = reinterpret_cast(impl_->stream); + NCCL_CHECK(ncclAllReduce(sendbuff, recvbuff, count, nccl_dtype, ncclSum, impl_->comm, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +void NCCLCommunicator::allReduceAsync(void* buff, size_t count, int dtype) { + ncclDataType_t nccl_dtype = static_cast(toNCCLDataType(dtype)); + cudaStream_t stream = reinterpret_cast(impl_->stream); + NCCL_CHECK(ncclAllReduce(buff, buff, count, nccl_dtype, ncclSum, impl_->comm, stream)); +} + +void NCCLCommunicator::allReduceAsync(const void* sendbuff, void* recvbuff, + size_t count, int dtype) { + ncclDataType_t nccl_dtype = static_cast(toNCCLDataType(dtype)); + cudaStream_t stream = reinterpret_cast(impl_->stream); + NCCL_CHECK(ncclAllReduce(sendbuff, recvbuff, count, nccl_dtype, ncclSum, impl_->comm, stream)); +} + +void NCCLCommunicator::streamSynchronize() { + CUDA_CHECK(cudaStreamSynchronize(reinterpret_cast(impl_->stream))); +} + +void NCCLCommunicator::synchronizeAll( + const std::vector>& comms) { + for (const auto& comm : comms) { + if (comm) { + CUDA_CHECK(cudaSetDevice(comm->device_id())); + CUDA_CHECK(cudaDeviceSynchronize()); + } + } +} + +// Map llaisys DataType to NCCL data type +// llaisys DataType enum (from python/llaisys/libllaisys/llaisys_types.py): +// F32 = 13, F16 = 12, BF16 = 19, I32 = 5, I64 = 6 +int toNCCLDataType(int dtype) { + switch (dtype) { + case 13: return ncclFloat32; // F32 + case 12: return ncclFloat16; // F16 + case 19: return ncclBfloat16; // BF16 + case 5: return ncclInt32; // I32 + case 6: return ncclInt64; // I64 + default: throw std::invalid_argument("Unsupported data type for NCCL: " + std::to_string(dtype)); + } +} + +size_t getNCCLDataTypeSize(int dtype) { + ncclDataType_t nccl_dtype = static_cast(dtype); + switch (nccl_dtype) { + case ncclFloat32: + case ncclInt32: + return 4; + case ncclFloat16: + case ncclBfloat16: + return 2; + case ncclInt64: + case ncclUint64: + case ncclFloat64: + return 8; + case ncclInt8: + case ncclUint8: + return 1; + default: + throw std::invalid_argument("Unknown NCCL data type"); + } +} + +} // namespace nvidia +} // namespace device +} // namespace llaisys diff --git a/src/device/nvidia/nccl_communicator_impl.cuh b/src/device/nvidia/nccl_communicator_impl.cuh new file mode 100644 index 00000000..2dfdef66 --- /dev/null +++ b/src/device/nvidia/nccl_communicator_impl.cuh @@ -0,0 +1,25 @@ +#pragma once + +#include "nccl_communicator.hpp" +#include +#include + +// NCCL error checking macro +#define NCCL_CHECK(cmd) \ + do { \ + ncclResult_t result = cmd; \ + if (result != ncclSuccess) { \ + throw std::runtime_error(std::string("NCCL error: ") + \ + ncclGetErrorString(result)); \ + } \ + } while (0) + +// CUDA error checking macro +#define CUDA_CHECK(cmd) \ + do { \ + cudaError_t result = cmd; \ + if (result != cudaSuccess) { \ + throw std::runtime_error(std::string("CUDA error: ") + \ + cudaGetErrorString(result)); \ + } \ + } while (0) diff --git a/src/device/nvidia/nvidia_resource.cu b/src/device/nvidia/nvidia_resource.cu index 2e63647e..01ecfd4a 100644 --- a/src/device/nvidia/nvidia_resource.cu +++ b/src/device/nvidia/nvidia_resource.cu @@ -4,4 +4,6 @@ namespace llaisys::device::nvidia { Resource::Resource(int device_id) : llaisys::device::DeviceResource(LLAISYS_DEVICE_NVIDIA, device_id) {} +Resource::~Resource() = default; + } // namespace llaisys::device::nvidia diff --git a/src/device/nvidia/nvidia_runtime_api.cu b/src/device/nvidia/nvidia_runtime_api.cu index cab92826..493517a1 100644 --- a/src/device/nvidia/nvidia_runtime_api.cu +++ b/src/device/nvidia/nvidia_runtime_api.cu @@ -1,56 +1,81 @@ #include "../runtime_api.hpp" -#include -#include +#include namespace llaisys::device::nvidia { namespace runtime_api { + +// Helper function to convert llaisysMemcpyKind_t to cudaMemcpyKind +cudaMemcpyKind convertMemcpyKind(llaisysMemcpyKind_t kind) { + switch (kind) { + case LLAISYS_MEMCPY_H2H: + return cudaMemcpyHostToHost; + case LLAISYS_MEMCPY_H2D: + return cudaMemcpyHostToDevice; + case LLAISYS_MEMCPY_D2H: + return cudaMemcpyDeviceToHost; + case LLAISYS_MEMCPY_D2D: + return cudaMemcpyDeviceToDevice; + default: + return cudaMemcpyHostToHost; + } +} + int getDeviceCount() { - TO_BE_IMPLEMENTED(); + int count = 0; + cudaGetDeviceCount(&count); + return count; } -void setDevice(int) { - TO_BE_IMPLEMENTED(); +void setDevice(int device_id) { + cudaSetDevice(device_id); } void deviceSynchronize() { - TO_BE_IMPLEMENTED(); + cudaDeviceSynchronize(); } llaisysStream_t createStream() { - TO_BE_IMPLEMENTED(); + cudaStream_t stream; + cudaStreamCreate(&stream); + return (llaisysStream_t)stream; } void destroyStream(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + cudaStreamDestroy((cudaStream_t)stream); } + void streamSynchronize(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + cudaStreamSynchronize((cudaStream_t)stream); } void *mallocDevice(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr = nullptr; + cudaMalloc(&ptr, size); + return ptr; } void freeDevice(void *ptr) { - TO_BE_IMPLEMENTED(); + cudaFree(ptr); } void *mallocHost(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr = nullptr; + cudaMallocHost(&ptr, size); + return ptr; } void freeHost(void *ptr) { - TO_BE_IMPLEMENTED(); + cudaFreeHost(ptr); } void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); + cudaMemcpy(dst, src, size, convertMemcpyKind(kind)); } -void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); +void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) { + cudaMemcpyAsync(dst, src, size, convertMemcpyKind(kind), (cudaStream_t)stream); } static const LlaisysRuntimeAPI RUNTIME_API = { diff --git a/src/device/runtime_api.cpp b/src/device/runtime_api.cpp index 2de3eca0..233afa89 100644 --- a/src/device/runtime_api.cpp +++ b/src/device/runtime_api.cpp @@ -80,6 +80,12 @@ const LlaisysRuntimeAPI *getRuntimeAPI(llaisysDeviceType_t device_type) { return llaisys::device::nvidia::getRuntimeAPI(); #else return getUnsupportedRuntimeAPI(); +#endif + case LLAISYS_DEVICE_METAX: +#ifdef ENABLE_METAX_API + return llaisys::device::metax::getRuntimeAPI(); +#else + return getUnsupportedRuntimeAPI(); #endif default: EXCEPTION_UNSUPPORTED_DEVICE; diff --git a/src/device/runtime_api.hpp b/src/device/runtime_api.hpp index e6b9f80d..0e94644f 100644 --- a/src/device/runtime_api.hpp +++ b/src/device/runtime_api.hpp @@ -17,4 +17,10 @@ namespace nvidia { const LlaisysRuntimeAPI *getRuntimeAPI(); } #endif + +#ifdef ENABLE_METAX_API +namespace metax { +const LlaisysRuntimeAPI *getRuntimeAPI(); +} +#endif } // namespace llaisys::device diff --git a/src/llaisys/models/qwen2_tp.cc b/src/llaisys/models/qwen2_tp.cc new file mode 100644 index 00000000..74460dd9 --- /dev/null +++ b/src/llaisys/models/qwen2_tp.cc @@ -0,0 +1,8 @@ +// Include the actual TP implementation to ensure symbols are exported +// This file is compiled as part of the shared library + +// Define LLAISYS_BUILDING_SHARED to ensure proper export +#define LLAISYS_BUILDING_SHARED + +// Include the implementation file +#include "../../models/qwen2/qwen2_tp.cpp" diff --git a/src/llaisys/qwen2.cc b/src/llaisys/qwen2.cc new file mode 100644 index 00000000..1ba7d5ea --- /dev/null +++ b/src/llaisys/qwen2.cc @@ -0,0 +1,119 @@ +#include "llaisys/models/qwen2.h" +#include "../models/qwen2/qwen2.hpp" +#include "llaisys_tensor.hpp" + +__C { + +struct LlaisysQwen2Model { + llaisys::models::Qwen2Model* model; + LlaisysQwen2Weights weights; +}; + +__export struct LlaisysQwen2Model* llaisysQwen2ModelCreate( + const LlaisysQwen2Meta* meta, + llaisysDeviceType_t device, + int* device_ids, + int ndevice) { + + auto* wrapper = new LlaisysQwen2Model(); + int device_id = (device_ids && ndevice > 0) ? device_ids[0] : 0; + wrapper->model = new llaisys::models::Qwen2Model(meta, device, device_id); + + // Set up weights pointers + auto& m = *wrapper->model; + auto& w = wrapper->weights; + size_t nlayer = meta->nlayer; + + w.in_embed = new LlaisysTensor{m.in_embed}; + w.out_embed = new LlaisysTensor{m.out_embed}; + w.out_norm_w = new LlaisysTensor{m.out_norm_w}; + + w.attn_norm_w = new llaisysTensor_t[nlayer]; + w.attn_q_w = new llaisysTensor_t[nlayer]; + w.attn_q_b = new llaisysTensor_t[nlayer]; + w.attn_k_w = new llaisysTensor_t[nlayer]; + w.attn_k_b = new llaisysTensor_t[nlayer]; + w.attn_v_w = new llaisysTensor_t[nlayer]; + w.attn_v_b = new llaisysTensor_t[nlayer]; + w.attn_o_w = new llaisysTensor_t[nlayer]; + w.mlp_norm_w = new llaisysTensor_t[nlayer]; + w.mlp_gate_w = new llaisysTensor_t[nlayer]; + w.mlp_up_w = new llaisysTensor_t[nlayer]; + w.mlp_down_w = new llaisysTensor_t[nlayer]; + + for (size_t i = 0; i < nlayer; ++i) { + w.attn_norm_w[i] = new LlaisysTensor{m.attn_norm_w[i]}; + w.attn_q_w[i] = new LlaisysTensor{m.attn_q_w[i]}; + w.attn_q_b[i] = new LlaisysTensor{m.attn_q_b[i]}; + w.attn_k_w[i] = new LlaisysTensor{m.attn_k_w[i]}; + w.attn_k_b[i] = new LlaisysTensor{m.attn_k_b[i]}; + w.attn_v_w[i] = new LlaisysTensor{m.attn_v_w[i]}; + w.attn_v_b[i] = new LlaisysTensor{m.attn_v_b[i]}; + w.attn_o_w[i] = new LlaisysTensor{m.attn_o_w[i]}; + w.mlp_norm_w[i] = new LlaisysTensor{m.mlp_norm_w[i]}; + w.mlp_gate_w[i] = new LlaisysTensor{m.mlp_gate_w[i]}; + w.mlp_up_w[i] = new LlaisysTensor{m.mlp_up_w[i]}; + w.mlp_down_w[i] = new LlaisysTensor{m.mlp_down_w[i]}; + } + + return wrapper; +} + +__export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model* model) { + if (!model) return; + + size_t nlayer = model->model->meta.nlayer; + auto& w = model->weights; + + delete w.in_embed; + delete w.out_embed; + delete w.out_norm_w; + + for (size_t i = 0; i < nlayer; ++i) { + delete w.attn_norm_w[i]; + delete w.attn_q_w[i]; + delete w.attn_q_b[i]; + delete w.attn_k_w[i]; + delete w.attn_k_b[i]; + delete w.attn_v_w[i]; + delete w.attn_v_b[i]; + delete w.attn_o_w[i]; + delete w.mlp_norm_w[i]; + delete w.mlp_gate_w[i]; + delete w.mlp_up_w[i]; + delete w.mlp_down_w[i]; + } + + delete[] w.attn_norm_w; + delete[] w.attn_q_w; + delete[] w.attn_q_b; + delete[] w.attn_k_w; + delete[] w.attn_k_b; + delete[] w.attn_v_w; + delete[] w.attn_v_b; + delete[] w.attn_o_w; + delete[] w.mlp_norm_w; + delete[] w.mlp_gate_w; + delete[] w.mlp_up_w; + delete[] w.mlp_down_w; + + delete model->model; + delete model; +} + +__export struct LlaisysQwen2Weights* llaisysQwen2ModelWeights(struct LlaisysQwen2Model* model) { + return &model->weights; +} + +__export int64_t llaisysQwen2ModelInfer( + struct LlaisysQwen2Model* model, + int64_t* token_ids, + size_t ntoken) { + return model->model->infer(token_ids, ntoken); +} + +__export void llaisysQwen2ModelResetCache(struct LlaisysQwen2Model* model) { + model->model->cache_len = 0; +} + +} diff --git a/src/llaisys/qwen2_tp.cc b/src/llaisys/qwen2_tp.cc new file mode 100644 index 00000000..76ced296 --- /dev/null +++ b/src/llaisys/qwen2_tp.cc @@ -0,0 +1,108 @@ +#include "llaisys/models/qwen2_tp.h" +#include "llaisys/models/qwen2.h" +#include "../models/qwen2/qwen2_tp.hpp" +#include "llaisys_tensor.hpp" + +#include + +__C { + +struct LlaisysQwen2ModelTP { + llaisys::models::Qwen2ModelTP* model; + std::vector weights_per_rank; +}; + +__export struct LlaisysQwen2ModelTP* llaisysQwen2ModelTPCreate( + const struct LlaisysQwen2Meta* meta, + const int* device_ids, + int world_size) { + + // Convert device_ids to vector + std::vector device_vec(device_ids, device_ids + world_size); + + auto* wrapper = new LlaisysQwen2ModelTP(); + wrapper->model = new llaisys::models::Qwen2ModelTP(meta, device_vec); + + // Get weights for each rank + wrapper->weights_per_rank.resize(world_size); + for (int i = 0; i < world_size; ++i) { + wrapper->weights_per_rank[i] = wrapper->model->getWeights(i); + } + + return wrapper; +} + +__export void llaisysQwen2ModelTPDestroy(struct LlaisysQwen2ModelTP* model) { + if (!model) return; + + // Clean up weights for each rank + for (auto* weights : model->weights_per_rank) { + if (!weights) continue; + + size_t nlayer = model->model->meta.nlayer; + + delete weights->in_embed; + delete weights->out_embed; + delete weights->out_norm_w; + + for (size_t i = 0; i < nlayer; ++i) { + delete weights->attn_norm_w[i]; + delete weights->attn_q_w[i]; + delete weights->attn_q_b[i]; + delete weights->attn_k_w[i]; + delete weights->attn_k_b[i]; + delete weights->attn_v_w[i]; + delete weights->attn_v_b[i]; + delete weights->attn_o_w[i]; + delete weights->mlp_norm_w[i]; + delete weights->mlp_gate_w[i]; + delete weights->mlp_up_w[i]; + delete weights->mlp_down_w[i]; + } + + delete[] weights->attn_norm_w; + delete[] weights->attn_q_w; + delete[] weights->attn_q_b; + delete[] weights->attn_k_w; + delete[] weights->attn_k_b; + delete[] weights->attn_v_w; + delete[] weights->attn_v_b; + delete[] weights->attn_o_w; + delete[] weights->mlp_norm_w; + delete[] weights->mlp_gate_w; + delete[] weights->mlp_up_w; + delete[] weights->mlp_down_w; + + delete weights; + } + + delete model->model; + delete model; +} + +__export struct LlaisysQwen2Weights* llaisysQwen2ModelTPWeights( + struct LlaisysQwen2ModelTP* model, + int rank) { + if (!model || rank < 0 || rank >= static_cast(model->weights_per_rank.size())) { + return nullptr; + } + return model->weights_per_rank[rank]; +} + +__export int64_t llaisysQwen2ModelTPInfer( + struct LlaisysQwen2ModelTP* model, + const int64_t* token_ids, + size_t ntoken) { + // Cast away const for the internal API + return model->model->infer(const_cast(token_ids), ntoken); +} + +__export void llaisysQwen2ModelTPResetCache(struct LlaisysQwen2ModelTP* model) { + model->model->resetCache(); +} + +__export int llaisysQwen2ModelTPGetWorldSize(struct LlaisysQwen2ModelTP* model) { + return model->model->getWorldSize(); +} + +} // __C diff --git a/src/models/qwen2/qwen2.cpp b/src/models/qwen2/qwen2.cpp new file mode 100644 index 00000000..97529ada --- /dev/null +++ b/src/models/qwen2/qwen2.cpp @@ -0,0 +1,265 @@ +#include "qwen2.hpp" +#include +#include + +namespace llaisys { +namespace models { + +Qwen2Model::Qwen2Model(const LlaisysQwen2Meta* meta_, llaisysDeviceType_t device, int dev_id) + : device_type(device), device_id(dev_id), cache_len(0) { + std::memcpy(&meta, meta_, sizeof(LlaisysQwen2Meta)); + allocateWeights(); + allocateCache(); + allocateBuffers(meta.maxseq); +} + +void Qwen2Model::allocateWeights() { + auto dtype = meta.dtype; + size_t nlayer = meta.nlayer; + size_t hs = meta.hs; + size_t nh = meta.nh; + size_t nkvh = meta.nkvh; + size_t dh = meta.dh; + size_t di = meta.di; + size_t voc = meta.voc; + + // Embedding weights + in_embed = Tensor::create({voc, hs}, dtype, device_type, device_id); + out_embed = Tensor::create({voc, hs}, dtype, device_type, device_id); + out_norm_w = Tensor::create({hs}, dtype, device_type, device_id); + + // Per-layer weights + attn_norm_w.resize(nlayer); + attn_q_w.resize(nlayer); + attn_q_b.resize(nlayer); + attn_k_w.resize(nlayer); + attn_k_b.resize(nlayer); + attn_v_w.resize(nlayer); + attn_v_b.resize(nlayer); + attn_o_w.resize(nlayer); + mlp_norm_w.resize(nlayer); + mlp_gate_w.resize(nlayer); + mlp_up_w.resize(nlayer); + mlp_down_w.resize(nlayer); + + for (size_t i = 0; i < nlayer; ++i) { + attn_norm_w[i] = Tensor::create({hs}, dtype, device_type, device_id); + attn_q_w[i] = Tensor::create({nh * dh, hs}, dtype, device_type, device_id); + attn_q_b[i] = Tensor::create({nh * dh}, dtype, device_type, device_id); + attn_k_w[i] = Tensor::create({nkvh * dh, hs}, dtype, device_type, device_id); + attn_k_b[i] = Tensor::create({nkvh * dh}, dtype, device_type, device_id); + attn_v_w[i] = Tensor::create({nkvh * dh, hs}, dtype, device_type, device_id); + attn_v_b[i] = Tensor::create({nkvh * dh}, dtype, device_type, device_id); + attn_o_w[i] = Tensor::create({hs, nh * dh}, dtype, device_type, device_id); + mlp_norm_w[i] = Tensor::create({hs}, dtype, device_type, device_id); + mlp_gate_w[i] = Tensor::create({di, hs}, dtype, device_type, device_id); + mlp_up_w[i] = Tensor::create({di, hs}, dtype, device_type, device_id); + mlp_down_w[i] = Tensor::create({hs, di}, dtype, device_type, device_id); + } +} + +void Qwen2Model::allocateCache() { + size_t nlayer = meta.nlayer; + size_t maxseq = meta.maxseq; + size_t nkvh = meta.nkvh; + size_t dh = meta.dh; + auto dtype = meta.dtype; + + k_cache.resize(nlayer); + v_cache.resize(nlayer); + + for (size_t i = 0; i < nlayer; ++i) { + k_cache[i] = Tensor::create({maxseq, nkvh, dh}, dtype, device_type, device_id); + v_cache[i] = Tensor::create({maxseq, nkvh, dh}, dtype, device_type, device_id); + } +} + +void Qwen2Model::allocateBuffers(size_t max_seqlen) { + auto dtype = meta.dtype; + size_t hs = meta.hs; + size_t nh = meta.nh; + size_t nkvh = meta.nkvh; + size_t dh = meta.dh; + size_t di = meta.di; + size_t voc = meta.voc; + + hidden = Tensor::create({max_seqlen, hs}, dtype, device_type, device_id); + hidden_norm = Tensor::create({max_seqlen, hs}, dtype, device_type, device_id); + q = Tensor::create({max_seqlen, nh * dh}, dtype, device_type, device_id); + k = Tensor::create({max_seqlen, nkvh * dh}, dtype, device_type, device_id); + v = Tensor::create({max_seqlen, nkvh * dh}, dtype, device_type, device_id); + q_rope = Tensor::create({max_seqlen, nh, dh}, dtype, device_type, device_id); + k_rope = Tensor::create({max_seqlen, nkvh, dh}, dtype, device_type, device_id); + attn_out = Tensor::create({max_seqlen, nh, dh}, dtype, device_type, device_id); + attn_proj = Tensor::create({max_seqlen, hs}, dtype, device_type, device_id); + gate = Tensor::create({max_seqlen, di}, dtype, device_type, device_id); + up = Tensor::create({max_seqlen, di}, dtype, device_type, device_id); + mlp_out = Tensor::create({max_seqlen, hs}, dtype, device_type, device_id); + logits = Tensor::create({1, voc}, dtype, device_type, device_id); + max_idx = Tensor::create({1}, LLAISYS_DTYPE_I64, device_type, device_id); + max_val = Tensor::create({1}, dtype, device_type, device_id); + pos_ids = Tensor::create({max_seqlen}, LLAISYS_DTYPE_I64, device_type, device_id); +} + +int64_t Qwen2Model::infer(int64_t* token_ids, size_t ntoken) { + // Create input tensor for embedding lookup + tensor_t input_ids = Tensor::create({ntoken}, LLAISYS_DTYPE_I64, device_type, device_id); + // For MetaX, ensure alignment by using an aligned buffer if needed + if (device_type != LLAISYS_DEVICE_CPU && (reinterpret_cast(token_ids) % 64) != 0) { + // Copy to aligned buffer + alignas(64) int64_t aligned_tokens[4096]; // Max sequence length + if (ntoken <= 4096) { + std::memcpy(aligned_tokens, token_ids, ntoken * sizeof(int64_t)); + input_ids->load(aligned_tokens); + } else { + input_ids->load(token_ids); + } + } else { + input_ids->load(token_ids); + } + + // Get hidden states view for current sequence + tensor_t hidden_view = hidden->slice(0, 0, ntoken); + tensor_t in_embed_view = in_embed; + + // Embedding lookup + ops::embedding(hidden_view, input_ids, in_embed_view); + + // Forward pass through all layers + forward(ntoken, cache_len); + + // Update cache length + cache_len += ntoken; + + // Final layer norm on last token + tensor_t last_hidden = hidden->slice(0, ntoken - 1, ntoken); + tensor_t last_norm = hidden_norm->slice(0, 0, 1); + ops::rms_norm(last_norm, last_hidden, out_norm_w, meta.epsilon); + + // Compute logits for last token + ops::linear(logits, last_norm, out_embed, nullptr); + + // Get last row of logits + tensor_t last_logits = logits->view({meta.voc}); + + // Argmax to get predicted token + ops::argmax(max_idx, max_val, last_logits); + + // Read result + int64_t result = 0; + if (device_type == LLAISYS_DEVICE_CPU) { + result = *reinterpret_cast(max_idx->data()); + } else { + // Copy from device to host + // Use a properly aligned buffer for MetaX + alignas(64) int64_t aligned_result; + core::context().setDevice(device_type, device_id); + core::context().runtime().api()->memcpy_sync( + &aligned_result, max_idx->data(), sizeof(int64_t), LLAISYS_MEMCPY_D2H); + result = aligned_result; + } + + return result; +} + +void Qwen2Model::forward(size_t seqlen, size_t start_pos) { + for (size_t layer = 0; layer < meta.nlayer; ++layer) { + forwardLayer(layer, seqlen, start_pos); + } +} + +void Qwen2Model::forwardLayer(size_t layer, size_t seqlen, size_t start_pos) { + size_t nh = meta.nh; + size_t nkvh = meta.nkvh; + size_t dh = meta.dh; + size_t total_len = start_pos + seqlen; + + // Get views for current sequence length + tensor_t hidden_view = hidden->slice(0, 0, seqlen); + tensor_t norm_view = hidden_norm->slice(0, 0, seqlen); + tensor_t q_view = q->slice(0, 0, seqlen); + tensor_t k_view = k->slice(0, 0, seqlen); + tensor_t v_view = v->slice(0, 0, seqlen); + tensor_t q_rope_view = q_rope->slice(0, 0, seqlen); + tensor_t k_rope_view = k_rope->slice(0, 0, seqlen); + tensor_t attn_out_view = attn_out->slice(0, 0, seqlen); + tensor_t attn_proj_view = attn_proj->slice(0, 0, seqlen); + tensor_t gate_view = gate->slice(0, 0, seqlen); + tensor_t up_view = up->slice(0, 0, seqlen); + tensor_t mlp_out_view = mlp_out->slice(0, 0, seqlen); + + // 1. Pre-attention layer norm + ops::rms_norm(norm_view, hidden_view, attn_norm_w[layer], meta.epsilon); + + // 2. Compute Q, K, V projections + ops::linear(q_view, norm_view, attn_q_w[layer], attn_q_b[layer]); + ops::linear(k_view, norm_view, attn_k_w[layer], attn_k_b[layer]); + ops::linear(v_view, norm_view, attn_v_w[layer], attn_v_b[layer]); + + // 3. Reshape Q, K, V to [seqlen, nhead, dh] + tensor_t q_reshaped = q_view->view({seqlen, nh, dh}); + tensor_t k_reshaped = k_view->view({seqlen, nkvh, dh}); + tensor_t v_reshaped = v_view->view({seqlen, nkvh, dh}); + + // 4. Set up position ids + tensor_t pos_view = pos_ids->slice(0, 0, seqlen); + // Use aligned buffer for MetaX + alignas(64) int64_t pos_data[4096]; + for (size_t i = 0; i < seqlen; ++i) { + pos_data[i] = static_cast(start_pos + i); + } + pos_view->load(pos_data); + + // 5. Apply RoPE + ops::rope(q_rope_view, q_reshaped, pos_view, meta.theta); + ops::rope(k_rope_view, k_reshaped, pos_view, meta.theta); + + // 6. Update KV cache + tensor_t k_cache_slice = k_cache[layer]->slice(0, start_pos, total_len); + tensor_t v_cache_slice = v_cache[layer]->slice(0, start_pos, total_len); + + // Copy new K, V to cache + size_t kv_bytes = seqlen * nkvh * dh * k_rope_view->elementSize(); + if (device_type == LLAISYS_DEVICE_CPU) { + std::memcpy(k_cache_slice->data(), k_rope_view->data(), kv_bytes); + std::memcpy(v_cache_slice->data(), v_reshaped->data(), kv_bytes); + } else { + core::context().setDevice(device_type, device_id); + auto api = core::context().runtime().api(); + api->memcpy_sync(k_cache_slice->data(), k_rope_view->data(), kv_bytes, LLAISYS_MEMCPY_D2D); + api->memcpy_sync(v_cache_slice->data(), v_reshaped->data(), kv_bytes, LLAISYS_MEMCPY_D2D); + } + + // 7. Self-attention with full KV cache + tensor_t k_full = k_cache[layer]->slice(0, 0, total_len); + tensor_t v_full = v_cache[layer]->slice(0, 0, total_len); + + float scale = 1.0f / std::sqrt(static_cast(dh)); + ops::self_attention(attn_out_view, q_rope_view, k_full, v_full, scale); + + // 8. Reshape attention output and project + tensor_t attn_out_flat = attn_out_view->view({seqlen, nh * dh}); + ops::linear(attn_proj_view, attn_out_flat, attn_o_w[layer], nullptr); + + // 9. Residual connection + ops::add(hidden_view, hidden_view, attn_proj_view); + + // 10. Post-attention layer norm + ops::rms_norm(norm_view, hidden_view, mlp_norm_w[layer], meta.epsilon); + + // 11. MLP: gate and up projections + ops::linear(gate_view, norm_view, mlp_gate_w[layer], nullptr); + ops::linear(up_view, norm_view, mlp_up_w[layer], nullptr); + + // 12. SwiGLU activation + ops::swiglu(gate_view, gate_view, up_view); + + // 13. Down projection (reuse attn_proj_view as output buffer, shape [seqlen, hs]) + ops::linear(attn_proj_view, gate_view, mlp_down_w[layer], nullptr); + + // 14. Residual connection + ops::add(hidden_view, hidden_view, attn_proj_view); +} + +} // namespace models +} // namespace llaisys diff --git a/src/models/qwen2/qwen2.hpp b/src/models/qwen2/qwen2.hpp new file mode 100644 index 00000000..e5d55ccd --- /dev/null +++ b/src/models/qwen2/qwen2.hpp @@ -0,0 +1,78 @@ +#pragma once + +#include "llaisys/models/qwen2.h" +#include "../../tensor/tensor.hpp" +#include "../../ops/embedding/op.hpp" +#include "../../ops/linear/op.hpp" +#include "../../ops/rms_norm/op.hpp" +#include "../../ops/rope/op.hpp" +#include "../../ops/self_attention/op.hpp" +#include "../../ops/swiglu/op.hpp" +#include "../../ops/argmax/op.hpp" +#include "../../ops/add/op.hpp" + +#include +#include + +namespace llaisys { +namespace models { + +struct Qwen2Model { + LlaisysQwen2Meta meta; + llaisysDeviceType_t device_type; + int device_id; + + // Weights + tensor_t in_embed; + tensor_t out_embed; + tensor_t out_norm_w; + std::vector attn_norm_w; + std::vector attn_q_w; + std::vector attn_q_b; + std::vector attn_k_w; + std::vector attn_k_b; + std::vector attn_v_w; + std::vector attn_v_b; + std::vector attn_o_w; + std::vector mlp_norm_w; + std::vector mlp_gate_w; + std::vector mlp_up_w; + std::vector mlp_down_w; + + // KV Cache: [nlayer][max_seq, nkvh, dh] + std::vector k_cache; + std::vector v_cache; + size_t cache_len; + + // Intermediate buffers + tensor_t hidden; + tensor_t hidden_norm; + tensor_t q; + tensor_t k; + tensor_t v; + tensor_t q_rope; + tensor_t k_rope; + tensor_t attn_out; + tensor_t attn_proj; + tensor_t gate; + tensor_t up; + tensor_t mlp_out; + tensor_t logits; + tensor_t max_idx; + tensor_t max_val; + tensor_t pos_ids; + + Qwen2Model(const LlaisysQwen2Meta* meta, llaisysDeviceType_t device, int device_id); + ~Qwen2Model() = default; + + void allocateWeights(); + void allocateCache(); + void allocateBuffers(size_t max_seqlen); + + int64_t infer(int64_t* token_ids, size_t ntoken); + void forward(size_t seqlen, size_t start_pos); + void forwardLayer(size_t layer, size_t seqlen, size_t start_pos); +}; + +} // namespace models +} // namespace llaisys diff --git a/src/models/qwen2/qwen2_tp.cpp b/src/models/qwen2/qwen2_tp.cpp new file mode 100644 index 00000000..29475c3c --- /dev/null +++ b/src/models/qwen2/qwen2_tp.cpp @@ -0,0 +1,458 @@ +#include "qwen2_tp.hpp" +#include +#include +#include + +// Windows does not support NCCL, provide stub implementation +// Also use stub when NVIDIA API is not enabled (CPU-only builds) +// Or when NCCL is not available +#if defined(_WIN32) || !defined(ENABLE_NVIDIA_API) || !defined(ENABLE_NCCL) + +namespace llaisys { +namespace models { + +// Stub implementation for Windows or CPU-only builds - TP not supported +Qwen2ModelShard::Qwen2ModelShard(int rank_, int device_id_) + : rank(rank_), device_id(device_id_), communicator(nullptr) { +} + +void Qwen2ModelShard::allocateWeights(const LlaisysQwen2Meta& meta, int world_size) { + throw std::runtime_error("Tensor Parallel requires NVIDIA GPU support"); +} + +void Qwen2ModelShard::allocateCache(const LlaisysQwen2Meta& meta, int world_size) { + throw std::runtime_error("Tensor Parallel requires NVIDIA GPU support"); +} + +void Qwen2ModelShard::allocateBuffers(const LlaisysQwen2Meta& meta, size_t max_seqlen, int world_size) { + throw std::runtime_error("Tensor Parallel requires NVIDIA GPU support"); +} + +Qwen2ModelTP::Qwen2ModelTP(const LlaisysQwen2Meta* meta_, const std::vector& device_ids_) + : world_size(static_cast(device_ids_.size())), + device_ids(device_ids_), + cache_len(0) { + throw std::runtime_error("Tensor Parallel requires NVIDIA GPU support. NCCL is Linux-only."); +} + +Qwen2ModelTP::~Qwen2ModelTP() = default; + +void Qwen2ModelTP::initialize() {} + +int64_t Qwen2ModelTP::infer(int64_t* token_ids, size_t ntoken) { + throw std::runtime_error("Tensor Parallel requires NVIDIA GPU support"); + return 0; +} + +void Qwen2ModelTP::resetCache() { + cache_len.store(0); +} + +LlaisysQwen2Weights* Qwen2ModelTP::getWeights(int rank) { + return nullptr; +} + +void Qwen2ModelTP::allReduce(int rank, void* buffer, size_t count, int dtype) { + throw std::runtime_error("Tensor Parallel requires NVIDIA GPU support"); +} + +} // namespace models +} // namespace llaisys + +#else // Linux with NVIDIA API enabled - Full implementation + +namespace llaisys { +namespace models { + +// Forward declare CUDA runtime functions +typedef int cudaError_t; +#define cudaSuccess 0 +extern "C" cudaError_t cudaSetDevice(int device); +extern "C" cudaError_t cudaDeviceSynchronize(void); +extern "C" cudaError_t cudaGetLastError(void); +extern "C" cudaError_t cudaGetDeviceCount(int* count); +extern "C" cudaError_t cudaFree(void* devPtr); +extern "C" const char* cudaGetErrorString(cudaError_t error); + +#define CUDA_CHECK(call) do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ \ + << " - " << cudaGetErrorString(err) << std::endl; \ + } \ +} while(0) + +// ==================== Qwen2ModelShard ==================== + +Qwen2ModelShard::Qwen2ModelShard(int rank_, int device_id_) + : rank(rank_), device_id(device_id_), communicator(nullptr) { +} + +void Qwen2ModelShard::allocateWeights(const LlaisysQwen2Meta& meta, int world_size) { + auto dtype = meta.dtype; + size_t nlayer = meta.nlayer; + size_t hs = meta.hs; + size_t di = meta.di; + size_t voc = meta.voc; + size_t dh = meta.dh; + size_t nh = meta.nh; + size_t nkvh = meta.nkvh; + + size_t nh_shard = nh / world_size; + size_t nkvh_shard = nkvh / world_size; + size_t di_shard = di / world_size; + + in_embed = Tensor::create({voc, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + out_embed = Tensor::create({voc, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + out_norm_w = Tensor::create({hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + + attn_norm_w.resize(nlayer); + attn_q_w.resize(nlayer); + attn_q_b.resize(nlayer); + attn_k_w.resize(nlayer); + attn_k_b.resize(nlayer); + attn_v_w.resize(nlayer); + attn_v_b.resize(nlayer); + attn_o_w.resize(nlayer); + mlp_norm_w.resize(nlayer); + mlp_gate_w.resize(nlayer); + mlp_up_w.resize(nlayer); + mlp_down_w.resize(nlayer); + + for (size_t i = 0; i < nlayer; ++i) { + attn_norm_w[i] = Tensor::create({hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + attn_q_w[i] = Tensor::create({nh_shard * dh, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + attn_q_b[i] = Tensor::create({nh_shard * dh}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + attn_k_w[i] = Tensor::create({nkvh_shard * dh, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + attn_k_b[i] = Tensor::create({nkvh_shard * dh}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + attn_v_w[i] = Tensor::create({nkvh_shard * dh, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + attn_v_b[i] = Tensor::create({nkvh_shard * dh}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + attn_o_w[i] = Tensor::create({hs, nh_shard * dh}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + mlp_norm_w[i] = Tensor::create({hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + mlp_gate_w[i] = Tensor::create({di_shard, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + mlp_up_w[i] = Tensor::create({di_shard, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + mlp_down_w[i] = Tensor::create({hs, di_shard}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + } +} + +void Qwen2ModelShard::allocateCache(const LlaisysQwen2Meta& meta, int world_size) { + size_t nlayer = meta.nlayer; + size_t maxseq = meta.maxseq; + size_t nkvh = meta.nkvh; + size_t dh = meta.dh; + size_t nkvh_shard = nkvh / world_size; + auto dtype_val = meta.dtype; + + k_cache.resize(nlayer); + v_cache.resize(nlayer); + + for (size_t i = 0; i < nlayer; ++i) { + k_cache[i] = Tensor::create({maxseq, nkvh_shard, dh}, dtype_val, LLAISYS_DEVICE_NVIDIA, device_id); + v_cache[i] = Tensor::create({maxseq, nkvh_shard, dh}, dtype_val, LLAISYS_DEVICE_NVIDIA, device_id); + } +} + +void Qwen2ModelShard::allocateBuffers(const LlaisysQwen2Meta& meta, size_t max_seqlen, int world_size) { + auto dtype = meta.dtype; + size_t hs = meta.hs; + size_t dh = meta.dh; + size_t di = meta.di; + size_t voc = meta.voc; + size_t nh = meta.nh; + size_t nkvh = meta.nkvh; + size_t nh_shard = nh / world_size; + size_t nkvh_shard = nkvh / world_size; + size_t di_shard = di / world_size; + + hidden = Tensor::create({max_seqlen, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + hidden_norm = Tensor::create({max_seqlen, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + q = Tensor::create({max_seqlen, nh_shard * dh}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + k = Tensor::create({max_seqlen, nkvh_shard * dh}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + v = Tensor::create({max_seqlen, nkvh_shard * dh}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + q_rope = Tensor::create({max_seqlen, nh_shard, dh}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + k_rope = Tensor::create({max_seqlen, nkvh_shard, dh}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + attn_out = Tensor::create({max_seqlen, nh_shard, dh}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + attn_proj = Tensor::create({max_seqlen, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + gate = Tensor::create({max_seqlen, di_shard}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + up = Tensor::create({max_seqlen, di_shard}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + mlp_out = Tensor::create({max_seqlen, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + logits = Tensor::create({1, voc}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + max_idx = Tensor::create({1}, LLAISYS_DTYPE_I64, LLAISYS_DEVICE_NVIDIA, device_id); + max_val = Tensor::create({1}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); + pos_ids = Tensor::create({max_seqlen}, LLAISYS_DTYPE_I64, LLAISYS_DEVICE_NVIDIA, device_id); + allreduce_buffer = Tensor::create({max_seqlen, hs}, dtype, LLAISYS_DEVICE_NVIDIA, device_id); +} + +// ==================== Qwen2ModelTP ==================== + +Qwen2ModelTP::Qwen2ModelTP(const LlaisysQwen2Meta* meta_, const std::vector& device_ids_) + : world_size(static_cast(device_ids_.size())), + device_ids(device_ids_), + cache_len(0) { + + std::memcpy(&meta, meta_, sizeof(LlaisysQwen2Meta)); + + if (meta.nh % world_size != 0) { + throw std::invalid_argument("nh must be divisible by world_size"); + } + if (meta.nkvh % world_size != 0) { + throw std::invalid_argument("nkvh must be divisible by world_size"); + } + if (meta.di % world_size != 0) { + throw std::invalid_argument("di must be divisible by world_size"); + } + + initialize(); +} + +Qwen2ModelTP::~Qwen2ModelTP() = default; + +void Qwen2ModelTP::initialize() { + // Initialize CUDA driver and get device count + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + + // Initialize CUDA context on first device + CUDA_CHECK(cudaSetDevice(device_ids[0])); + cudaFree(0); + CUDA_CHECK(cudaDeviceSynchronize()); + core::context().setDevice(LLAISYS_DEVICE_NVIDIA, device_ids[0]); + + // Create NCCL communicators + communicators = device::nvidia::NCCLCommunicator::createAll(device_ids); + + // Reset to first device after NCCL init + CUDA_CHECK(cudaSetDevice(device_ids[0])); + core::context().setDevice(LLAISYS_DEVICE_NVIDIA, device_ids[0]); + + // Create shards with proper device context + shards.reserve(world_size); + for (int i = 0; i < world_size; ++i) { + CUDA_CHECK(cudaSetDevice(device_ids[i])); + core::context().setDevice(LLAISYS_DEVICE_NVIDIA, device_ids[i]); + + shards.push_back(std::make_unique(i, device_ids[i])); + shards[i]->communicator = communicators[i].get(); + shards[i]->allocateWeights(meta, world_size); + shards[i]->allocateCache(meta, world_size); + shards[i]->allocateBuffers(meta, meta.maxseq, world_size); + } + + // Reset to first device after initialization + CUDA_CHECK(cudaSetDevice(device_ids[0])); + core::context().setDevice(LLAISYS_DEVICE_NVIDIA, device_ids[0]); + CUDA_CHECK(cudaDeviceSynchronize()); +} + +void Qwen2ModelTP::allReduce(int rank, void* buffer, size_t count, int dtype) { + cudaSetDevice(shards[rank]->device_id); + cudaDeviceSynchronize(); + communicators[rank]->allReduce(buffer, count, dtype); +} + +int64_t Qwen2ModelTP::infer(int64_t* token_ids, size_t ntoken) { + std::atomic result_token(0); + std::vector threads; + + // Create barrier for synchronizing all threads + SimpleBarrier sync_barrier(world_size); + + // Fork: launch threads for each rank + for (int rank = 0; rank < world_size; ++rank) { + threads.emplace_back([&, rank]() { + auto& shard = *shards[rank]; + + // Set CUDA device and runtime context for this thread + cudaSetDevice(shard.device_id); + cudaDeviceSynchronize(); + core::context().setDevice(LLAISYS_DEVICE_NVIDIA, shard.device_id); + + tensor_t input_ids = Tensor::create({ntoken}, LLAISYS_DTYPE_I64, + LLAISYS_DEVICE_NVIDIA, shard.device_id); + + alignas(64) int64_t aligned_tokens[4096]; + if (ntoken <= 4096) { + std::memcpy(aligned_tokens, token_ids, ntoken * sizeof(int64_t)); + input_ids->load(aligned_tokens); + } else { + input_ids->load(token_ids); + } + + tensor_t hidden_view = shard.hidden->slice(0, 0, ntoken); + ops::embedding(hidden_view, input_ids, shard.in_embed); + + for (size_t layer = 0; layer < meta.nlayer; ++layer) { + forwardLayerWithBarrier(rank, layer, ntoken, cache_len.load(), sync_barrier); + } + + if (rank == 0) { + cache_len += ntoken; + } + + tensor_t last_hidden = shard.hidden->slice(0, ntoken - 1, ntoken); + tensor_t last_norm = shard.hidden_norm->slice(0, 0, 1); + + ops::rms_norm(last_norm, last_hidden, shard.out_norm_w, meta.epsilon); + ops::linear(shard.logits, last_norm, shard.out_embed, nullptr); + + if (rank == 0) { + tensor_t last_logits = shard.logits->view({meta.voc}); + ops::argmax(shard.max_idx, shard.max_val, last_logits); + + alignas(64) int64_t result; + core::context().runtime().api()->memcpy_sync( + &result, shard.max_idx->data(), sizeof(int64_t), LLAISYS_MEMCPY_D2H); + + result_token.store(result); + } + }); + } + + for (auto& t : threads) { + t.join(); + } + + device::nvidia::NCCLCommunicator::synchronizeAll(communicators); + + return result_token.load(); +} + +void Qwen2ModelTP::resetCache() { + cache_len.store(0); +} + +LlaisysQwen2Weights* Qwen2ModelTP::getWeights(int rank) { + auto* weights = new LlaisysQwen2Weights(); + auto& shard = *shards[rank]; + + weights->in_embed = new LlaisysTensor{shard.in_embed}; + weights->out_embed = new LlaisysTensor{shard.out_embed}; + weights->out_norm_w = new LlaisysTensor{shard.out_norm_w}; + + weights->attn_norm_w = new llaisysTensor_t[meta.nlayer]; + weights->attn_q_w = new llaisysTensor_t[meta.nlayer]; + weights->attn_q_b = new llaisysTensor_t[meta.nlayer]; + weights->attn_k_w = new llaisysTensor_t[meta.nlayer]; + weights->attn_k_b = new llaisysTensor_t[meta.nlayer]; + weights->attn_v_w = new llaisysTensor_t[meta.nlayer]; + weights->attn_v_b = new llaisysTensor_t[meta.nlayer]; + weights->attn_o_w = new llaisysTensor_t[meta.nlayer]; + weights->mlp_norm_w = new llaisysTensor_t[meta.nlayer]; + weights->mlp_gate_w = new llaisysTensor_t[meta.nlayer]; + weights->mlp_up_w = new llaisysTensor_t[meta.nlayer]; + weights->mlp_down_w = new llaisysTensor_t[meta.nlayer]; + + for (size_t i = 0; i < meta.nlayer; ++i) { + weights->attn_norm_w[i] = new LlaisysTensor{shard.attn_norm_w[i]}; + weights->attn_q_w[i] = new LlaisysTensor{shard.attn_q_w[i]}; + weights->attn_q_b[i] = new LlaisysTensor{shard.attn_q_b[i]}; + weights->attn_k_w[i] = new LlaisysTensor{shard.attn_k_w[i]}; + weights->attn_k_b[i] = new LlaisysTensor{shard.attn_k_b[i]}; + weights->attn_v_w[i] = new LlaisysTensor{shard.attn_v_w[i]}; + weights->attn_v_b[i] = new LlaisysTensor{shard.attn_v_b[i]}; + weights->attn_o_w[i] = new LlaisysTensor{shard.attn_o_w[i]}; + weights->mlp_norm_w[i] = new LlaisysTensor{shard.mlp_norm_w[i]}; + weights->mlp_gate_w[i] = new LlaisysTensor{shard.mlp_gate_w[i]}; + weights->mlp_up_w[i] = new LlaisysTensor{shard.mlp_up_w[i]}; + weights->mlp_down_w[i] = new LlaisysTensor{shard.mlp_down_w[i]}; + } + + return weights; +} + +void Qwen2ModelTP::forwardLayerWithBarrier(int rank, size_t layer, size_t seqlen, size_t start_pos, SimpleBarrier& barrier) { + auto& shard = *shards[rank]; + size_t total_len = start_pos + seqlen; + size_t nh_shard = meta.nh / world_size; + size_t nkvh_shard = meta.nkvh / world_size; + size_t dh = meta.dh; + + tensor_t hidden_view = shard.hidden->slice(0, 0, seqlen); + tensor_t norm_view = shard.hidden_norm->slice(0, 0, seqlen); + tensor_t q_view = shard.q->slice(0, 0, seqlen); + tensor_t k_view = shard.k->slice(0, 0, seqlen); + tensor_t v_view = shard.v->slice(0, 0, seqlen); + tensor_t q_rope_view = shard.q_rope->slice(0, 0, seqlen); + tensor_t k_rope_view = shard.k_rope->slice(0, 0, seqlen); + tensor_t attn_out_view = shard.attn_out->slice(0, 0, seqlen); + tensor_t attn_proj_view = shard.attn_proj->slice(0, 0, seqlen); + tensor_t gate_view = shard.gate->slice(0, 0, seqlen); + tensor_t up_view = shard.up->slice(0, 0, seqlen); + + ops::rms_norm(norm_view, hidden_view, shard.attn_norm_w[layer], meta.epsilon); + ops::linear(q_view, norm_view, shard.attn_q_w[layer], shard.attn_q_b[layer]); + ops::linear(k_view, norm_view, shard.attn_k_w[layer], shard.attn_k_b[layer]); + ops::linear(v_view, norm_view, shard.attn_v_w[layer], shard.attn_v_b[layer]); + + tensor_t q_reshaped = q_view->view({seqlen, nh_shard, dh}); + tensor_t k_reshaped = k_view->view({seqlen, nkvh_shard, dh}); + tensor_t v_reshaped = v_view->view({seqlen, nkvh_shard, dh}); + + tensor_t pos_view = shard.pos_ids->slice(0, 0, seqlen); + alignas(64) int64_t pos_data[4096]; + for (size_t i = 0; i < seqlen; ++i) { + pos_data[i] = static_cast(start_pos + i); + } + pos_view->load(pos_data); + + ops::rope(q_rope_view, q_reshaped, pos_view, meta.theta); + ops::rope(k_rope_view, k_reshaped, pos_view, meta.theta); + + tensor_t k_cache_slice = shard.k_cache[layer]->slice(0, start_pos, total_len); + tensor_t v_cache_slice = shard.v_cache[layer]->slice(0, start_pos, total_len); + + size_t kv_bytes = seqlen * nkvh_shard * dh * k_rope_view->elementSize(); + auto api = core::context().runtime().api(); + api->memcpy_sync(k_cache_slice->data(), k_rope_view->data(), kv_bytes, LLAISYS_MEMCPY_D2D); + api->memcpy_sync(v_cache_slice->data(), v_reshaped->data(), kv_bytes, LLAISYS_MEMCPY_D2D); + + tensor_t k_full = shard.k_cache[layer]->slice(0, 0, total_len); + tensor_t v_full = shard.v_cache[layer]->slice(0, 0, total_len); + + float scale = 1.0f / std::sqrt(static_cast(dh)); + ops::self_attention(attn_out_view, q_rope_view, k_full, v_full, scale); + + tensor_t attn_out_flat = attn_out_view->view({seqlen, nh_shard * dh}); + ops::linear(attn_proj_view, attn_out_flat, shard.attn_o_w[layer], nullptr); + + size_t hidden_bytes = seqlen * meta.hs * hidden_view->elementSize(); + api->memcpy_sync(shard.allreduce_buffer->data(), attn_proj_view->data(), + hidden_bytes, LLAISYS_MEMCPY_D2D); + + barrier.arrive_and_wait(); + + CUDA_CHECK(cudaSetDevice(shard.device_id)); + core::context().setDevice(LLAISYS_DEVICE_NVIDIA, shard.device_id); + shard.communicator->allReduce(shard.allreduce_buffer->data(), seqlen * meta.hs, meta.dtype); + barrier.arrive_and_wait(); + + api->memcpy_sync(attn_proj_view->data(), shard.allreduce_buffer->data(), + hidden_bytes, LLAISYS_MEMCPY_D2D); + + ops::add(hidden_view, hidden_view, attn_proj_view); + ops::rms_norm(norm_view, hidden_view, shard.mlp_norm_w[layer], meta.epsilon); + ops::linear(gate_view, norm_view, shard.mlp_gate_w[layer], nullptr); + ops::linear(up_view, norm_view, shard.mlp_up_w[layer], nullptr); + ops::swiglu(gate_view, gate_view, up_view); + ops::linear(attn_proj_view, gate_view, shard.mlp_down_w[layer], nullptr); + + api->memcpy_sync(shard.allreduce_buffer->data(), attn_proj_view->data(), + hidden_bytes, LLAISYS_MEMCPY_D2D); + + barrier.arrive_and_wait(); + + CUDA_CHECK(cudaSetDevice(shard.device_id)); + core::context().setDevice(LLAISYS_DEVICE_NVIDIA, shard.device_id); + shard.communicator->allReduce(shard.allreduce_buffer->data(), seqlen * meta.hs, meta.dtype); + barrier.arrive_and_wait(); + + api->memcpy_sync(attn_proj_view->data(), shard.allreduce_buffer->data(), + hidden_bytes, LLAISYS_MEMCPY_D2D); + + ops::add(hidden_view, hidden_view, attn_proj_view); +} + +} // namespace models +} // namespace llaisys + +#endif // _WIN32 diff --git a/src/models/qwen2/qwen2_tp.hpp b/src/models/qwen2/qwen2_tp.hpp new file mode 100644 index 00000000..2021d81d --- /dev/null +++ b/src/models/qwen2/qwen2_tp.hpp @@ -0,0 +1,167 @@ +#pragma once + +#include "llaisys/models/qwen2_tp.h" +#include "qwen2.hpp" +#include "../../llaisys/llaisys_tensor.hpp" +#include "../../device/nvidia/nccl_communicator.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace llaisys { +namespace models { + +// Simple barrier for C++17 +class SimpleBarrier { +public: + explicit SimpleBarrier(size_t count) : count_(count), generation_(0), waiting_(0) {} + + void arrive_and_wait() { + std::unique_lock lock(mutex_); + size_t gen = generation_; + if (++waiting_ == count_) { + generation_++; + waiting_ = 0; + cv_.notify_all(); + } else { + cv_.wait(lock, [this, gen] { return gen != generation_; }); + } + } + +private: + std::mutex mutex_; + std::condition_variable cv_; + size_t count_; + size_t generation_; + size_t waiting_; +}; + +// Model shard for a single GPU +struct Qwen2ModelShard { + int rank; + int device_id; + + // NCCL communicator for this rank + device::nvidia::NCCLCommunicator* communicator; + + // Weight shards + tensor_t in_embed; + tensor_t out_embed; + tensor_t out_norm_w; + std::vector attn_norm_w; + std::vector attn_q_w; + std::vector attn_q_b; + std::vector attn_k_w; + std::vector attn_k_b; + std::vector attn_v_w; + std::vector attn_v_b; + std::vector attn_o_w; + std::vector mlp_norm_w; + std::vector mlp_gate_w; + std::vector mlp_up_w; + std::vector mlp_down_w; + + // KV Cache + std::vector k_cache; + std::vector v_cache; + + // Intermediate buffers + tensor_t hidden; + tensor_t hidden_norm; + tensor_t q; + tensor_t k; + tensor_t v; + tensor_t q_rope; + tensor_t k_rope; + tensor_t attn_out; + tensor_t attn_proj; + tensor_t gate; + tensor_t up; + tensor_t mlp_out; + tensor_t logits; + tensor_t max_idx; + tensor_t max_val; + tensor_t pos_ids; + tensor_t allreduce_buffer; + + Qwen2ModelShard(int rank_, int device_id_); + + void allocateWeights(const LlaisysQwen2Meta& meta, int world_size); + void allocateCache(const LlaisysQwen2Meta& meta, int world_size); + void allocateBuffers(const LlaisysQwen2Meta& meta, size_t max_seqlen, int world_size); + + size_t getShardedAttentionHeads(size_t nh, int world_size) const { + return nh / world_size; + } + size_t getShardedKVHeads(size_t nkvh, int world_size) const { + return nkvh / world_size; + } + size_t getShardedIntermediateSize(size_t di, int world_size) const { + return di / world_size; + } +}; + +// Inference task for worker threads +struct InferTask { + int64_t* token_ids; + size_t ntoken; + size_t cache_len; + std::promise* result_promise; // Only rank 0 sets this +}; + +// Tensor Parallel Qwen2 Model +struct Qwen2ModelTP { + LlaisysQwen2Meta meta; + int world_size; + std::vector device_ids; + + // NCCL communicators (created in main thread, used by worker threads) + std::vector> communicators; + + // Model shards + std::vector> shards; + + // Current cache length + std::atomic cache_len; + + // Persistent worker threads + std::vector worker_threads; + std::atomic workers_running; + + // Task queue for inference + std::vector> task_queues; + std::vector queue_mutexes; + std::vector queue_cvs; + + // Completion synchronization + SimpleBarrier* infer_barrier; + std::atomic infer_result; + + Qwen2ModelTP(const LlaisysQwen2Meta* meta_, const std::vector& device_ids); + ~Qwen2ModelTP(); + + void initialize(); + void startWorkerThreads(); + void stopWorkerThreads(); + void workerLoop(int rank); + + LlaisysQwen2Weights* getWeights(int rank); + + int64_t infer(int64_t* token_ids, size_t ntoken); + void resetCache(); + int getWorldSize() const { return world_size; } + +private: + void forwardLayerWithBarrier(int rank, size_t layer, size_t seqlen, size_t start_pos, SimpleBarrier& barrier); + void allReduce(int rank, void* buffer, size_t count, int dtype); +}; + +} // namespace models +} // namespace llaisys diff --git a/src/ops/add/metax/add_metax.cu b/src/ops/add/metax/add_metax.cu new file mode 100644 index 00000000..d532c70a --- /dev/null +++ b/src/ops/add/metax/add_metax.cu @@ -0,0 +1,45 @@ +#include "add_metax.cuh" + +#include "../../../device/metax/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include + +namespace llaisys::ops::metax { + +template +__global__ void addKernel(T *c, const T *a, const T *b, size_t numel) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) { + float af = to_float_metax(a[idx]); + float bf = to_float_metax(b[idx]); + c[idx] = from_float_metax(af + bf); + } +} + +template +void add_(std::byte *c, const std::byte *a, const std::byte *b, size_t numel) { + auto c_ptr = reinterpret_cast(c); + auto a_ptr = reinterpret_cast(a); + auto b_ptr = reinterpret_cast(b); + + const int blockSize = 256; + const int numBlocks = (numel + blockSize - 1) / blockSize; + + addKernel<<>>(c_ptr, a_ptr, b_ptr, numel); +} + +void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t size) { + switch (type) { + case LLAISYS_DTYPE_F32: + return add_(c, a, b, size); + case LLAISYS_DTYPE_BF16: + return add_(c, a, b, size); + case LLAISYS_DTYPE_F16: + return add_(c, a, b, size); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/add/metax/add_metax.cuh b/src/ops/add/metax/add_metax.cuh new file mode 100644 index 00000000..cc648077 --- /dev/null +++ b/src/ops/add/metax/add_metax.cuh @@ -0,0 +1,12 @@ +#pragma once + +#include "../../../device/runtime_api.hpp" +#include "llaisys/tensor.h" + +#include + +namespace llaisys::ops::metax { + +void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t size); + +} // namespace llaisys::ops::metax diff --git a/src/ops/add/nvidia/add_nvidia.cu b/src/ops/add/nvidia/add_nvidia.cu new file mode 100644 index 00000000..8bb18320 --- /dev/null +++ b/src/ops/add/nvidia/add_nvidia.cu @@ -0,0 +1,45 @@ +#include "add_nvidia.cuh" + +#include "../../../device/nvidia/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include + +namespace llaisys::ops::nvidia { + +template +__global__ void addKernel(T *c, const T *a, const T *b, size_t numel) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) { + float af = to_float_cuda(a[idx]); + float bf = to_float_cuda(b[idx]); + c[idx] = from_float_cuda(af + bf); + } +} + +template +void add_(std::byte *c, const std::byte *a, const std::byte *b, size_t numel) { + auto c_ptr = reinterpret_cast(c); + auto a_ptr = reinterpret_cast(a); + auto b_ptr = reinterpret_cast(b); + + const int blockSize = 256; + const int numBlocks = (numel + blockSize - 1) / blockSize; + + addKernel<<>>(c_ptr, a_ptr, b_ptr, numel); +} + +void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t size) { + switch (type) { + case LLAISYS_DTYPE_F32: + return add_(c, a, b, size); + case LLAISYS_DTYPE_BF16: + return add_(c, a, b, size); + case LLAISYS_DTYPE_F16: + return add_(c, a, b, size); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/add/nvidia/add_nvidia.cuh b/src/ops/add/nvidia/add_nvidia.cuh new file mode 100644 index 00000000..a2e9144a --- /dev/null +++ b/src/ops/add/nvidia/add_nvidia.cuh @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t size); +} diff --git a/src/ops/add/op.cpp b/src/ops/add/op.cpp index a057330d..0f05156e 100644 --- a/src/ops/add/op.cpp +++ b/src/ops/add/op.cpp @@ -5,6 +5,14 @@ #include "cpu/add_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/add_nvidia.cuh" +#endif + +#ifdef ENABLE_METAX_API +#include "metax/add_metax.cuh" +#endif + namespace llaisys::ops { void add(tensor_t c, tensor_t a, tensor_t b) { CHECK_SAME_DEVICE(c, a, b); @@ -25,8 +33,11 @@ void add(tensor_t c, tensor_t a, tensor_t b) { return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #ifdef ENABLE_NVIDIA_API case LLAISYS_DEVICE_NVIDIA: - TO_BE_IMPLEMENTED(); - return; + return nvidia::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #endif default: EXCEPTION_UNSUPPORTED_DEVICE; diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 00000000..5c0e18a0 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,42 @@ +#include "argmax_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::cpu { + +template +void argmax_(std::byte *max_idx_bytes, std::byte *max_val_bytes, const std::byte *vals_bytes, size_t size) { + auto max_idx = reinterpret_cast(max_idx_bytes); + auto max_val = reinterpret_cast(max_val_bytes); + auto vals = reinterpret_cast(vals_bytes); + + *max_idx = 0; + *max_val = vals[0]; + + for (size_t i = 1; i < size; ++i) { + float curr_val = llaisys::utils::cast(vals[i]); + float current_max = llaisys::utils::cast(*max_val); + if (curr_val > current_max) { + *max_idx = static_cast(i); + *max_val = vals[i]; + } + } +} + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t size) { + switch (type) { + case LLAISYS_DTYPE_F32: + return argmax_(max_idx, max_val, vals, size); + case LLAISYS_DTYPE_BF16: + return argmax_(max_idx, max_val, vals, size); + case LLAISYS_DTYPE_F16: + return argmax_(max_idx, max_val, vals, size); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 00000000..0c362ee4 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t size); +} \ No newline at end of file diff --git a/src/ops/argmax/metax/argmax_metax.cu b/src/ops/argmax/metax/argmax_metax.cu new file mode 100644 index 00000000..f829cf2b --- /dev/null +++ b/src/ops/argmax/metax/argmax_metax.cu @@ -0,0 +1,193 @@ +#include "argmax_metax.cuh" + +#include "../../../device/metax/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include +#include + +namespace llaisys::ops::metax { + +// Kernel for f16 - output to float buffer (workaround for MetaX scalar write bug) +__global__ void argmaxF16Kernel(const fp16_t_metax *vals, float *max_val_float, int64_t *max_idx, size_t size) { + __shared__ float shared_vals[256]; + __shared__ int64_t shared_idxs[256]; + + unsigned int tid = threadIdx.x; + + shared_vals[tid] = -FLT_MAX; + shared_idxs[tid] = -1; + + for (unsigned int i = tid; i < size; i += blockDim.x) { + float v = to_float_metax(vals[i]); + if (v > shared_vals[tid]) { + shared_vals[tid] = v; + shared_idxs[tid] = i; + } + } + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + if (shared_vals[tid + s] > shared_vals[tid]) { + shared_vals[tid] = shared_vals[tid + s]; + shared_idxs[tid] = shared_idxs[tid + s]; + } + } + __syncthreads(); + } + + if (tid == 0) { + *max_idx = shared_idxs[0]; + *max_val_float = shared_vals[0]; + } +} + +// Kernel for bf16 +__global__ void argmaxBF16Kernel(const bf16_t_metax *vals, bf16_t_metax *max_val, int64_t *max_idx, size_t size) { + __shared__ float shared_vals[256]; + __shared__ int64_t shared_idxs[256]; + + unsigned int tid = threadIdx.x; + + shared_vals[tid] = -FLT_MAX; + shared_idxs[tid] = -1; + + for (unsigned int i = tid; i < size; i += blockDim.x) { + float v = to_float_metax(vals[i]); + if (v > shared_vals[tid]) { + shared_vals[tid] = v; + shared_idxs[tid] = i; + } + } + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + if (shared_vals[tid + s] > shared_vals[tid]) { + shared_vals[tid] = shared_vals[tid + s]; + shared_idxs[tid] = shared_idxs[tid + s]; + } + } + __syncthreads(); + } + + if (tid == 0) { + *max_idx = shared_idxs[0]; + max_val[0] = from_float_metax(shared_vals[0]); + } +} + +// Kernel for f32 +__global__ void argmaxF32Kernel(const float *vals, float *max_val, int64_t *max_idx, size_t size) { + __shared__ float shared_vals[256]; + __shared__ int64_t shared_idxs[256]; + + unsigned int tid = threadIdx.x; + + shared_vals[tid] = -FLT_MAX; + shared_idxs[tid] = -1; + + for (unsigned int i = tid; i < size; i += blockDim.x) { + float v = vals[i]; + if (v > shared_vals[tid]) { + shared_vals[tid] = v; + shared_idxs[tid] = i; + } + } + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + if (shared_vals[tid + s] > shared_vals[tid]) { + shared_vals[tid] = shared_vals[tid + s]; + shared_idxs[tid] = shared_idxs[tid + s]; + } + } + __syncthreads(); + } + + if (tid == 0) { + max_val[0] = shared_vals[0]; + *max_idx = shared_idxs[0]; + } +} + +void argmax(std::byte *max_idx_bytes, std::byte *max_val_bytes, const std::byte *vals_bytes, llaisysDataType_t type, size_t size) { + auto max_idx = reinterpret_cast(max_idx_bytes); + auto max_val = reinterpret_cast(max_val_bytes); + auto vals = reinterpret_cast(vals_bytes); + + const int blockSize = 256; + + switch (type) { + case LLAISYS_DTYPE_F32: + argmaxF32Kernel<<<1, blockSize>>>( + reinterpret_cast(vals), + reinterpret_cast(max_val), + max_idx, size); + break; + case LLAISYS_DTYPE_F16: { + // Workaround: Output to temp float buffer, then convert on host + float *d_max_val_float; + cudaMalloc(&d_max_val_float, sizeof(float)); + + argmaxF16Kernel<<<1, blockSize>>>( + reinterpret_cast(vals), + d_max_val_float, + max_idx, size); + cudaDeviceSynchronize(); + + float h_max_val; + cudaMemcpy(&h_max_val, d_max_val_float, sizeof(float), cudaMemcpyDeviceToHost); + cudaFree(d_max_val_float); + + // Host-side float to f16 conversion + union { float f; uint32_t u; } fu; + fu.f = h_max_val; + uint32_t f32_bits = fu.u; + + uint32_t sign = (f32_bits >> 31) & 0x1; + int32_t exp = ((f32_bits >> 23) & 0xFF) - 127; + uint32_t mant = f32_bits & 0x7FFFFF; + + uint16_t f16_bits; + if (exp == 128) { + f16_bits = (sign << 15) | (mant ? 0x7E00 : 0x7C00); + } else if (exp <= -15) { + f16_bits = sign << 15; + } else if (exp >= 16) { + f16_bits = (sign << 15) | 0x7C00; + } else { + int32_t new_exp = exp + 15; + uint32_t new_mant = (mant + 0xFFF) >> 13; + if (new_mant >= 0x400) { + new_mant = 0; + if (++new_exp >= 31) { + f16_bits = (sign << 15) | 0x7C00; + } else { + f16_bits = (sign << 15) | (new_exp << 10); + } + } else { + f16_bits = (sign << 15) | (new_exp << 10) | (new_mant & 0x3FF); + } + } + + cudaMemcpy(max_val, &f16_bits, sizeof(uint16_t), cudaMemcpyHostToDevice); + break; + } + case LLAISYS_DTYPE_BF16: + argmaxBF16Kernel<<<1, blockSize>>>( + reinterpret_cast(vals), + reinterpret_cast(max_val), + max_idx, size); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + + cudaDeviceSynchronize(); +} + +} diff --git a/src/ops/argmax/metax/argmax_metax.cuh b/src/ops/argmax/metax/argmax_metax.cuh new file mode 100644 index 00000000..e7915a10 --- /dev/null +++ b/src/ops/argmax/metax/argmax_metax.cuh @@ -0,0 +1,16 @@ +#pragma once + +#include "../../../device/runtime_api.hpp" +#include "llaisys/tensor.h" + +#include + +namespace llaisys::ops::metax { + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t size); + +// Debug functions +void test_f16_write(std::byte *out, int mode); +void test_f32_write(std::byte *out); + +} // namespace llaisys::ops::metax diff --git a/src/ops/argmax/nvidia/argmax_nvidia.cu b/src/ops/argmax/nvidia/argmax_nvidia.cu new file mode 100644 index 00000000..3c6a7454 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_nvidia.cu @@ -0,0 +1,77 @@ +#include "argmax_nvidia.cuh" + +#include "../../../device/nvidia/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include +#include + +namespace llaisys::ops::nvidia { + +// Simpler approach: single block reduction for small sizes +template +__global__ void argmaxSimpleKernel(const T *vals, T *max_val, int64_t *max_idx, size_t size) { + extern __shared__ float svals[]; + int64_t *sidxs = (int64_t*)&svals[blockDim.x]; + + unsigned int tid = threadIdx.x; + + // Initialize + svals[tid] = -FLT_MAX; + sidxs[tid] = 0; + + // Each thread processes multiple elements + for (unsigned int i = tid; i < size; i += blockDim.x) { + float v = to_float_cuda(vals[i]); + if (v > svals[tid]) { + svals[tid] = v; + sidxs[tid] = i; + } + } + __syncthreads(); + + // Reduction within block + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + if (svals[tid + s] > svals[tid]) { + svals[tid] = svals[tid + s]; + sidxs[tid] = sidxs[tid + s]; + } + } + __syncthreads(); + } + + if (tid == 0) { + *max_val = from_float_cuda(svals[0]); + *max_idx = sidxs[0]; + } +} + +template +void argmax_(std::byte *max_idx_bytes, std::byte *max_val_bytes, const std::byte *vals_bytes, size_t size) { + auto max_idx = reinterpret_cast(max_idx_bytes); + auto max_val = reinterpret_cast(max_val_bytes); + auto vals = reinterpret_cast(vals_bytes); + + // Use single block with 256 threads for simplicity + // This works well for typical LLM vocab sizes (up to ~100k) + const int blockSize = 256; + size_t sharedMemSize = blockSize * sizeof(float) + blockSize * sizeof(int64_t); + argmaxSimpleKernel<<<1, blockSize, sharedMemSize>>>(vals, max_val, max_idx, size); +} + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t size) { + switch (type) { + case LLAISYS_DTYPE_F32: + return argmax_(max_idx, max_val, vals, size); + case LLAISYS_DTYPE_BF16: + return argmax_(max_idx, max_val, vals, size); + case LLAISYS_DTYPE_F16: + return argmax_(max_idx, max_val, vals, size); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/argmax/nvidia/argmax_nvidia.cuh b/src/ops/argmax/nvidia/argmax_nvidia.cuh new file mode 100644 index 00000000..f4ba25e1 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_nvidia.cuh @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t size); +} diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d42..a5634109 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,48 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/argmax_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/argmax_nvidia.cuh" +#endif + +#ifdef ENABLE_METAX_API +#include "metax/argmax_metax.cuh" +#endif + namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(max_idx, max_val, vals); + // Only support 1D tensors for now + ASSERT(vals->ndim() == 1, "Argmax: vals must be 1D tensor."); + ASSERT(max_idx->ndim() == 1 && max_idx->numel() == 1, "Argmax: max_idx must be 1D tensor with 1 element."); + ASSERT(max_val->ndim() == 1 && max_val->numel() == 1, "Argmax: max_val must be 1D tensor with 1 element."); + ASSERT(max_idx->dtype() == LLAISYS_DTYPE_I64, "Argmax: max_idx must be int64 dtype."); + ASSERT(max_val->dtype() == vals->dtype(), "Argmax: max_val and vals must have same dtype."); + ASSERT(max_idx->isContiguous() && max_val->isContiguous() && vals->isContiguous(), "Argmax: all tensors must be contiguous."); + + if (vals->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); + } + + llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId()); + + switch (vals->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } -} // namespace llaisys::ops +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 00000000..df40dcd9 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,37 @@ +#include "embedding_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::cpu { + +template +void embedding_(std::byte *out_bytes, const std::byte *index_bytes, const std::byte *weight_bytes, size_t index_size, size_t embed_dim) { + auto out = reinterpret_cast(out_bytes); + auto index = reinterpret_cast(index_bytes); + auto weight = reinterpret_cast(weight_bytes); + + for (size_t i = 0; i < index_size; ++i) { + int64_t idx = index[i]; + const T *src = weight + idx * embed_dim; + T *dst = out + i * embed_dim; + std::memcpy(dst, src, embed_dim * sizeof(T)); + } +} + +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t index_size, size_t embed_dim) { + switch (type) { + case LLAISYS_DTYPE_F32: + return embedding_(out, index, weight, index_size, embed_dim); + case LLAISYS_DTYPE_BF16: + return embedding_(out, index, weight, index_size, embed_dim); + case LLAISYS_DTYPE_F16: + return embedding_(out, index, weight, index_size, embed_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 00000000..fe77592d --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t index_size, size_t embed_dim); +} \ No newline at end of file diff --git a/src/ops/embedding/metax/embedding_metax.cu b/src/ops/embedding/metax/embedding_metax.cu new file mode 100644 index 00000000..6273c406 --- /dev/null +++ b/src/ops/embedding/metax/embedding_metax.cu @@ -0,0 +1,55 @@ +#include "embedding_metax.cuh" + +#include "../../../device/metax/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::metax { + +template +__global__ void embeddingKernel(T *out, const int64_t *index, const T *weight, + size_t index_size, size_t embed_dim) { + // Each thread handles one element of the embedding + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total_elements = index_size * embed_dim; + + if (idx < total_elements) { + size_t i = idx / embed_dim; // Which index + size_t j = idx % embed_dim; // Which dimension + + int64_t vocab_idx = index[i]; + out[idx] = weight[vocab_idx * embed_dim + j]; + } +} + +template +void embedding_(std::byte *out_bytes, const std::byte *index_bytes, const std::byte *weight_bytes, + size_t index_size, size_t embed_dim) { + auto out = reinterpret_cast(out_bytes); + auto index = reinterpret_cast(index_bytes); + auto weight = reinterpret_cast(weight_bytes); + + size_t total_elements = index_size * embed_dim; + const int blockSize = 256; + const int numBlocks = (total_elements + blockSize - 1) / blockSize; + + embeddingKernel<<>>(out, index, weight, index_size, embed_dim); +} + +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, size_t index_size, size_t embed_dim) { + switch (type) { + case LLAISYS_DTYPE_F32: + return embedding_(out, index, weight, index_size, embed_dim); + case LLAISYS_DTYPE_BF16: + return embedding_(out, index, weight, index_size, embed_dim); + case LLAISYS_DTYPE_F16: + return embedding_(out, index, weight, index_size, embed_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/embedding/metax/embedding_metax.cuh b/src/ops/embedding/metax/embedding_metax.cuh new file mode 100644 index 00000000..9899172a --- /dev/null +++ b/src/ops/embedding/metax/embedding_metax.cuh @@ -0,0 +1,13 @@ +#pragma once + +#include "../../../device/runtime_api.hpp" +#include "llaisys/tensor.h" + +#include + +namespace llaisys::ops::metax { + +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, size_t index_size, size_t embed_dim); + +} // namespace llaisys::ops::metax diff --git a/src/ops/embedding/nvidia/embedding_nvidia.cu b/src/ops/embedding/nvidia/embedding_nvidia.cu new file mode 100644 index 00000000..5fec6876 --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_nvidia.cu @@ -0,0 +1,55 @@ +#include "embedding_nvidia.cuh" + +#include "../../../device/nvidia/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::nvidia { + +template +__global__ void embeddingKernel(T *out, const int64_t *index, const T *weight, + size_t index_size, size_t embed_dim) { + // Each thread handles one element of the embedding + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total_elements = index_size * embed_dim; + + if (idx < total_elements) { + size_t i = idx / embed_dim; // Which index + size_t j = idx % embed_dim; // Which dimension + + int64_t vocab_idx = index[i]; + out[idx] = weight[vocab_idx * embed_dim + j]; + } +} + +template +void embedding_(std::byte *out_bytes, const std::byte *index_bytes, const std::byte *weight_bytes, + size_t index_size, size_t embed_dim) { + auto out = reinterpret_cast(out_bytes); + auto index = reinterpret_cast(index_bytes); + auto weight = reinterpret_cast(weight_bytes); + + size_t total_elements = index_size * embed_dim; + const int blockSize = 256; + const int numBlocks = (total_elements + blockSize - 1) / blockSize; + + embeddingKernel<<>>(out, index, weight, index_size, embed_dim); +} + +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, size_t index_size, size_t embed_dim) { + switch (type) { + case LLAISYS_DTYPE_F32: + return embedding_(out, index, weight, index_size, embed_dim); + case LLAISYS_DTYPE_BF16: + return embedding_(out, index, weight, index_size, embed_dim); + case LLAISYS_DTYPE_F16: + return embedding_(out, index, weight, index_size, embed_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/embedding/nvidia/embedding_nvidia.cuh b/src/ops/embedding/nvidia/embedding_nvidia.cuh new file mode 100644 index 00000000..5cb16285 --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_nvidia.cuh @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t index_size, size_t embed_dim); +} diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d0..0bb3c265 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,50 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/embedding_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/embedding_nvidia.cuh" +#endif + +#ifdef ENABLE_METAX_API +#include "metax/embedding_metax.cuh" +#endif + namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, index, weight); + // Only support 1D index and 2D weight for now + ASSERT(index->ndim() == 1, "Embedding: index must be 1D tensor."); + ASSERT(weight->ndim() == 2, "Embedding: weight must be 2D tensor."); + ASSERT(out->ndim() == 2, "Embedding: out must be 2D tensor."); + ASSERT(index->dtype() == LLAISYS_DTYPE_I64, "Embedding: index must be int64 dtype."); + ASSERT(out->dtype() == weight->dtype(), "Embedding: out and weight must have same dtype."); + ASSERT(out->shape()[0] == index->numel(), "Embedding: out shape[0] must match index size."); + ASSERT(out->shape()[1] == weight->shape()[1], "Embedding: out shape[1] must match weight shape[1]."); + ASSERT(out->isContiguous() && index->isContiguous() && weight->isContiguous(), "Embedding: all tensors must be contiguous."); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::embedding(out->data(), index->data(), weight->data(), weight->dtype(), index->numel(), weight->shape()[1]); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::embedding(out->data(), index->data(), weight->data(), weight->dtype(), index->numel(), weight->shape()[1]); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::embedding(out->data(), index->data(), weight->data(), weight->dtype(), index->numel(), weight->shape()[1]); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::embedding(out->data(), index->data(), weight->data(), weight->dtype(), index->numel(), weight->shape()[1]); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } -} // namespace llaisys::ops +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 00000000..9af33ec0 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,69 @@ +#include "linear_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::cpu { + +template +void linear_(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, const std::byte *bias_bytes, size_t batch_size, size_t in_dim, size_t out_dim) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto weight = reinterpret_cast(weight_bytes); + auto bias = reinterpret_cast(bias_bytes); + + // Initialize output with bias if provided + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < out_dim; ++j) { + if (bias) { + out[i * out_dim + j] = bias[j]; + } else { + out[i * out_dim + j] = llaisys::utils::cast(0.0f); + } + } + } + + // Compute Y = xW^T + b + if constexpr (std::is_same_v) { + // For float32, compute directly without casting + for (size_t i = 0; i < batch_size; ++i) { + for (size_t k = 0; k < in_dim; ++k) { + float x = in[i * in_dim + k]; + for (size_t j = 0; j < out_dim; ++j) { + float w = weight[j * in_dim + k]; + out[i * out_dim + j] += x * w; + } + } + } + } else { + // For float16 and bfloat16, compute in float32 for better precision + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < out_dim; ++j) { + float sum = llaisys::utils::cast(out[i * out_dim + j]); + for (size_t k = 0; k < in_dim; ++k) { + float x = llaisys::utils::cast(in[i * in_dim + k]); + float w = llaisys::utils::cast(weight[j * in_dim + k]); + sum += x * w; + } + out[i * out_dim + j] = llaisys::utils::cast(sum); + } + } + } +} + +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, llaisysDataType_t type, size_t batch_size, size_t in_dim, size_t out_dim) { + switch (type) { + case LLAISYS_DTYPE_F32: + return linear_(out, in, weight, bias, batch_size, in_dim, out_dim); + case LLAISYS_DTYPE_BF16: + return linear_(out, in, weight, bias, batch_size, in_dim, out_dim); + case LLAISYS_DTYPE_F16: + return linear_(out, in, weight, bias, batch_size, in_dim, out_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 00000000..699e6840 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, llaisysDataType_t type, size_t batch_size, size_t in_dim, size_t out_dim); +} \ No newline at end of file diff --git a/src/ops/linear/metax/linear_metax.cu b/src/ops/linear/metax/linear_metax.cu new file mode 100644 index 00000000..eda4cb73 --- /dev/null +++ b/src/ops/linear/metax/linear_metax.cu @@ -0,0 +1,276 @@ +#include "linear_metax.cuh" + +#include "../../../device/metax/cuda_utils.cuh" +#include "../../../device/metax/metax_resource.cuh" +#include "../../../utils.hpp" + +#include +#include +#include + +namespace llaisys::ops::metax { + +// Bias add kernel for mcblas version +template +__global__ void addBiasKernel(T *out, const T *bias, size_t batch_size, size_t out_dim) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = batch_size * out_dim; + + for (size_t i = idx; i < total; i += blockDim.x * gridDim.x) { + size_t col = i % out_dim; + float val = to_float_metax(out[i]) + to_float_metax(bias[col]); + out[i] = from_float_metax(val); + } +} + +// Helper to get mcblas data type (mapped via cuBLAS wrapper) +macaDataType getMacaDataType(llaisysDataType_t type) { + switch (type) { + case LLAISYS_DTYPE_F32: + return MACA_R_32F; + case LLAISYS_DTYPE_F16: + return MACA_R_16F; + case LLAISYS_DTYPE_BF16: + return MACA_R_16BF; + default: + return MACA_R_32F; + } +} + +// mcBLAS GEMM wrapper +template +void gemm_mcblas(T *out, const T *weight, const T *in, + size_t out_dim, size_t batch_size, size_t in_dim, + llaisysDataType_t dtype) { + // Get singleton mcblas handle + mcblasHandle_t handle = llaisys::device::metax::McblasHandle::get(); + + float alpha = 1.0f; + float beta = 0.0f; + + macaDataType maca_dtype = getMacaDataType(dtype); + + // Linear: Y = X * W^T + // X: [batch_size, in_dim] + // W: [out_dim, in_dim] + // Y: [batch_size, out_dim] + // + // mcBLAS is column-major, so we compute: + // Y^T = W * X^T + // C = A * B where: + // A = W [in_dim x out_dim] (transposed to [out_dim x in_dim]) + // B = X^T [in_dim x batch_size] + // C = Y^T [out_dim x batch_size] + + mcblasGemmEx( + handle, + MCBLAS_OP_T, // transa: transpose weight + MCBLAS_OP_N, // transb: no transpose input + static_cast(out_dim), // m + static_cast(batch_size), // n + static_cast(in_dim), // k + &alpha, + weight, // A + maca_dtype, // A type + static_cast(in_dim), // lda + in, // B + maca_dtype, // B type + static_cast(in_dim), // ldb + &beta, + out, // C + maca_dtype, // C type + static_cast(out_dim), // ldc + MCBLAS_COMPUTE_32F, // compute type + MCBLAS_GEMM_DEFAULT // algorithm + ); +} + +// Template specializations for mcblas version +template +void linear_mcblas(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, + const std::byte *bias_bytes, size_t batch_size, size_t in_dim, size_t out_dim, + llaisysDataType_t dtype); + +template <> +void linear_mcblas(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, + const std::byte *bias_bytes, size_t batch_size, size_t in_dim, size_t out_dim, + llaisysDataType_t dtype) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto weight = reinterpret_cast(weight_bytes); + + gemm_mcblas(out, weight, in, out_dim, batch_size, in_dim, dtype); + + // Add bias + if (bias_bytes) { + auto bias = reinterpret_cast(bias_bytes); + size_t total = batch_size * out_dim; + const int blockSize = 256; + const int numBlocks = (total + blockSize - 1) / blockSize; + addBiasKernel<<>>(out, bias, batch_size, out_dim); + } +} + +template <> +void linear_mcblas(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, + const std::byte *bias_bytes, size_t batch_size, size_t in_dim, size_t out_dim, + llaisysDataType_t dtype) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto weight = reinterpret_cast(weight_bytes); + + gemm_mcblas(out, weight, in, out_dim, batch_size, in_dim, dtype); + + if (bias_bytes) { + auto bias = reinterpret_cast(bias_bytes); + size_t total = batch_size * out_dim; + const int blockSize = 256; + const int numBlocks = (total + blockSize - 1) / blockSize; + addBiasKernel<<>>(out, bias, batch_size, out_dim); + } +} + +template <> +void linear_mcblas(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, + const std::byte *bias_bytes, size_t batch_size, size_t in_dim, size_t out_dim, + llaisysDataType_t dtype) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto weight = reinterpret_cast(weight_bytes); + + gemm_mcblas(out, weight, in, out_dim, batch_size, in_dim, dtype); + + if (bias_bytes) { + auto bias = reinterpret_cast(bias_bytes); + size_t total = batch_size * out_dim; + const int blockSize = 256; + const int numBlocks = (total + blockSize - 1) / blockSize; + addBiasKernel<<>>(out, bias, batch_size, out_dim); + } +} + +// Fallback: Improved linear kernel with more accurate accumulation +// Uses Kahan summation for better precision +template +__global__ void linearKernel(T *out, const T *in, const T *weight, const T *bias, + size_t batch_size, size_t in_dim, size_t out_dim) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = batch_size * out_dim; + + for (size_t linear_idx = idx; linear_idx < total; linear_idx += blockDim.x * gridDim.x) { + size_t i = linear_idx / out_dim; + size_t j = linear_idx % out_dim; + + // Kahan summation for better precision + float sum = 0.0f; + float c = 0.0f; // Compensation for lost low-order bits + + for (size_t k = 0; k < in_dim; ++k) { + float x = to_float_metax(in[i * in_dim + k]); + float w = to_float_metax(weight[j * in_dim + k]); + float y = x * w - c; + float t = sum + y; + c = (t - sum) - y; + sum = t; + } + + if (bias) { + sum += to_float_metax(bias[j]); + } + + out[linear_idx] = from_float_metax(sum); + } +} + +// Alternative: blocked accumulation for very large dimensions +template +__global__ void linearKernelBlocked(T *out, const T *in, const T *weight, const T *bias, + size_t batch_size, size_t in_dim, size_t out_dim) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = batch_size * out_dim; + + const size_t BLOCK_SIZE = 256; // Process in blocks for better precision + + for (size_t linear_idx = idx; linear_idx < total; linear_idx += blockDim.x * gridDim.x) { + size_t i = linear_idx / out_dim; + size_t j = linear_idx % out_dim; + + float sum = 0.0f; + + // Process in blocks + for (size_t k_start = 0; k_start < in_dim; k_start += BLOCK_SIZE) { + float block_sum = 0.0f; + size_t k_end = min(k_start + BLOCK_SIZE, in_dim); + + for (size_t k = k_start; k < k_end; ++k) { + float x = to_float_metax(in[i * in_dim + k]); + float w = to_float_metax(weight[j * in_dim + k]); + block_sum += x * w; + } + sum += block_sum; + } + + if (bias) { + sum += to_float_metax(bias[j]); + } + + out[linear_idx] = from_float_metax(sum); + } +} + +template +void linear_fallback(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, + const std::byte *bias_bytes, size_t batch_size, size_t in_dim, size_t out_dim) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto weight = reinterpret_cast(weight_bytes); + auto bias = reinterpret_cast(bias_bytes); + + size_t total_outputs = batch_size * out_dim; + const int blockSize = 256; + const int numBlocks = (total_outputs + blockSize - 1) / blockSize; + + // Use blocked kernel for large dimensions (common in transformers) + if (in_dim > 512) { + linearKernelBlocked<<>>(out, in, weight, bias, batch_size, in_dim, out_dim); + } else { + linearKernel<<>>(out, in, weight, bias, batch_size, in_dim, out_dim); + } +} + +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t batch_size, size_t in_dim, size_t out_dim) { + // Use mcBLAS for better performance on large matrices + // Fallback to custom kernel for small matrices or special cases + const size_t USE_MCBLAS_THRESHOLD = 256; // Use mcBLAS for dim >= 256 + + bool use_mcblas = (batch_size >= USE_MCBLAS_THRESHOLD || in_dim >= USE_MCBLAS_THRESHOLD || + out_dim >= USE_MCBLAS_THRESHOLD); + + if (use_mcblas) { + switch (type) { + case LLAISYS_DTYPE_F32: + return linear_mcblas(out, in, weight, bias, batch_size, in_dim, out_dim, type); + case LLAISYS_DTYPE_BF16: + return linear_mcblas(out, in, weight, bias, batch_size, in_dim, out_dim, type); + case LLAISYS_DTYPE_F16: + return linear_mcblas(out, in, weight, bias, batch_size, in_dim, out_dim, type); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + } else { + // Fallback to custom kernel + switch (type) { + case LLAISYS_DTYPE_F32: + return linear_fallback(out, in, weight, bias, batch_size, in_dim, out_dim); + case LLAISYS_DTYPE_BF16: + return linear_fallback(out, in, weight, bias, batch_size, in_dim, out_dim); + case LLAISYS_DTYPE_F16: + return linear_fallback(out, in, weight, bias, batch_size, in_dim, out_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + } +} + +} // namespace llaisys::ops::metax diff --git a/src/ops/linear/metax/linear_metax.cuh b/src/ops/linear/metax/linear_metax.cuh new file mode 100644 index 00000000..1959be76 --- /dev/null +++ b/src/ops/linear/metax/linear_metax.cuh @@ -0,0 +1,13 @@ +#pragma once + +#include "../../../device/runtime_api.hpp" +#include "llaisys/tensor.h" + +#include + +namespace llaisys::ops::metax { + +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t batch_size, size_t in_dim, size_t out_dim); + +} // namespace llaisys::ops::metax diff --git a/src/ops/linear/nvidia/linear_nvidia.cu b/src/ops/linear/nvidia/linear_nvidia.cu new file mode 100644 index 00000000..5a3d119b --- /dev/null +++ b/src/ops/linear/nvidia/linear_nvidia.cu @@ -0,0 +1,183 @@ +#include "linear_nvidia.cuh" + +#include "../../../device/nvidia/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include +#include + +namespace llaisys::ops::nvidia { + +// Bias add kernel +template +__global__ void addBiasKernel(T *out, const T *bias, size_t batch_size, size_t out_dim) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = batch_size * out_dim; + + for (size_t i = idx; i < total; i += blockDim.x * gridDim.x) { + size_t col = i % out_dim; + float val = to_float_cuda(out[i]) + to_float_cuda(bias[col]); + out[i] = from_float_cuda(val); + } +} + +// cuBLAS handle wrapper for lazy initialization +class CublasHandle { +public: + static cublasHandle_t& get() { + static CublasHandle instance; + return instance.handle; + } +private: + cublasHandle_t handle; + CublasHandle() { + cublasCreate(&handle); + } + ~CublasHandle() { + cublasDestroy(handle); + } + CublasHandle(const CublasHandle&) = delete; + CublasHandle& operator=(const CublasHandle&) = delete; +}; + +// Helper to get CUDA data type +cudaDataType getCudaDataType(llaisysDataType_t type) { + switch (type) { + case LLAISYS_DTYPE_F32: + return CUDA_R_32F; + case LLAISYS_DTYPE_F16: + return CUDA_R_16F; + case LLAISYS_DTYPE_BF16: + return CUDA_R_16BF; // CUDA 11.0+ supports bfloat16 + default: + return CUDA_R_32F; + } +} + +// cuBLAS GEMM wrapper +template +void gemm_cublas(T *out, const T *weight, const T *in, + size_t out_dim, size_t batch_size, size_t in_dim, + llaisysDataType_t dtype) { + cublasHandle_t handle = CublasHandle::get(); + + float alpha = 1.0f; + float beta = 0.0f; + + cudaDataType cuda_dtype = getCudaDataType(dtype); + + // Linear: Y = X * W^T + // X: [batch_size, in_dim] + // W: [out_dim, in_dim] + // Y: [batch_size, out_dim] + // + // cuBLAS is column-major, so we compute: + // Y^T = W * X^T + // C = A * B where: + // A = W [in_dim x out_dim] (transposed to [out_dim x in_dim]) + // B = X^T [in_dim x batch_size] + // C = Y^T [out_dim x batch_size] + + cublasGemmEx( + handle, + CUBLAS_OP_T, // transA: transpose weight + CUBLAS_OP_N, // transB: no transpose input + static_cast(out_dim), // m + static_cast(batch_size), // n + static_cast(in_dim), // k + &alpha, + weight, // A + cuda_dtype, // A type + static_cast(in_dim), // lda + in, // B + cuda_dtype, // B type + static_cast(in_dim), // ldb + &beta, + out, // C + cuda_dtype, // C type + static_cast(out_dim), // ldc + CUBLAS_COMPUTE_32F, // compute type + CUBLAS_GEMM_DEFAULT_TENSOR_OP // use Tensor Core + ); +} + +// Template specializations for different types +template +void linear_cublas(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, + const std::byte *bias_bytes, size_t batch_size, size_t in_dim, size_t out_dim, + llaisysDataType_t dtype); + +template <> +void linear_cublas(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, + const std::byte *bias_bytes, size_t batch_size, size_t in_dim, size_t out_dim, + llaisysDataType_t dtype) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto weight = reinterpret_cast(weight_bytes); + + gemm_cublas(out, weight, in, out_dim, batch_size, in_dim, dtype); + + // Add bias + if (bias_bytes) { + auto bias = reinterpret_cast(bias_bytes); + size_t total = batch_size * out_dim; + const int blockSize = 256; + const int numBlocks = (total + blockSize - 1) / blockSize; + addBiasKernel<<>>(out, bias, batch_size, out_dim); + } +} + +template <> +void linear_cublas(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, + const std::byte *bias_bytes, size_t batch_size, size_t in_dim, size_t out_dim, + llaisysDataType_t dtype) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto weight = reinterpret_cast(weight_bytes); + + gemm_cublas(out, weight, in, out_dim, batch_size, in_dim, dtype); + + if (bias_bytes) { + auto bias = reinterpret_cast(bias_bytes); + size_t total = batch_size * out_dim; + const int blockSize = 256; + const int numBlocks = (total + blockSize - 1) / blockSize; + addBiasKernel<<>>(out, bias, batch_size, out_dim); + } +} + +template <> +void linear_cublas(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, + const std::byte *bias_bytes, size_t batch_size, size_t in_dim, size_t out_dim, + llaisysDataType_t dtype) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto weight = reinterpret_cast(weight_bytes); + + gemm_cublas(out, weight, in, out_dim, batch_size, in_dim, dtype); + + if (bias_bytes) { + auto bias = reinterpret_cast(bias_bytes); + size_t total = batch_size * out_dim; + const int blockSize = 256; + const int numBlocks = (total + blockSize - 1) / blockSize; + addBiasKernel<<>>(out, bias, batch_size, out_dim); + } +} + +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t batch_size, size_t in_dim, size_t out_dim) { + switch (type) { + case LLAISYS_DTYPE_F32: + return linear_cublas(out, in, weight, bias, batch_size, in_dim, out_dim, type); + case LLAISYS_DTYPE_BF16: + return linear_cublas(out, in, weight, bias, batch_size, in_dim, out_dim, type); + case LLAISYS_DTYPE_F16: + return linear_cublas(out, in, weight, bias, batch_size, in_dim, out_dim, type); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/linear/nvidia/linear_nvidia.cuh b/src/ops/linear/nvidia/linear_nvidia.cuh new file mode 100644 index 00000000..738f70e6 --- /dev/null +++ b/src/ops/linear/nvidia/linear_nvidia.cuh @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t batch_size, size_t in_dim, size_t out_dim); +} diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f865..b16f854c 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,102 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/linear_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/linear_nvidia.cuh" +#endif + +#ifdef ENABLE_METAX_API +#include "metax/linear_metax.cuh" +#endif + namespace llaisys::ops { void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight); + if (bias) { + CHECK_SAME_DEVICE(out, bias); + } + + // Only support 2D tensors for now + ASSERT(in->ndim() == 2, "Linear: in must be 2D tensor."); + ASSERT(weight->ndim() == 2, "Linear: weight must be 2D tensor."); + ASSERT(out->ndim() == 2, "Linear: out must be 2D tensor."); + ASSERT(in->dtype() == out->dtype() && in->dtype() == weight->dtype(), "Linear: in, out, weight must have same dtype."); + ASSERT(in->shape()[1] == weight->shape()[1], "Linear: in shape[1] must match weight shape[1]."); + ASSERT(out->shape()[0] == in->shape()[0], "Linear: out shape[0] must match in shape[0]."); + ASSERT(out->shape()[1] == weight->shape()[0], "Linear: out shape[1] must match weight shape[0]."); + + if (bias) { + ASSERT(bias->ndim() == 1, "Linear: bias must be 1D tensor."); + ASSERT(bias->shape()[0] == out->shape()[1], "Linear: bias shape[0] must match out shape[1]."); + ASSERT(bias->dtype() == in->dtype(), "Linear: bias must have same dtype as in."); + ASSERT(bias->isContiguous(), "Linear: bias must be contiguous."); + } + + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), "Linear: all tensors must be contiguous."); + + size_t batch_size = in->shape()[0]; + size_t in_dim = in->shape()[1]; + size_t out_dim = out->shape()[1]; + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::linear( + out->data(), + in->data(), + weight->data(), + bias ? bias->data() : nullptr, + in->dtype(), + batch_size, + in_dim, + out_dim + ); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::linear( + out->data(), + in->data(), + weight->data(), + bias ? bias->data() : nullptr, + in->dtype(), + batch_size, + in_dim, + out_dim + ); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::linear( + out->data(), + in->data(), + weight->data(), + bias ? bias->data() : nullptr, + in->dtype(), + batch_size, + in_dim, + out_dim + ); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::linear( + out->data(), + in->data(), + weight->data(), + bias ? bias->data() : nullptr, + in->dtype(), + batch_size, + in_dim, + out_dim + ); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } -} // namespace llaisys::ops +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.cpp b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp new file mode 100644 index 00000000..0eeb5bd9 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp @@ -0,0 +1,50 @@ +#include "rms_norm_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::cpu { + +template +void rms_norm_(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, float eps, size_t batch_size, size_t hidden_dim) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto weight = reinterpret_cast(weight_bytes); + + for (size_t i = 0; i < batch_size; ++i) { + // Compute sum of squares + float sum_sq = 0.0f; + for (size_t j = 0; j < hidden_dim; ++j) { + float x = llaisys::utils::cast(in[i * hidden_dim + j]); + sum_sq += x * x; + } + + // Compute RMS + float rms = std::sqrt(sum_sq / hidden_dim + eps); + + // Normalize and apply weight + for (size_t j = 0; j < hidden_dim; ++j) { + float x = llaisys::utils::cast(in[i * hidden_dim + j]); + float w = llaisys::utils::cast(weight[j]); + float y = (x / rms) * w; + out[i * hidden_dim + j] = llaisys::utils::cast(y); + } + } +} + +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, float eps, size_t batch_size, size_t hidden_dim) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rms_norm_(out, in, weight, eps, batch_size, hidden_dim); + case LLAISYS_DTYPE_BF16: + return rms_norm_(out, in, weight, eps, batch_size, hidden_dim); + case LLAISYS_DTYPE_F16: + return rms_norm_(out, in, weight, eps, batch_size, hidden_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.hpp b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp new file mode 100644 index 00000000..48ab3415 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, float eps, size_t batch_size, size_t hidden_dim); +} \ No newline at end of file diff --git a/src/ops/rms_norm/metax/rms_norm_metax.cu b/src/ops/rms_norm/metax/rms_norm_metax.cu new file mode 100644 index 00000000..53b02745 --- /dev/null +++ b/src/ops/rms_norm/metax/rms_norm_metax.cu @@ -0,0 +1,80 @@ +#include "rms_norm_metax.cuh" + +#include "../../../device/metax/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::metax { + +template +__global__ void rmsNormKernel(T *out, const T *in, const T *weight, float eps, size_t hidden_dim) { + // Each block processes one row + size_t row = blockIdx.x; + size_t tid = threadIdx.x; + + // Pointer to the start of this row + const T *in_row = in + row * hidden_dim; + T *out_row = out + row * hidden_dim; + + // Compute sum of squares using shared memory + extern __shared__ float shared_sum[]; + + float local_sum = 0.0f; + for (size_t i = tid; i < hidden_dim; i += blockDim.x) { + float val = to_float_metax(in_row[i]); + local_sum += val * val; + } + + shared_sum[tid] = local_sum; + __syncthreads(); + + // Parallel reduction in shared memory + for (size_t s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_sum[tid] += shared_sum[tid + s]; + } + __syncthreads(); + } + + // Compute RMS + float rms = sqrtf(shared_sum[0] / hidden_dim + eps); + + // Normalize and apply weight + for (size_t i = tid; i < hidden_dim; i += blockDim.x) { + float val = to_float_metax(in_row[i]); + float w = to_float_metax(weight[i]); + float result = (val / rms) * w; + out_row[i] = from_float_metax(result); + } +} + +template +void rms_norm_(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, + float eps, size_t batch_size, size_t hidden_dim) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto weight = reinterpret_cast(weight_bytes); + + const int blockSize = 256; + size_t sharedMemSize = blockSize * sizeof(float); + + rmsNormKernel<<>>(out, in, weight, eps, hidden_dim); +} + +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, + llaisysDataType_t type, float eps, size_t batch_size, size_t hidden_dim) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rms_norm_(out, in, weight, eps, batch_size, hidden_dim); + case LLAISYS_DTYPE_BF16: + return rms_norm_(out, in, weight, eps, batch_size, hidden_dim); + case LLAISYS_DTYPE_F16: + return rms_norm_(out, in, weight, eps, batch_size, hidden_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/rms_norm/metax/rms_norm_metax.cuh b/src/ops/rms_norm/metax/rms_norm_metax.cuh new file mode 100644 index 00000000..74b3ce60 --- /dev/null +++ b/src/ops/rms_norm/metax/rms_norm_metax.cuh @@ -0,0 +1,13 @@ +#pragma once + +#include "../../../device/runtime_api.hpp" +#include "llaisys/tensor.h" + +#include + +namespace llaisys::ops::metax { + +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, + llaisysDataType_t type, float eps, size_t batch_size, size_t hidden_dim); + +} // namespace llaisys::ops::metax diff --git a/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu new file mode 100644 index 00000000..d4d4185c --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu @@ -0,0 +1,80 @@ +#include "rms_norm_nvidia.cuh" + +#include "../../../device/nvidia/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::nvidia { + +template +__global__ void rmsNormKernel(T *out, const T *in, const T *weight, float eps, size_t hidden_dim) { + // Each block processes one row + size_t row = blockIdx.x; + size_t tid = threadIdx.x; + + // Pointer to the start of this row + const T *in_row = in + row * hidden_dim; + T *out_row = out + row * hidden_dim; + + // Compute sum of squares using shared memory + extern __shared__ float shared_sum[]; + + float local_sum = 0.0f; + for (size_t i = tid; i < hidden_dim; i += blockDim.x) { + float val = to_float_cuda(in_row[i]); + local_sum += val * val; + } + + shared_sum[tid] = local_sum; + __syncthreads(); + + // Parallel reduction in shared memory + for (size_t s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_sum[tid] += shared_sum[tid + s]; + } + __syncthreads(); + } + + // Compute RMS + float rms = sqrtf(shared_sum[0] / hidden_dim + eps); + + // Normalize and apply weight + for (size_t i = tid; i < hidden_dim; i += blockDim.x) { + float val = to_float_cuda(in_row[i]); + float w = to_float_cuda(weight[i]); + float result = (val / rms) * w; + out_row[i] = from_float_cuda(result); + } +} + +template +void rms_norm_(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *weight_bytes, + float eps, size_t batch_size, size_t hidden_dim) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto weight = reinterpret_cast(weight_bytes); + + const int blockSize = 256; + size_t sharedMemSize = blockSize * sizeof(float); + + rmsNormKernel<<>>(out, in, weight, eps, hidden_dim); +} + +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, + llaisysDataType_t type, float eps, size_t batch_size, size_t hidden_dim) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rms_norm_(out, in, weight, eps, batch_size, hidden_dim); + case LLAISYS_DTYPE_BF16: + return rms_norm_(out, in, weight, eps, batch_size, hidden_dim); + case LLAISYS_DTYPE_F16: + return rms_norm_(out, in, weight, eps, batch_size, hidden_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/rms_norm/nvidia/rms_norm_nvidia.cuh b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cuh new file mode 100644 index 00000000..6cd1d30f --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cuh @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, float eps, size_t batch_size, size_t hidden_dim); +} diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9..adeea0bb 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,85 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rms_norm_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/rms_norm_nvidia.cuh" +#endif + +#ifdef ENABLE_METAX_API +#include "metax/rms_norm_metax.cuh" +#endif + namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight); + + // Only support 2D tensors for now + ASSERT(in->ndim() == 2, "RMSNorm: in must be 2D tensor."); + ASSERT(out->ndim() == 2, "RMSNorm: out must be 2D tensor."); + ASSERT(weight->ndim() == 1, "RMSNorm: weight must be 1D tensor."); + ASSERT(in->dtype() == out->dtype() && in->dtype() == weight->dtype(), "RMSNorm: in, out, weight must have same dtype."); + ASSERT(in->shape() == out->shape(), "RMSNorm: in and out must have same shape."); + ASSERT(weight->shape()[0] == in->shape()[1], "RMSNorm: weight shape[0] must match in shape[1]."); + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), "RMSNorm: all tensors must be contiguous."); + + size_t batch_size = in->shape()[0]; + size_t hidden_dim = in->shape()[1]; + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rms_norm( + out->data(), + in->data(), + weight->data(), + in->dtype(), + eps, + batch_size, + hidden_dim + ); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rms_norm( + out->data(), + in->data(), + weight->data(), + in->dtype(), + eps, + batch_size, + hidden_dim + ); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::rms_norm( + out->data(), + in->data(), + weight->data(), + in->dtype(), + eps, + batch_size, + hidden_dim + ); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::rms_norm( + out->data(), + in->data(), + weight->data(), + in->dtype(), + eps, + batch_size, + hidden_dim + ); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } -} // namespace llaisys::ops +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 00000000..9637193b --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,70 @@ +#include "rope_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include + +namespace llaisys::ops::cpu { + +template +void rope_(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *pos_ids_bytes, float theta, size_t seq_len, size_t num_heads, size_t head_dim) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto pos_ids = reinterpret_cast(pos_ids_bytes); + + size_t half_dim = head_dim / 2; + + for (size_t i = 0; i < seq_len; ++i) { + int64_t p = pos_ids[i]; + + for (size_t h = 0; h < num_heads; ++h) { + for (size_t j = 0; j < half_dim; ++j) { + // Compute rotation angle following PyTorch implementation + // freqs = positions / (theta ** (2 * i / head_dim)) + float dim = static_cast(j); + float exponent = 2.0f * dim / head_dim; + float theta_pow = std::pow(theta, exponent); + float phi = static_cast(p) / theta_pow; + + // Compute sin and cos + float cos_phi = std::cos(phi); + float sin_phi = std::sin(phi); + + // Get input values [a, b] + size_t index_a = i * num_heads * head_dim + h * head_dim + j; + size_t index_b = index_a + half_dim; + + T a = in[index_a]; + T b = in[index_b]; + + // Apply rotation + float a_f = llaisys::utils::cast(a); + float b_f = llaisys::utils::cast(b); + + float a_prime = a_f * cos_phi - b_f * sin_phi; + float b_prime = b_f * cos_phi + a_f * sin_phi; + + // Store result + out[index_a] = llaisys::utils::cast(a_prime); + out[index_b] = llaisys::utils::cast(b_prime); + } + } + } +} + +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, float theta, size_t seq_len, size_t num_heads, size_t head_dim) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rope_(out, in, pos_ids, theta, seq_len, num_heads, head_dim); + case LLAISYS_DTYPE_BF16: + return rope_(out, in, pos_ids, theta, seq_len, num_heads, head_dim); + case LLAISYS_DTYPE_F16: + return rope_(out, in, pos_ids, theta, seq_len, num_heads, head_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 00000000..554e1aab --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, float theta, size_t seq_len, size_t num_heads, size_t head_dim); +} \ No newline at end of file diff --git a/src/ops/rope/metax/rope_metax.cu b/src/ops/rope/metax/rope_metax.cu new file mode 100644 index 00000000..a5024031 --- /dev/null +++ b/src/ops/rope/metax/rope_metax.cu @@ -0,0 +1,84 @@ +#include "rope_metax.cuh" + +#include "../../../device/metax/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::metax { + +template +__global__ void ropeKernel(T *out, const T *in, const int64_t *pos_ids, float theta, + size_t seqlen, size_t nhead, size_t d) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = seqlen * nhead * d; + + if (idx >= total) return; + + // Calculate indices + size_t i = idx / (nhead * d); // sequence index + size_t h = (idx / d) % nhead; // head index + size_t j = idx % d; // dimension index + + int64_t pos = pos_ids[i]; + + // Split dimension into two halves + size_t half_d = d / 2; + + // Get the position in the first or second half + if (j < half_d) { + // First half: a (cos part) + float angle = pos / powf(theta, (2.0f * j) / d); + float cos_val = cosf(angle); + float sin_val = sinf(angle); + + float a = to_float_metax(in[i * nhead * d + h * d + j]); + float b = to_float_metax(in[i * nhead * d + h * d + j + half_d]); + + float result = a * cos_val - b * sin_val; + out[idx] = from_float_metax(result); + } else { + // Second half: b (sin part) + size_t j_in_half = j - half_d; + float angle = pos / powf(theta, (2.0f * j_in_half) / d); + float cos_val = cosf(angle); + float sin_val = sinf(angle); + + float a = to_float_metax(in[i * nhead * d + h * d + j_in_half]); + float b = to_float_metax(in[i * nhead * d + h * d + j]); + + float result = b * cos_val + a * sin_val; + out[idx] = from_float_metax(result); + } +} + +template +void rope_(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *pos_ids_bytes, + float theta, size_t seqlen, size_t nhead, size_t d) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto pos_ids = reinterpret_cast(pos_ids_bytes); + + size_t total = seqlen * nhead * d; + const int blockSize = 256; + const int numBlocks = (total + blockSize - 1) / blockSize; + + ropeKernel<<>>(out, in, pos_ids, theta, seqlen, nhead, d); +} + +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, + llaisysDataType_t type, float theta, size_t seqlen, size_t nhead, size_t d) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rope_(out, in, pos_ids, theta, seqlen, nhead, d); + case LLAISYS_DTYPE_BF16: + return rope_(out, in, pos_ids, theta, seqlen, nhead, d); + case LLAISYS_DTYPE_F16: + return rope_(out, in, pos_ids, theta, seqlen, nhead, d); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/rope/metax/rope_metax.cuh b/src/ops/rope/metax/rope_metax.cuh new file mode 100644 index 00000000..0959ea94 --- /dev/null +++ b/src/ops/rope/metax/rope_metax.cuh @@ -0,0 +1,13 @@ +#pragma once + +#include "../../../device/runtime_api.hpp" +#include "llaisys/tensor.h" + +#include + +namespace llaisys::ops::metax { + +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, + llaisysDataType_t type, float theta, size_t seqlen, size_t nhead, size_t d); + +} // namespace llaisys::ops::metax diff --git a/src/ops/rope/nvidia/rope_nvidia.cu b/src/ops/rope/nvidia/rope_nvidia.cu new file mode 100644 index 00000000..811747f9 --- /dev/null +++ b/src/ops/rope/nvidia/rope_nvidia.cu @@ -0,0 +1,84 @@ +#include "rope_nvidia.cuh" + +#include "../../../device/nvidia/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::nvidia { + +template +__global__ void ropeKernel(T *out, const T *in, const int64_t *pos_ids, float theta, + size_t seq_len, size_t num_heads, size_t head_dim) { + size_t half_dim = head_dim / 2; + + // Each thread processes one (seq, head, half_dim) position + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total_pos = seq_len * num_heads * half_dim; + + if (idx >= total_pos) return; + + size_t h = idx % half_dim; + size_t tmp = idx / half_dim; + size_t head = tmp % num_heads; + size_t s = tmp / num_heads; + + int64_t p = pos_ids[s]; + + // Compute rotation angle + float dim = static_cast(h); + float exponent = 2.0f * dim / head_dim; + float theta_pow = powf(theta, exponent); + float phi = static_cast(p) / theta_pow; + + float cos_phi = cosf(phi); + float sin_phi = sinf(phi); + + // Get input values [a, b] + size_t index_a = s * num_heads * head_dim + head * head_dim + h; + size_t index_b = index_a + half_dim; + + float a = to_float_cuda(in[index_a]); + float b = to_float_cuda(in[index_b]); + + // Apply rotation + float a_prime = a * cos_phi - b * sin_phi; + float b_prime = b * cos_phi + a * sin_phi; + + // Store result + out[index_a] = from_float_cuda(a_prime); + out[index_b] = from_float_cuda(b_prime); +} + +template +void rope_(std::byte *out_bytes, const std::byte *in_bytes, const std::byte *pos_ids_bytes, + float theta, size_t seq_len, size_t num_heads, size_t head_dim) { + auto out = reinterpret_cast(out_bytes); + auto in = reinterpret_cast(in_bytes); + auto pos_ids = reinterpret_cast(pos_ids_bytes); + + size_t half_dim = head_dim / 2; + size_t total_pos = seq_len * num_heads * half_dim; + + const int blockSize = 256; + const int numBlocks = (total_pos + blockSize - 1) / blockSize; + + ropeKernel<<>>(out, in, pos_ids, theta, seq_len, num_heads, head_dim); +} + +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, + llaisysDataType_t type, float theta, size_t seq_len, size_t num_heads, size_t head_dim) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rope_(out, in, pos_ids, theta, seq_len, num_heads, head_dim); + case LLAISYS_DTYPE_BF16: + return rope_(out, in, pos_ids, theta, seq_len, num_heads, head_dim); + case LLAISYS_DTYPE_F16: + return rope_(out, in, pos_ids, theta, seq_len, num_heads, head_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/rope/nvidia/rope_nvidia.cuh b/src/ops/rope/nvidia/rope_nvidia.cuh new file mode 100644 index 00000000..d0410919 --- /dev/null +++ b/src/ops/rope/nvidia/rope_nvidia.cuh @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, float theta, size_t seq_len, size_t num_heads, size_t head_dim); +} diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64..e30c5481 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,93 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rope_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/rope_nvidia.cuh" +#endif + +#ifdef ENABLE_METAX_API +#include "metax/rope_metax.cuh" +#endif + namespace llaisys::ops { void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in); + CHECK_SAME_DEVICE(out, pos_ids); + + // Only support 3D tensors for now: [seqlen, nhead, d] or [seqlen, nkvhead, d] + ASSERT(in->ndim() == 3, "RoPE: in must be 3D tensor."); + ASSERT(out->ndim() == 3, "RoPE: out must be 3D tensor."); + ASSERT(pos_ids->ndim() == 1, "RoPE: pos_ids must be 1D tensor."); + ASSERT(in->dtype() == out->dtype(), "RoPE: in and out must have same dtype."); + ASSERT(pos_ids->dtype() == LLAISYS_DTYPE_I64, "RoPE: pos_ids must be int64 dtype."); + ASSERT(in->shape() == out->shape(), "RoPE: in and out must have same shape."); + ASSERT(pos_ids->shape()[0] == in->shape()[0], "RoPE: pos_ids shape[0] must match in shape[0]."); + ASSERT(in->shape()[2] % 2 == 0, "RoPE: in shape[2] must be even."); + ASSERT(out->isContiguous() && in->isContiguous() && pos_ids->isContiguous(), "RoPE: all tensors must be contiguous."); + + size_t seq_len = in->shape()[0]; + size_t num_heads = in->shape()[1]; + size_t head_dim = in->shape()[2]; + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rope( + out->data(), + in->data(), + pos_ids->data(), + in->dtype(), + theta, + seq_len, + num_heads, + head_dim + ); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rope( + out->data(), + in->data(), + pos_ids->data(), + in->dtype(), + theta, + seq_len, + num_heads, + head_dim + ); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::rope( + out->data(), + in->data(), + pos_ids->data(), + in->dtype(), + theta, + seq_len, + num_heads, + head_dim + ); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::rope( + out->data(), + in->data(), + pos_ids->data(), + in->dtype(), + theta, + seq_len, + num_heads, + head_dim + ); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } -} // namespace llaisys::ops +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/self_attention/cpu/self_attention_cpu.cpp b/src/ops/self_attention/cpu/self_attention_cpu.cpp new file mode 100644 index 00000000..3823b3e6 --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.cpp @@ -0,0 +1,125 @@ +#include "self_attention_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include +#include + +namespace llaisys::ops::cpu { + +// Helper function for softmax with numerical stability +void softmax(float *x, size_t size) { + float max_val = -std::numeric_limits::infinity(); + for (size_t i = 0; i < size; ++i) { + if (x[i] > max_val) { + max_val = x[i]; + } + } + + if (max_val == -std::numeric_limits::infinity()) { + for (size_t i = 0; i < size; ++i) { + x[i] = 1.0f / size; + } + return; + } + + float sum = 0.0f; + for (size_t i = 0; i < size; ++i) { + x[i] = std::exp(x[i] - max_val); + sum += x[i]; + } + + for (size_t i = 0; i < size; ++i) { + x[i] /= sum; + } +} + +template +void self_attention_(std::byte *attn_val_bytes, const std::byte *q_bytes, const std::byte *k_bytes, const std::byte *v_bytes, float scale, size_t seq_len, size_t total_len, size_t nhead, size_t nkvhead, size_t d, size_t dv) { + auto attn_val = reinterpret_cast(attn_val_bytes); + auto q = reinterpret_cast(q_bytes); + auto k = reinterpret_cast(k_bytes); + auto v = reinterpret_cast(v_bytes); + + // Calculate repeats for Grouped Query Attention (GQA) + size_t num_repeats = nhead / nkvhead; + + // Outer loop over heads + for (size_t h = 0; h < nhead; ++h) { + // Map current query head `h` to the corresponding key/value head `kvh` + size_t kvh = h / num_repeats; + + // Loop over Sequence Length (Time) + for (size_t i = 0; i < seq_len; ++i) { + float* attn_scores = new float[total_len]; + + // 1. Compute Q @ K^T + for (size_t j = 0; j < total_len; ++j) { + double dot_product = 0.0; + for (size_t k_idx = 0; k_idx < d; ++k_idx) { + // Correct Indexing for [seq_len, nhead, d] layout + // Stride(i) = nhead * d + // Stride(h) = d + size_t q_pos = i * nhead * d + h * d + k_idx; + double q_val = static_cast(llaisys::utils::cast(q[q_pos])); + + // Correct Indexing for [total_len, nkvhead, d] layout + // Stride(j) = nkvhead * d + // Stride(kvh) = d + size_t k_pos = j * nkvhead * d + kvh * d + k_idx; + double k_val = static_cast(llaisys::utils::cast(k[k_pos])); + + dot_product += q_val * k_val; + } + + attn_scores[j] = static_cast(dot_product * static_cast(scale)); + } + + // 2. Apply Causal Mask + size_t diagonal = total_len - seq_len; + for (size_t j = 0; j < total_len; ++j) { + if (j > i + diagonal) { + attn_scores[j] = -std::numeric_limits::infinity(); + } + } + + // 3. Softmax + softmax(attn_scores, total_len); + + // 4. Compute Weighted Sum (Scores @ V) + for (size_t v_idx = 0; v_idx < dv; ++v_idx) { + double weighted_sum = 0.0; + for (size_t j = 0; j < total_len; ++j) { + // Correct Indexing for [total_len, nkvhead, dv] layout + size_t v_pos = j * nkvhead * dv + kvh * dv + v_idx; + double v_val = static_cast(llaisys::utils::cast(v[v_pos])); + + weighted_sum += static_cast(attn_scores[j]) * v_val; + } + + // Output layout is [seq_len, nhead, dv] + size_t out_pos = i * nhead * dv + h * dv + v_idx; + attn_val[out_pos] = llaisys::utils::cast(static_cast(weighted_sum)); + } + + delete[] attn_scores; + } + } +} + +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, llaisysDataType_t type, float scale, size_t seq_len, size_t total_len, size_t nhead, size_t nkvhead, size_t d, size_t dv) { + switch (type) { + case LLAISYS_DTYPE_F32: + return self_attention_(attn_val, q, k, v, scale, seq_len, total_len, nhead, nkvhead, d, dv); + case LLAISYS_DTYPE_BF16: + return self_attention_(attn_val, q, k, v, scale, seq_len, total_len, nhead, nkvhead, d, dv); + case LLAISYS_DTYPE_F16: + return self_attention_(attn_val, q, k, v, scale, seq_len, total_len, nhead, nkvhead, d, dv); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/self_attention/cpu/self_attention_cpu.hpp b/src/ops/self_attention/cpu/self_attention_cpu.hpp new file mode 100644 index 00000000..c5fc1db3 --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, llaisysDataType_t type, float scale, size_t seq_len, size_t total_len, size_t nhead, size_t nkvhead, size_t d, size_t dv); +} \ No newline at end of file diff --git a/src/ops/self_attention/metax/self_attention_metax.cu b/src/ops/self_attention/metax/self_attention_metax.cu new file mode 100644 index 00000000..224976d1 --- /dev/null +++ b/src/ops/self_attention/metax/self_attention_metax.cu @@ -0,0 +1,182 @@ +#include "self_attention_metax.cuh" + +#include "../../../device/metax/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include +#include + +namespace llaisys::ops::metax { + +// Kernel to compute Q*K^T and apply causal mask +template +__global__ void qkKernel(T *attn_scores, const T *q, const T *k, float scale, + size_t seqlen, size_t nhead, size_t nkvhead, size_t d, size_t total_len) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = seqlen * nhead * total_len; + if (idx >= total) return; + + size_t i = idx / (nhead * total_len); + size_t h = (idx / total_len) % nhead; + size_t j = idx % total_len; + + size_t kv_offset = total_len - seqlen; + if (j > i + kv_offset) { + attn_scores[idx] = from_float_metax(-FLT_MAX); + return; + } + + size_t kv_h = h / (nhead / nkvhead); + + // Blocked accumulation for better precision + float sum = 0.0f; + const size_t BLOCK_SIZE = 64; + for (size_t dim_start = 0; dim_start < d; dim_start += BLOCK_SIZE) { + float block_sum = 0.0f; + size_t dim_end = min(dim_start + BLOCK_SIZE, d); + for (size_t dim = dim_start; dim < dim_end; ++dim) { + float q_val = to_float_metax(q[i * nhead * d + h * d + dim]); + float k_val = to_float_metax(k[j * nkvhead * d + kv_h * d + dim]); + block_sum += q_val * k_val; + } + sum += block_sum; + } + + attn_scores[idx] = from_float_metax(sum * scale); +} + +// Kernel for softmax (per row) +template +__global__ void softmaxKernel(T *attn_scores, size_t seqlen, size_t nhead, size_t total_len) { + size_t row = blockIdx.x; + size_t tid = threadIdx.x; + + if (row >= seqlen * nhead) return; + + T *row_ptr = attn_scores + row * total_len; + extern __shared__ float shared_mem[]; + float *shared_max = shared_mem; + float *shared_sum = shared_mem + blockDim.x; + + float local_max = -FLT_MAX; + for (size_t i = tid; i < total_len; i += blockDim.x) { + float val = to_float_metax(row_ptr[i]); + if (val != -FLT_MAX) { + local_max = fmaxf(local_max, val); + } + } + shared_max[tid] = local_max; + __syncthreads(); + + for (size_t s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + s]); + } + __syncthreads(); + } + float row_max = shared_max[0]; + + float local_sum = 0.0f; + for (size_t i = tid; i < total_len; i += blockDim.x) { + float val = to_float_metax(row_ptr[i]); + if (val != -FLT_MAX) { + float exp_val = expf(val - row_max); + local_sum += exp_val; + row_ptr[i] = from_float_metax(exp_val); + } else { + row_ptr[i] = from_float_metax(0.0f); + } + } + shared_sum[tid] = local_sum; + __syncthreads(); + + for (size_t s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_sum[tid] += shared_sum[tid + s]; + } + __syncthreads(); + } + float row_sum = shared_sum[0]; + + for (size_t i = tid; i < total_len; i += blockDim.x) { + float val = to_float_metax(row_ptr[i]); + row_ptr[i] = from_float_metax(val / row_sum); + } +} + +// Kernel for attention * V with blocked accumulation +template +__global__ void attnVKernel(T *attn_val, const T *attn_scores, const T *v, + size_t seqlen, size_t nhead, size_t nkvhead, size_t dv, size_t total_len) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total = seqlen * nhead * dv; + if (idx >= total) return; + + size_t i = idx / (nhead * dv); + size_t h = (idx / dv) % nhead; + size_t dim = idx % dv; + size_t kv_h = h / (nhead / nkvhead); + + // Blocked accumulation for better precision + float sum = 0.0f; + const size_t BLOCK_SIZE = 64; + for (size_t j_start = 0; j_start < total_len; j_start += BLOCK_SIZE) { + float block_sum = 0.0f; + size_t j_end = min(j_start + BLOCK_SIZE, total_len); + for (size_t j = j_start; j < j_end; ++j) { + float score = to_float_metax(attn_scores[i * nhead * total_len + h * total_len + j]); + float v_val = to_float_metax(v[j * nkvhead * dv + kv_h * dv + dim]); + block_sum += score * v_val; + } + sum += block_sum; + } + + attn_val[idx] = from_float_metax(sum); +} + +template +void self_attention_(std::byte *attn_val_bytes, const std::byte *q_bytes, const std::byte *k_bytes, + const std::byte *v_bytes, float scale, size_t seqlen, size_t nhead, + size_t nkvhead, size_t d, size_t dv, size_t total_len) { + auto attn_val = reinterpret_cast(attn_val_bytes); + auto q = reinterpret_cast(q_bytes); + auto k = reinterpret_cast(k_bytes); + auto v = reinterpret_cast(v_bytes); + + T *attn_scores; + size_t scores_size = seqlen * nhead * total_len * sizeof(T); + cudaMalloc(&attn_scores, scores_size); + + const int blockSize = 256; + + size_t qk_total = seqlen * nhead * total_len; + int qk_blocks = (qk_total + blockSize - 1) / blockSize; + qkKernel<<>>(attn_scores, q, k, scale, seqlen, nhead, nkvhead, d, total_len); + + size_t softmax_shared_mem = 2 * blockSize * sizeof(float); + softmaxKernel<<>>(attn_scores, seqlen, nhead, total_len); + + size_t attn_v_total = seqlen * nhead * dv; + int attn_v_blocks = (attn_v_total + blockSize - 1) / blockSize; + attnVKernel<<>>(attn_val, attn_scores, v, seqlen, nhead, nkvhead, dv, total_len); + + cudaFree(attn_scores); +} + +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, float scale, size_t seqlen, size_t nhead, size_t nkvhead, + size_t d, size_t dv, size_t total_len) { + switch (type) { + case LLAISYS_DTYPE_F32: + return self_attention_(attn_val, q, k, v, scale, seqlen, nhead, nkvhead, d, dv, total_len); + case LLAISYS_DTYPE_BF16: + return self_attention_(attn_val, q, k, v, scale, seqlen, nhead, nkvhead, d, dv, total_len); + case LLAISYS_DTYPE_F16: + return self_attention_(attn_val, q, k, v, scale, seqlen, nhead, nkvhead, d, dv, total_len); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/self_attention/metax/self_attention_metax.cuh b/src/ops/self_attention/metax/self_attention_metax.cuh new file mode 100644 index 00000000..d95466c6 --- /dev/null +++ b/src/ops/self_attention/metax/self_attention_metax.cuh @@ -0,0 +1,14 @@ +#pragma once + +#include "../../../device/runtime_api.hpp" +#include "llaisys/tensor.h" + +#include + +namespace llaisys::ops::metax { + +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, float scale, size_t seqlen, size_t nhead, size_t nkvhead, + size_t d, size_t dv, size_t total_len); + +} // namespace llaisys::ops::metax diff --git a/src/ops/self_attention/nvidia/self_attention_nvidia.cu b/src/ops/self_attention/nvidia/self_attention_nvidia.cu new file mode 100644 index 00000000..b18d11b5 --- /dev/null +++ b/src/ops/self_attention/nvidia/self_attention_nvidia.cu @@ -0,0 +1,227 @@ +#include "self_attention_nvidia.cuh" + +#include "../../../device/nvidia/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include +#include + +namespace llaisys::ops::nvidia { + +// Compute Q * K^T for one attention head +// Q: [seq_len, d], K: [total_len, d], Output: [seq_len, total_len] +template +__global__ void qkT_kernel(const T *q, const T *k, float *scores, + size_t seq_len, size_t total_len, size_t d, + size_t nhead, size_t nkvhead, float scale) { + // Each thread computes one (i, j) position in the score matrix + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total_scores = seq_len * total_len; + + if (idx >= total_scores) return; + + size_t j = idx % total_len; // Key position + size_t i = idx / total_len; // Query position + + // Map query head to key/value head (GQA) + size_t num_repeats = nhead / nkvhead; + + float dot_product = 0.0f; + for (size_t k_idx = 0; k_idx < d; ++k_idx) { + // Q layout: [seq_len, nhead, d] + size_t q_pos = i * nhead * d + blockIdx.y * d + k_idx; + float q_val = to_float_cuda(q[q_pos]); + + // K layout: [total_len, nkvhead, d] + size_t kvh = blockIdx.y / num_repeats; + size_t k_pos = j * nkvhead * d + kvh * d + k_idx; + float k_val = to_float_cuda(k[k_pos]); + + dot_product += q_val * k_val; + } + + // Apply scale + scores[idx * nhead + blockIdx.y] = dot_product * scale; +} + +// Apply causal mask and softmax +// scores: [seq_len, total_len, nhead] +template +__global__ void causal_softmax_kernel(float *scores, T *attn_weights, + size_t seq_len, size_t total_len, size_t nhead) { + // Each block handles one (query position, head) + size_t i = blockIdx.x; // Query position + size_t h = blockIdx.y; // Head + + size_t diagonal = total_len - seq_len; + + extern __shared__ float shared_mem[]; + float *sdata = shared_mem; + float *smax = &shared_mem[blockDim.x]; + + // Load scores into shared memory with causal mask + float thread_max = -FLT_MAX; + for (size_t j = threadIdx.x; j < total_len; j += blockDim.x) { + size_t idx = (i * total_len + j) * nhead + h; + float val = scores[idx]; + + // Apply causal mask + if (j > i + diagonal) { + val = -FLT_MAX; + } + sdata[j] = val; + if (val > thread_max) thread_max = val; + } + __syncthreads(); + + // Parallel reduction for max + smax[threadIdx.x] = thread_max; + __syncthreads(); + + for (size_t s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + if (smax[threadIdx.x + s] > smax[threadIdx.x]) { + smax[threadIdx.x] = smax[threadIdx.x + s]; + } + } + __syncthreads(); + } + float max_val = smax[0]; + + // Compute exp and thread-local sum + float thread_sum = 0.0f; + for (size_t j = threadIdx.x; j < total_len; j += blockDim.x) { + float exp_val = expf(sdata[j] - max_val); + sdata[j] = exp_val; + thread_sum += exp_val; + } + __syncthreads(); + + // Parallel reduction for sum + __shared__ float ssum[256]; + ssum[threadIdx.x] = thread_sum; + __syncthreads(); + + for (size_t s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + ssum[threadIdx.x] += ssum[threadIdx.x + s]; + } + __syncthreads(); + } + float sum_val = ssum[0]; + + // Normalize and write output + for (size_t j = threadIdx.x; j < total_len; j += blockDim.x) { + size_t idx = (i * total_len + j) * nhead + h; + scores[idx] = sdata[j] / sum_val; + } +} + +// Compute attention output: attn_weights * V +// attn_weights: [seq_len, total_len, nhead], V: [total_len, nkvhead, dv] +// Output: [seq_len, nhead, dv] +template +__global__ void attn_v_kernel(const float *attn_weights, const T *v, T *out, + size_t seq_len, size_t total_len, size_t nhead, size_t nkvhead, size_t dv) { + // Each thread computes one output element + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t total_outputs = seq_len * nhead * dv; + + if (idx >= total_outputs) return; + + size_t tmp = idx; + size_t v_idx = tmp % dv; + tmp /= dv; + size_t h = tmp % nhead; + size_t i = tmp / nhead; + + // Map query head to key/value head (GQA) + size_t num_repeats = nhead / nkvhead; + size_t kvh = h / num_repeats; + + float weighted_sum = 0.0f; + for (size_t j = 0; j < total_len; ++j) { + size_t attn_idx = (i * total_len + j) * nhead + h; + float w = attn_weights[attn_idx]; + + // V layout: [total_len, nkvhead, dv] + size_t v_pos = j * nkvhead * dv + kvh * dv + v_idx; + float v_val = to_float_cuda(v[v_pos]); + + weighted_sum += w * v_val; + } + + // Output layout: [seq_len, nhead, dv] + size_t out_pos = i * nhead * dv + h * dv + v_idx; + out[out_pos] = from_float_cuda(weighted_sum); +} + +template +void self_attention_(std::byte *attn_val_bytes, const std::byte *q_bytes, + const std::byte *k_bytes, const std::byte *v_bytes, + float scale, size_t seq_len, size_t total_len, + size_t nhead, size_t nkvhead, size_t d, size_t dv) { + auto attn_val = reinterpret_cast(attn_val_bytes); + auto q = reinterpret_cast(q_bytes); + auto k = reinterpret_cast(k_bytes); + auto v = reinterpret_cast(v_bytes); + + // Use thread_local static buffer to avoid repeated cudaMalloc + // Max size: 4096 * 4096 * 12 * 4 = 768MB + static thread_local float *d_scores_buffer = nullptr; + static thread_local size_t buffer_size = 0; + + size_t scores_size = seq_len * total_len * nhead * sizeof(float); + + // Allocate or reallocate if needed + if (d_scores_buffer == nullptr || buffer_size < scores_size) { + if (d_scores_buffer != nullptr) { + cudaFree(d_scores_buffer); + } + cudaMalloc(&d_scores_buffer, scores_size); + buffer_size = scores_size; + } + + float *d_scores = d_scores_buffer; + + // Step 1: Compute Q * K^T + size_t total_scores = seq_len * total_len; + const int blockSize = 256; + const int numBlocks = (total_scores + blockSize - 1) / blockSize; + + dim3 qk_grid(numBlocks, nhead); + qkT_kernel<<>>(q, k, d_scores, seq_len, total_len, d, nhead, nkvhead, scale); + + // Step 2: Apply causal softmax + // Shared memory: sdata[total_len] + smax[blockSize] + dim3 softmax_grid(seq_len, nhead); + size_t shared_mem_size = total_len * sizeof(float) + blockSize * sizeof(float); + causal_softmax_kernel<<>>( + d_scores, attn_val, seq_len, total_len, nhead); + + // Step 3: Compute attention output + size_t total_outputs = seq_len * nhead * dv; + const int out_numBlocks = (total_outputs + blockSize - 1) / blockSize; + attn_v_kernel<<>>( + d_scores, v, attn_val, seq_len, total_len, nhead, nkvhead, dv); + + // Buffer is cached for reuse, no need to free +} + +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, float scale, size_t seq_len, size_t total_len, + size_t nhead, size_t nkvhead, size_t d, size_t dv) { + switch (type) { + case LLAISYS_DTYPE_F32: + return self_attention_(attn_val, q, k, v, scale, seq_len, total_len, nhead, nkvhead, d, dv); + case LLAISYS_DTYPE_BF16: + return self_attention_(attn_val, q, k, v, scale, seq_len, total_len, nhead, nkvhead, d, dv); + case LLAISYS_DTYPE_F16: + return self_attention_(attn_val, q, k, v, scale, seq_len, total_len, nhead, nkvhead, d, dv); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/self_attention/nvidia/self_attention_nvidia.cuh b/src/ops/self_attention/nvidia/self_attention_nvidia.cuh new file mode 100644 index 00000000..5cedf3c2 --- /dev/null +++ b/src/ops/self_attention/nvidia/self_attention_nvidia.cuh @@ -0,0 +1,10 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, float scale, size_t seq_len, size_t total_len, + size_t nhead, size_t nkvhead, size_t d, size_t dv); +} diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d62014..ff33630b 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,124 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/self_attention_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/self_attention_nvidia.cuh" +#endif + +#ifdef ENABLE_METAX_API +#include "metax/self_attention_metax.cuh" +#endif + namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(attn_val, q, k, v); + + // Check tensor dimensions + ASSERT(q->ndim() == 3, "SelfAttention: q must be 3D tensor."); + ASSERT(k->ndim() == 3, "SelfAttention: k must be 3D tensor."); + ASSERT(v->ndim() == 3, "SelfAttention: v must be 3D tensor."); + ASSERT(attn_val->ndim() == 3, "SelfAttention: attn_val must be 3D tensor."); + + // Check dtypes + ASSERT(q->dtype() == k->dtype() && q->dtype() == v->dtype() && q->dtype() == attn_val->dtype(), + "SelfAttention: all tensors must have same dtype."); + + // Get dimensions + size_t seq_len = q->shape()[0]; + size_t total_len = k->shape()[0]; + size_t nhead = q->shape()[1]; + size_t nkvhead = k->shape()[1]; + size_t d = q->shape()[2]; + size_t dv = v->shape()[2]; + + // Check shapes + ASSERT(k->shape()[2] == d, "SelfAttention: k shape[2] must match q shape[2]."); + ASSERT(v->shape()[0] == total_len, "SelfAttention: v shape[0] must match k shape[0]."); + ASSERT(v->shape()[1] == nkvhead, "SelfAttention: v shape[1] must match k shape[1]."); + ASSERT(attn_val->shape()[0] == seq_len, "SelfAttention: attn_val shape[0] must match q shape[0]."); + ASSERT(attn_val->shape()[1] == nhead, "SelfAttention: attn_val shape[1] must match q shape[1]."); + ASSERT(attn_val->shape()[2] == dv, "SelfAttention: attn_val shape[2] must match v shape[2]."); + ASSERT(nhead % nkvhead == 0, "SelfAttention: nhead must be divisible by nkvhead."); + + // Check contiguity + ASSERT(attn_val->isContiguous() && q->isContiguous() && k->isContiguous() && v->isContiguous(), + "SelfAttention: all tensors must be contiguous."); + + if (attn_val->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::self_attention( + attn_val->data(), + q->data(), + k->data(), + v->data(), + q->dtype(), + scale, + seq_len, + total_len, + nhead, + nkvhead, + d, + dv + ); + } + + llaisys::core::context().setDevice(attn_val->deviceType(), attn_val->deviceId()); + + switch (attn_val->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::self_attention( + attn_val->data(), + q->data(), + k->data(), + v->data(), + q->dtype(), + scale, + seq_len, + total_len, + nhead, + nkvhead, + d, + dv + ); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::self_attention( + attn_val->data(), + q->data(), + k->data(), + v->data(), + q->dtype(), + scale, + seq_len, + total_len, + nhead, + nkvhead, + d, + dv + ); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::self_attention( + attn_val->data(), + q->data(), + k->data(), + v->data(), + q->dtype(), + scale, + seq_len, + nhead, + nkvhead, + d, + dv, + total_len + ); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } -} // namespace llaisys::ops +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 00000000..b1118c9c --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,48 @@ +#include "swiglu_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::cpu { + +template +void swiglu_(std::byte *out_bytes, const std::byte *gate_bytes, const std::byte *up_bytes, size_t seq_len, size_t intermediate_size) { + auto out = reinterpret_cast(out_bytes); + auto gate = reinterpret_cast(gate_bytes); + auto up = reinterpret_cast(up_bytes); + + for (size_t i = 0; i < seq_len; ++i) { + for (size_t j = 0; j < intermediate_size; ++j) { + size_t index = i * intermediate_size + j; + float gate_val = llaisys::utils::cast(gate[index]); + float up_val = llaisys::utils::cast(up[index]); + + // Compute gate / (1 + e^{-gate}) + float exp_neg_gate = std::exp(-gate_val); + float denominator = 1.0f + exp_neg_gate; + float gate_div = gate_val / denominator; + + // Compute out = up * (gate / (1 + e^{-gate})) + float out_val = up_val * gate_div; + + out[index] = llaisys::utils::cast(out_val); + } + } +} + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t type, size_t seq_len, size_t intermediate_size) { + switch (type) { + case LLAISYS_DTYPE_F32: + return swiglu_(out, gate, up, seq_len, intermediate_size); + case LLAISYS_DTYPE_BF16: + return swiglu_(out, gate, up, seq_len, intermediate_size); + case LLAISYS_DTYPE_F16: + return swiglu_(out, gate, up, seq_len, intermediate_size); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 00000000..9f749462 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t type, size_t seq_len, size_t intermediate_size); +} \ No newline at end of file diff --git a/src/ops/swiglu/metax/swiglu_metax.cu b/src/ops/swiglu/metax/swiglu_metax.cu new file mode 100644 index 00000000..12e27719 --- /dev/null +++ b/src/ops/swiglu/metax/swiglu_metax.cu @@ -0,0 +1,53 @@ +#include "swiglu_metax.cuh" + +#include "../../../device/metax/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::metax { + +// SiLU (Swish) activation: x * sigmoid(x) +__device__ inline float silu(float x) { + return x / (1.0f + expf(-x)); +} + +template +__global__ void swigluKernel(T *out, const T *gate, const T *up, size_t numel) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) { + float g = to_float_metax(gate[idx]); + float u = to_float_metax(up[idx]); + float result = u * silu(g); + out[idx] = from_float_metax(result); + } +} + +template +void swiglu_(std::byte *out_bytes, const std::byte *gate_bytes, const std::byte *up_bytes, size_t numel) { + auto out = reinterpret_cast(out_bytes); + auto gate = reinterpret_cast(gate_bytes); + auto up = reinterpret_cast(up_bytes); + + const int blockSize = 256; + const int numBlocks = (numel + blockSize - 1) / blockSize; + + swigluKernel<<>>(out, gate, up, numel); +} + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t type, size_t numel) { + switch (type) { + case LLAISYS_DTYPE_F32: + return swiglu_(out, gate, up, numel); + case LLAISYS_DTYPE_BF16: + return swiglu_(out, gate, up, numel); + case LLAISYS_DTYPE_F16: + return swiglu_(out, gate, up, numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/swiglu/metax/swiglu_metax.cuh b/src/ops/swiglu/metax/swiglu_metax.cuh new file mode 100644 index 00000000..d0e086d3 --- /dev/null +++ b/src/ops/swiglu/metax/swiglu_metax.cuh @@ -0,0 +1,13 @@ +#pragma once + +#include "../../../device/runtime_api.hpp" +#include "llaisys/tensor.h" + +#include + +namespace llaisys::ops::metax { + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t type, size_t numel); + +} // namespace llaisys::ops::metax diff --git a/src/ops/swiglu/nvidia/swiglu_nvidia.cu b/src/ops/swiglu/nvidia/swiglu_nvidia.cu new file mode 100644 index 00000000..6167610c --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_nvidia.cu @@ -0,0 +1,54 @@ +#include "swiglu_nvidia.cuh" + +#include "../../../device/nvidia/cuda_utils.cuh" +#include "../../../utils.hpp" + +#include +#include + +namespace llaisys::ops::nvidia { + +template +__global__ void swigluKernel(T *out, const T *gate, const T *up, size_t numel) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) { + float g = to_float_cuda(gate[idx]); + float u = to_float_cuda(up[idx]); + + // Swish: x * sigmoid(x) + float sigmoid = 1.0f / (1.0f + expf(-g)); + float result = u * g * sigmoid; + + out[idx] = from_float_cuda(result); + } +} + +template +void swiglu_(std::byte *out_bytes, const std::byte *gate_bytes, const std::byte *up_bytes, + size_t seq_len, size_t intermediate_size) { + auto out = reinterpret_cast(out_bytes); + auto gate = reinterpret_cast(gate_bytes); + auto up = reinterpret_cast(up_bytes); + + size_t numel = seq_len * intermediate_size; + const int blockSize = 256; + const int numBlocks = (numel + blockSize - 1) / blockSize; + + swigluKernel<<>>(out, gate, up, numel); +} + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t type, size_t seq_len, size_t intermediate_size) { + switch (type) { + case LLAISYS_DTYPE_F32: + return swiglu_(out, gate, up, seq_len, intermediate_size); + case LLAISYS_DTYPE_BF16: + return swiglu_(out, gate, up, seq_len, intermediate_size); + case LLAISYS_DTYPE_F16: + return swiglu_(out, gate, up, seq_len, intermediate_size); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} diff --git a/src/ops/swiglu/nvidia/swiglu_nvidia.cuh b/src/ops/swiglu/nvidia/swiglu_nvidia.cuh new file mode 100644 index 00000000..85c8a69e --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_nvidia.cuh @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t type, size_t seq_len, size_t intermediate_size); +} diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc9..59ed27b1 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,79 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/swiglu_cpu.hpp" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/swiglu_nvidia.cuh" +#endif + +#ifdef ENABLE_METAX_API +#include "metax/swiglu_metax.cuh" +#endif + namespace llaisys::ops { void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, gate, up); + + // Only support 2D tensors for now + ASSERT(gate->ndim() == 2, "SwiGLU: gate must be 2D tensor."); + ASSERT(up->ndim() == 2, "SwiGLU: up must be 2D tensor."); + ASSERT(out->ndim() == 2, "SwiGLU: out must be 2D tensor."); + ASSERT(gate->dtype() == up->dtype() && gate->dtype() == out->dtype(), "SwiGLU: all tensors must have same dtype."); + ASSERT(gate->shape() == up->shape() && gate->shape() == out->shape(), "SwiGLU: all tensors must have same shape."); + ASSERT(out->isContiguous() && gate->isContiguous() && up->isContiguous(), "SwiGLU: all tensors must be contiguous."); + + size_t seq_len = gate->shape()[0]; + size_t intermediate_size = gate->shape()[1]; + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::swiglu( + out->data(), + gate->data(), + up->data(), + gate->dtype(), + seq_len, + intermediate_size + ); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::swiglu( + out->data(), + gate->data(), + up->data(), + gate->dtype(), + seq_len, + intermediate_size + ); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::swiglu( + out->data(), + gate->data(), + up->data(), + gate->dtype(), + seq_len, + intermediate_size + ); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::swiglu( + out->data(), + gate->data(), + up->data(), + gate->dtype(), + seq_len * intermediate_size + ); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } -} // namespace llaisys::ops +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb6..9e089883 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -5,6 +5,8 @@ #include #include #include +#include +#include namespace llaisys { @@ -164,27 +166,168 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + if (this->ndim() == 0) { + return true; + } + + ptrdiff_t expected_stride = 1; + for (size_t i = this->ndim(); i-- > 0;) { + if (this->strides()[i] != expected_stride) { + return false; + } + expected_stride *= this->shape()[i]; + } + return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t ndim = this->ndim(); + CHECK_ARGUMENT(order.size() == ndim, "Invalid permute order size"); + + std::vector visited(ndim, false); + for (size_t i = 0; i < ndim; ++i) { + CHECK_ARGUMENT(order[i] < ndim && !visited[order[i]], "Invalid permute order"); + visited[order[i]] = true; + } + + TensorMeta new_meta; + new_meta.dtype = this->_meta.dtype; + new_meta.shape.resize(ndim); + new_meta.strides.resize(ndim); + + for (size_t i = 0; i < ndim; ++i) { + new_meta.shape[i] = this->_meta.shape[order[i]]; + new_meta.strides[i] = this->_meta.strides[order[i]]; + } + + return std::shared_ptr(new Tensor(new_meta, this->_storage, this->_offset)); } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t new_numel = std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); + CHECK_ARGUMENT(new_numel == this->numel(), "View shape has different number of elements"); + + TensorMeta new_meta; + new_meta.dtype = this->_meta.dtype; + new_meta.shape = shape; + new_meta.strides.resize(shape.size()); + + if (this->isContiguous()) { + // 连续张量可以直接重塑 + size_t ndim = shape.size(); + if (ndim == 0) { + return std::shared_ptr(new Tensor(new_meta, this->_storage, this->_offset)); + } + + new_meta.strides[ndim - 1] = 1; + for (size_t i = ndim - 1; i-- > 0;) { + new_meta.strides[i] = new_meta.strides[i + 1] * new_meta.shape[i + 1]; + } + } else { + // 非连续张量需要检查是否可以重塑 + // 这里只处理可以直接重塑的情况,更复杂的情况需要更详细的检查 + // 目前只支持将连续的维度合并或拆分 + size_t this_idx = 0; + size_t new_idx = 0; + + while (this_idx < this->ndim() && new_idx < shape.size()) { + size_t this_size = this->shape()[this_idx]; + ptrdiff_t this_stride = this->strides()[this_idx]; + + size_t new_size = shape[new_idx]; + size_t combined_size = this_size; + ptrdiff_t expected_stride = this_stride * this_size; + + // 尝试合并多个维度 + while (this_idx + 1 < this->ndim() && combined_size * this->shape()[this_idx + 1] == new_size) { + combined_size *= this->shape()[this_idx + 1]; + CHECK_ARGUMENT(this->strides()[this_idx + 1] == expected_stride, "Cannot view non-contiguous tensor with this shape"); + expected_stride *= this->shape()[this_idx + 1]; + this_idx++; + } + + if (combined_size != new_size) { + // 尝试拆分维度 + CHECK_ARGUMENT(this_size % new_size == 0, "Cannot view non-contiguous tensor with this shape"); + + size_t split_size = this_size / new_size; + new_meta.strides[new_idx] = this_stride * split_size; + new_idx++; + + CHECK_ARGUMENT(new_idx < shape.size(), "Cannot view non-contiguous tensor with this shape"); + + new_meta.strides[new_idx] = this_stride; + new_meta.shape[new_idx] = split_size; + } else { + new_meta.strides[new_idx] = this_stride; + } + + this_idx++; + new_idx++; + } + + CHECK_ARGUMENT(this_idx == this->ndim() && new_idx == shape.size(), "Cannot view non-contiguous tensor with this shape"); + } + + return std::shared_ptr(new Tensor(new_meta, this->_storage, this->_offset)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + CHECK_ARGUMENT(dim < this->ndim(), "Invalid dimension for slice"); + CHECK_ARGUMENT(start < end && end <= this->shape()[dim], "Invalid slice range"); + + TensorMeta new_meta = this->_meta; + new_meta.shape[dim] = end - start; + + size_t offset = this->_offset + start * this->strides()[dim] * this->elementSize(); + + return std::shared_ptr(new Tensor(new_meta, this->_storage, offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + size_t size = this->numel() * this->elementSize(); + core::context().setDevice(this->deviceType(), this->deviceId()); + + if (this->deviceType() == LLAISYS_DEVICE_CPU) { + std::memcpy(this->data(), src_, size); + } else { + const size_t ALIGNMENT = 64; + void *src_aligned = const_cast(src_); + void *temp_buf = nullptr; + + bool src_aligned_ok = (reinterpret_cast(src_) % ALIGNMENT) == 0; + bool dst_aligned_ok = (reinterpret_cast(this->data()) % ALIGNMENT) == 0; + bool size_aligned_ok = (size % ALIGNMENT) == 0; + + if (!src_aligned_ok || !dst_aligned_ok || !size_aligned_ok) { + size_t aligned_size = ((size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; +#ifdef _WIN32 + temp_buf = _aligned_malloc(aligned_size > 0 ? aligned_size : ALIGNMENT, ALIGNMENT); + if (temp_buf) { +#else + int ret = posix_memalign(&temp_buf, ALIGNMENT, aligned_size > 0 ? aligned_size : ALIGNMENT); + if (ret == 0 && temp_buf) { +#endif + std::memcpy(temp_buf, src_, size); + src_aligned = temp_buf; + } + } + + core::context().runtime().api()->memcpy_sync( + this->data(), + src_aligned, + size, + LLAISYS_MEMCPY_H2D); + + if (temp_buf) { +#ifdef _WIN32 + _aligned_free(temp_buf); +#else + free(temp_buf); +#endif + } + } } tensor_t Tensor::contiguous() const { diff --git a/test/ops/add.py b/test/ops/add.py index bb8bf8ca..993c4f61 100644 --- a/test/ops/add.py +++ b/test/ops/add.py @@ -5,7 +5,7 @@ sys.path.insert(0, parent_dir) import llaisys import torch -from test_utils import random_tensor, check_equal, benchmark +from test_utils import random_tensor, check_equal, benchmark, get_tolerance def torch_add(ans, a, b): @@ -42,19 +42,15 @@ def test_op_add( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(2, 3), (512, 4096)] - testDtypePrec = [ - # type, atol, rtol - ("f32", 1e-5, 1e-5), - ("f16", 1e-3, 1e-3), - ("bf16", 1e-3, 1e-3), - ] + testDtypes = ["f32", "f16", "bf16"] print(f"Testing Ops.add on {args.device}") for shape in testShapes: - for dtype_name, atol, rtol in testDtypePrec: + for dtype_name in testDtypes: + atol, rtol = get_tolerance(dtype_name, args.device) test_op_add(shape, dtype_name, atol, rtol, args.device, args.profile) print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops/argmax.py b/test/ops/argmax.py index d0f7ee29..78a3a9d5 100644 --- a/test/ops/argmax.py +++ b/test/ops/argmax.py @@ -6,7 +6,7 @@ sys.path.insert(0, parent_dir) import llaisys import torch -from test_utils import random_tensor, check_equal, benchmark, zero_tensor +from test_utils import random_tensor, check_equal, benchmark, zero_tensor, get_tolerance def torch_argmax(max_idx, max_val, vals): @@ -43,7 +43,7 @@ def test_op_argmax( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(4,), (4096,)] diff --git a/test/ops/embedding.py b/test/ops/embedding.py index 99cadc1b..4539daf6 100644 --- a/test/ops/embedding.py +++ b/test/ops/embedding.py @@ -4,7 +4,7 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, parent_dir) import llaisys -from test_utils import random_int_tensor, random_tensor, check_equal, benchmark +from test_utils import random_int_tensor, random_tensor, check_equal, benchmark, get_tolerance def torch_embedding(out, idx, embd): @@ -15,6 +15,8 @@ def test_op_embedding( idx_shape, embd_shape, dtype_name="f32", + atol=1e-5, + rtol=1e-5, device_name="cpu", profile=False, ): @@ -25,7 +27,7 @@ def test_op_embedding( torch_embedding(out, idx, embd) llaisys.Ops.embedding(out_, idx_, embd_) - check_equal(out_, out, strict=True) + assert check_equal(out_, out, atol=atol, rtol=rtol) if profile: benchmark( @@ -39,24 +41,20 @@ def test_op_embedding( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ ((1,), (2, 3)), ((50,), (512, 4096)), ] - testDtype = [ - # type - "f32", - "f16", - "bf16", - ] + testDtypes = ["f32", "f16", "bf16"] print(f"Testing Ops.embedding on {args.device}") for idx_shape, embd_shape in testShapes: - for dtype_name in testDtype: + for dtype_name in testDtypes: + atol, rtol = get_tolerance(dtype_name, args.device) test_op_embedding( - idx_shape, embd_shape, dtype_name, args.device, args.profile + idx_shape, embd_shape, dtype_name, atol, rtol, args.device, args.profile ) print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops/linear.py b/test/ops/linear.py index 38897331..0490ccf3 100644 --- a/test/ops/linear.py +++ b/test/ops/linear.py @@ -5,7 +5,7 @@ sys.path.insert(0, parent_dir) import llaisys import torch -from test_utils import random_tensor, check_equal, benchmark +from test_utils import random_tensor, check_equal, benchmark, get_tolerance def torch_linear(out, x, w, bias): @@ -49,22 +49,18 @@ def test_op_linear( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ ((2, 3), (2, 4), (3, 4), True), ((512, 4096), (512, 4096), (4096, 4096), True), ] - testDtypePrec = [ - # type, atol, rtol - ("f32", 1e-5, 1e-5), - ("f16", 1e-3, 1e-3), - ("bf16", 1e-2, 1e-2), - ] + testDtypes = ["f32", "f16", "bf16"] print(f"Testing Ops.linear on {args.device}") for shapes in testShapes: - for dtype_name, atol, rtol in testDtypePrec: + for dtype_name in testDtypes: + atol, rtol = get_tolerance(dtype_name, args.device) test_op_linear(*shapes, dtype_name, atol, rtol, args.device, args.profile) print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops/rms_norm.py b/test/ops/rms_norm.py index 67b789e3..74805418 100644 --- a/test/ops/rms_norm.py +++ b/test/ops/rms_norm.py @@ -5,7 +5,7 @@ sys.path.insert(0, parent_dir) import llaisys import torch -from test_utils import random_tensor, check_equal, benchmark +from test_utils import random_tensor, check_equal, benchmark, get_tolerance def torch_rms_norm(ans, x, w, eps): @@ -48,19 +48,15 @@ def test_op_rms_norm( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(1, 4), (512, 4096)] - testDtypePrec = [ - # type, atol, rtol - ("f32", 1e-5, 1e-5), - ("f16", 1e-3, 1e-3), - ("bf16", 1e-2, 1e-2), - ] + testDtypes = ["f32", "f16", "bf16"] print(f"Testing Ops.rms_norm on {args.device}") for shape in testShapes: - for dtype_name, atol, rtol in testDtypePrec: + for dtype_name in testDtypes: + atol, rtol = get_tolerance(dtype_name, args.device) test_op_rms_norm(shape, dtype_name, atol, rtol, args.device, args.profile) print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops/rope.py b/test/ops/rope.py index fe59dd11..1339a2c6 100644 --- a/test/ops/rope.py +++ b/test/ops/rope.py @@ -5,7 +5,7 @@ sys.path.insert(0, parent_dir) import llaisys import torch -from test_utils import arrange_tensor, random_tensor, check_equal, benchmark +from test_utils import arrange_tensor, random_tensor, check_equal, benchmark, get_tolerance def torch_rope(y: torch.Tensor, x: torch.Tensor, pos_ids: torch.Tensor, theta: float): @@ -63,21 +63,17 @@ def test_op_rope( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ ((2, 1, 4), (0, 2)), ((512, 4, 4096), (512, 1024))] - testDtypePrec = [ - # type, atol, rtol - ("f32", 1e-4, 1e-4), - ("f16", 1e-3, 1e-3), - ("bf16", 1e-2, 1e-2), - ] + testDtypes = ["f32", "f16", "bf16"] print(f"Testing Ops.rope on {args.device}") for shape, start_end in testShapes: - for dtype_name, atol, rtol in testDtypePrec: + for dtype_name in testDtypes: + atol, rtol = get_tolerance(dtype_name, args.device) test_op_rope(shape, start_end, dtype_name, atol, rtol, args.device, args.profile) print("\033[92mTest passed!\033[0m\n") diff --git a/test/ops/self_attention.py b/test/ops/self_attention.py index a042b51b..8b478952 100644 --- a/test/ops/self_attention.py +++ b/test/ops/self_attention.py @@ -15,7 +15,7 @@ def torch_self_attention(attn_val, query, key, value, scale): L, S = query.size(-2), key.size(-2) attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=S-L) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=S-L) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) @@ -65,7 +65,7 @@ def test_op_self_attention( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ diff --git a/test/ops/swiglu.py b/test/ops/swiglu.py index 1fa08f73..de894a82 100644 --- a/test/ops/swiglu.py +++ b/test/ops/swiglu.py @@ -5,7 +5,7 @@ sys.path.insert(0, parent_dir) import llaisys import torch -from test_utils import random_tensor, check_equal, benchmark +from test_utils import random_tensor, check_equal, benchmark, get_tolerance def torch_swiglu(out, gate, up): @@ -42,19 +42,15 @@ def test_op_swiglu( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(2, 3), (512, 4096)] - testDtypePrec = [ - # type, atol, rtol - ("f32", 1e-5, 1e-5), - ("f16", 1e-3, 1e-3), - ("bf16", 1e-2, 1e-2), - ] + testDtypes = ["f32", "f16", "bf16"] print(f"Testing Ops.swiglu on {args.device}") for shape in testShapes: - for dtype_name, atol, rtol in testDtypePrec: + for dtype_name in testDtypes: + atol, rtol = get_tolerance(dtype_name, args.device) test_op_swiglu(shape, dtype_name, atol, rtol, args.device, args.profile) print("\033[92mTest passed!\033[0m\n") diff --git a/test/test_infer.py b/test/test_infer.py index 59d06b87..ee740a00 100644 --- a/test/test_infer.py +++ b/test/test_infer.py @@ -23,10 +23,14 @@ def load_hf_model(model_path=None, device_name="cpu"): print(f"Loading model from Hugging Face: {model_id}") model_path = snapshot_download(model_id) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # For MetaX, use CPU to load PyTorch reference model + # because PyTorch doesn't support MetaX directly + torch_device_name = "cpu" if device_name == "metax" else device_name model = AutoModelForCausalLM.from_pretrained( model_path, - torch_dtype=torch.bfloat16, - device_map=torch_device(device_name), + dtype=torch.bfloat16, + device_map=torch_device(torch_device_name), trust_remote_code=True, ) @@ -81,8 +85,9 @@ def llaisys_infer( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) - parser.add_argument("--model", default=None, type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) + # parser.add_argument("--model", default=None, type=str) + parser.add_argument("--model", default="./models/deepseek-r1-distill-qwen-1.5b/", type=str) parser.add_argument("--prompt", default="Who are you?", type=str) parser.add_argument("--max_steps", default=128, type=int) parser.add_argument("--top_p", default=0.8, type=float) diff --git a/test/test_infer_tp.py b/test/test_infer_tp.py new file mode 100644 index 00000000..5b94487a --- /dev/null +++ b/test/test_infer_tp.py @@ -0,0 +1,179 @@ +import gc +from test_utils import * + +import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +from huggingface_hub import snapshot_download +import os +import time +import sys +import io + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'python')) +import llaisys + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def load_hf_model(model_path=None, device_name="cuda"): + # Check if local path exists + if model_path and os.path.isdir(model_path): + print(f"Loading model from local path: {model_path}") + use_path = model_path + else: + # Try alternative local paths + alt_paths = [ + "./models/deepseek-r1-distill-qwen-1.5b/", + "../models/deepseek-r1-distill-qwen-1.5b/", + "/home/hanson/llaisys/models/deepseek-r1-distill-qwen-1.5b/", + ] + use_path = None + for p in alt_paths: + if os.path.isdir(p): + print(f"Loading model from local path: {p}") + use_path = p + model_path = p + break + + if use_path is None: + # Fall back to HuggingFace + model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + print(f"Loading model from Hugging Face: {model_id}") + use_path = model_id + model_path = snapshot_download(model_id) + + tokenizer = AutoTokenizer.from_pretrained(use_path, trust_remote_code=True) + + model = AutoModelForCausalLM.from_pretrained( + use_path, + dtype=torch.bfloat16, + device_map=torch_device(device_name), + trust_remote_code=True, + ) + + return tokenizer, model, model_path + + +def hf_infer( + prompt, tokenizer, model, max_new_tokens=128, top_p=0.8, top_k=50, temperature=0.8 +): + input_content = tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + inputs = tokenizer.encode(input_content, return_tensors="pt").to(model.device) + with torch.no_grad(): + outputs = model.generate( + inputs, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + result = tokenizer.decode(outputs[0], skip_special_tokens=True) + return outputs[0].tolist(), result + + +def load_llaisys_tp_model(model_path, device_ids): + print(f"Loading Tensor Parallel model on GPUs: {device_ids}") + model = llaisys.models.Qwen2TP(model_path, device_ids=device_ids) + return model + + +def llaisys_tp_infer( + prompt, tokenizer, model, max_new_tokens=128, top_p=0.8, top_k=50, temperature=0.8 +): + input_content = tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + inputs = tokenizer.encode(input_content) + outputs = model.generate( + inputs, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + + return outputs, tokenizer.decode(outputs, skip_special_tokens=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--device-ids", default="0,1", type=str, + help="Comma-separated GPU IDs for tensor parallelism") + parser.add_argument("--model", default="./models/deepseek-r1-distill-qwen-1.5b/", type=str) + parser.add_argument("--prompt", default="Who are you?", type=str) + parser.add_argument("--max_steps", default=128, type=int) + parser.add_argument("--top_p", default=0.8, type=float) + parser.add_argument("--top_k", default=50, type=int) + parser.add_argument("--temperature", default=1.0, type=float) + parser.add_argument("--test", action="store_true") + + args = parser.parse_args() + + # Parse device IDs + device_ids = [int(x.strip()) for x in args.device_ids.split(",")] + if len(device_ids) < 2: + print("Error: Tensor parallelism requires at least 2 GPUs") + sys.exit(1) + + top_p, top_k, temperature = args.top_p, args.top_k, args.temperature + if args.test: + top_p, top_k, temperature = 1.0, 1, 1.0 + + tokenizer, model, model_path = load_hf_model(args.model, "nvidia") + + # Example prompt + start_time = time.time() + tokens, output = hf_infer( + args.prompt, + tokenizer, + model, + max_new_tokens=args.max_steps, + top_p=top_p, + top_k=top_k, + temperature=temperature, + ) + end_time = time.time() + + del model + gc.collect() + torch.cuda.empty_cache() + + print("\n=== Answer ===\n") + print("Tokens:") + print(tokens) + print("\nContents:") + print(output) + print("\n") + print(f"Time elapsed: {(end_time - start_time):.2f}s\n") + + model = load_llaisys_tp_model(model_path, device_ids) + start_time = time.time() + llaisys_tokens, llaisys_output = llaisys_tp_infer( + args.prompt, + tokenizer, + model, + max_new_tokens=args.max_steps, + top_p=top_p, + top_k=top_k, + temperature=temperature, + ) + end_time = time.time() + + print("\n=== Your Result ===\n") + print("Tokens:") + print(llaisys_tokens) + print("\nContents:") + print(llaisys_output) + print("\n") + print(f"Time elapsed: {(end_time - start_time):.2f}s\n") + + if args.test: + assert llaisys_tokens == tokens + print("\033[92mTest passed!\033[0m\n") diff --git a/test/test_runtime.py b/test/test_runtime.py index e2ac218a..1961a1aa 100644 --- a/test/test_runtime.py +++ b/test/test_runtime.py @@ -55,7 +55,7 @@ def test_memcpy(api, size_bytes: int): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) args = parser.parse_args() test_basic_runtime_api(args.device) diff --git a/test/test_utils.py b/test/test_utils.py index 0f38f0c8..ee87e90e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -24,11 +24,13 @@ def random_tensor( api = llaisys.RuntimeAPI(llaisys_device(device_name)) bytes_ = torch_tensor.numel() * torch_tensor.element_size() + # Use H2D for metax since torch tensor is on CPU + memcpy_kind = llaisys.MemcpyKind.H2D if device_name == "metax" else llaisys.MemcpyKind.D2D api.memcpy_sync( llaisys_tensor.data_ptr(), torch_tensor.data_ptr(), bytes_, - llaisys.MemcpyKind.D2D, + memcpy_kind, ) return torch_tensor, llaisys_tensor @@ -52,11 +54,13 @@ def random_int_tensor(shape, device_name, dtype_name="i64", device_id=0, low=0, api = llaisys.RuntimeAPI(llaisys_device(device_name)) bytes_ = torch_tensor.numel() * torch_tensor.element_size() + # Use H2D for metax since torch tensor is on CPU + memcpy_kind = llaisys.MemcpyKind.H2D if device_name == "metax" else llaisys.MemcpyKind.D2D api.memcpy_sync( llaisys_tensor.data_ptr(), torch_tensor.data_ptr(), bytes_, - llaisys.MemcpyKind.D2D, + memcpy_kind, ) return torch_tensor, llaisys_tensor @@ -80,11 +84,13 @@ def zero_tensor( api = llaisys.RuntimeAPI(llaisys_device(device_name)) bytes_ = torch_tensor.numel() * torch_tensor.element_size() + # Use H2D for metax since torch tensor is on CPU + memcpy_kind = llaisys.MemcpyKind.H2D if device_name == "metax" else llaisys.MemcpyKind.D2D api.memcpy_sync( llaisys_tensor.data_ptr(), torch_tensor.data_ptr(), bytes_, - llaisys.MemcpyKind.D2D, + memcpy_kind, ) return torch_tensor, llaisys_tensor @@ -103,11 +109,13 @@ def arrange_tensor( api = llaisys.RuntimeAPI(llaisys_device(device_name)) bytes_ = torch_tensor.numel() * torch_tensor.element_size() + # Use H2D for metax since torch tensor is on CPU + memcpy_kind = llaisys.MemcpyKind.H2D if device_name == "metax" else llaisys.MemcpyKind.D2D api.memcpy_sync( llaisys_tensor.data_ptr(), torch_tensor.data_ptr(), bytes_, - llaisys.MemcpyKind.D2D, + memcpy_kind, ) return torch_tensor, llaisys_tensor @@ -132,20 +140,21 @@ def check_equal( else: # TODO: Support negative strides in the future raise ValueError("Negative strides are not supported yet") + device_name_str = device_name(llaisys_result.device_type()) tmp = torch.zeros( (right + 1,), dtype=torch_answer.dtype, - device=torch_device( - device_name(llaisys_result.device_type()), llaisys_result.device_id() - ), + device=torch_device(device_name_str, llaisys_result.device_id()), ) result = torch.as_strided(tmp, shape, strides) api = llaisys.RuntimeAPI(llaisys_result.device_type()) + # Use D2H for metax since result tensor is on CPU + memcpy_kind = llaisys.MemcpyKind.D2H if device_name_str == "metax" else llaisys.MemcpyKind.D2D api.memcpy_sync( result.data_ptr(), llaisys_result.data_ptr(), (right + 1) * tmp.element_size(), - llaisys.MemcpyKind.D2D, + memcpy_kind, ) if strict: @@ -188,6 +197,10 @@ def torch_device(device_name: str, device_id=0): return torch.device("cpu") elif device_name == "nvidia": return torch.device(f"cuda:{device_id}") + elif device_name == "metax": + # MetaX uses CPU for PyTorch reference (no NVIDIA driver) + # Data will be transferred via LLAISYS runtime API + return torch.device("cpu") else: raise ValueError(f"Unsupported device name: {device_name}") @@ -197,6 +210,8 @@ def llaisys_device(device_name: str): return llaisys.DeviceType.CPU elif device_name == "nvidia": return llaisys.DeviceType.NVIDIA + elif device_name == "metax": + return llaisys.DeviceType.METAX else: raise ValueError(f"Unsupported device name: {device_name}") @@ -206,6 +221,8 @@ def device_name(llaisys_device: llaisys.DeviceType): return "cpu" elif llaisys_device == llaisys.DeviceType.NVIDIA: return "nvidia" + elif llaisys_device == llaisys.DeviceType.METAX: + return "metax" else: raise ValueError(f"Unsupported llaisys device: {llaisys_device}") @@ -277,3 +294,28 @@ def dtype_name(llaisys_dtype: llaisys.DataType): return "bool" else: raise ValueError(f"Unsupported llaisys dtype: {llaisys_dtype}") + + +def get_tolerance(dtype_name: str, device_name: str): + """ + Get the tolerance (atol, rtol) for numerical comparison based on dtype and device. + + Args: + dtype_name: Data type name (f32, f16, bf16) + device_name: Device name (cpu, nvidia, metax) + + Returns: + tuple: (atol, rtol) + """ + base = { + "f32": (1e-4, 1e-4), + "f16": (1e-3, 1e-3), + "bf16": (1e-2, 1e-2), + } + + # MetaX GPU has larger numerical differences compared to CPU reference + if device_name == "metax": + base["f32"] = (1e-3, 1e-3) + base["bf16"] = (1e-2, 1e-2) + + return base.get(dtype_name, (1e-4, 1e-4)) diff --git a/xmake.lua b/xmake.lua index 1f65f7a9..2429ee3c 100644 --- a/xmake.lua +++ b/xmake.lua @@ -15,9 +15,40 @@ option_end() if has_config("nv-gpu") then add_defines("ENABLE_NVIDIA_API") + + -- Check NCCL availability (also needed in this file) + nccl_available = false + if is_plat("linux") then + local nccl_paths = { + "/usr/lib/x86_64-linux-gnu", + "/usr/local/cuda/lib64", + os.getenv("NCCL_ROOT") and (os.getenv("NCCL_ROOT") .. "/lib") or nil, + os.getenv("NCCL_HOME") and (os.getenv("NCCL_HOME") .. "/lib") or nil, + } + for _, path in ipairs(nccl_paths) do + if path and os.isfile(path .. "/libnccl.so") then + nccl_available = true + add_defines("ENABLE_NCCL") + break + end + end + end + includes("xmake/nvidia.lua") end +-- MetaX -- +option("metax-gpu") + set_default(false) + set_showmenu(true) + set_description("Whether to compile implementations for MetaX GPU") +option_end() + +if has_config("metax-gpu") then + add_defines("ENABLE_METAX_API") + includes("xmake/metax.lua") +end + target("llaisys-utils") set_kind("static") @@ -37,6 +68,9 @@ target("llaisys-device") set_kind("static") add_deps("llaisys-utils") add_deps("llaisys-device-cpu") + if has_config("nv-gpu") then + add_deps("llaisys-device-nvidia") + end set_languages("cxx17") set_warnings("all", "error") @@ -83,18 +117,42 @@ target_end() target("llaisys-ops") set_kind("static") add_deps("llaisys-ops-cpu") + if has_config("nv-gpu") then + add_deps("llaisys-ops-nvidia") + end set_languages("cxx17") set_warnings("all", "error") if not is_plat("windows") then add_cxflags("-fPIC", "-Wno-unknown-pragmas") end - + add_files("src/ops/*/*.cpp") on_install(function (target) end) target_end() +target("llaisys-models") + set_kind("static") + add_deps("llaisys-tensor") + add_deps("llaisys-ops") + + set_languages("cxx17") + set_warnings("all", "error") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + end + + add_files("src/models/*/*.cpp") + + -- Add NCCL define if available + if has_config("nv-gpu") and nccl_available then + add_defines("ENABLE_NCCL") + end + + on_install(function (target) end) +target_end() + target("llaisys") set_kind("shared") add_deps("llaisys-utils") @@ -102,12 +160,51 @@ target("llaisys") add_deps("llaisys-core") add_deps("llaisys-tensor") add_deps("llaisys-ops") + add_deps("llaisys-models") set_languages("cxx17") set_warnings("all", "error") add_files("src/llaisys/*.cc") set_installdir(".") + if has_config("nv-gpu") then + -- Add CUDA files directly to main target for proper device linking + -- Conditionally exclude NCCL files if NCCL is not available + if nccl_available then + add_files("src/device/nvidia/*.cu") + else + add_files("src/device/nvidia/nvidia_resource.cu") + add_files("src/device/nvidia/nvidia_runtime_api.cu") + end + add_files("src/ops/*/nvidia/*.cu") + add_linkdirs("/usr/local/cuda/lib64") + add_syslinks("cudart", "cublas") + add_shflags("-Wl,--no-as-needed", "-lcudart", "-lcublas", {force = true}) + set_toolchains("cuda") + add_cugencodes("native") + add_cuflags("-rdc=true", {force = true}) + add_includedirs("/usr/include") -- For NCCL headers + + -- Try to find NCCL in common locations + if nccl_available and os.isdir("/usr/lib/x86_64-linux-gnu") then + add_linkdirs("/usr/lib/x86_64-linux-gnu") + add_shflags("-Wl,--no-as-needed", "-lnccl", {force = true}) + end + end + + if has_config("metax-gpu") then + -- Directly add CUDA object files instead of static libraries to avoid RDC issues + add_files("src/device/metax/*.cu") + add_files("src/ops/*/metax/*.cu") + add_linkdirs("/opt/maca/lib", "/opt/maca/tools/cu-bridge/lib") + add_syslinks("mcblas", "mcruntime", "cuda") + add_shflags("-Wl,--no-as-needed", "-lmcblas", "-lmcruntime", "-lcuda", {force = true}) + set_toolchains("cuda") + add_cugencodes("native") + -- Disable device link for MetaX + set_policy("build.cuda.devlink", false) + end + after_install(function (target) -- copy shared library to python package @@ -119,4 +216,4 @@ target("llaisys") os.cp("lib/*.so", "python/llaisys/libllaisys/") end end) -target_end() \ No newline at end of file +target_end() diff --git a/xmake/metax.lua b/xmake/metax.lua new file mode 100644 index 00000000..8cf51815 --- /dev/null +++ b/xmake/metax.lua @@ -0,0 +1,44 @@ +-- MetaX GPU support configuration +-- This file defines separate static libraries for MetaX device and ops +-- which are then linked into the main llaisys shared library + +target("llaisys-device-metax") + set_kind("static") + add_deps("llaisys-utils") + + set_toolchains("cuda") + set_languages("cxx17") + add_cugencodes("native") + add_cuflags("-rdc=true", {force = true}) + + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cuflags("-Xcompiler=-fPIC", "-Wno-unknown-pragmas") + end + + add_includedirs("../include") + add_files("../src/device/metax/*.cu") + + on_install(function (target) end) +target_end() + +target("llaisys-ops-metax") + set_kind("static") + add_deps("llaisys-tensor") + add_deps("llaisys-device-metax") + + set_toolchains("cuda") + set_languages("cxx17") + add_cugencodes("native") + add_cuflags("-rdc=true", {force = true}) + + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cuflags("-Xcompiler=-fPIC", "-Wno-unknown-pragmas") + end + + add_includedirs("../include") + add_files("../src/ops/*/metax/*.cu") + + on_install(function (target) end) +target_end() diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua new file mode 100644 index 00000000..9c2e2a9a --- /dev/null +++ b/xmake/nvidia.lua @@ -0,0 +1,124 @@ +-- NVIDIA GPU support configuration + +-- Check if NCCL is available (Linux only) +local nccl_available = false +local nccl_lib = nil + +if is_plat("linux") then + -- Try to find NCCL in common locations + local nccl_paths = { + "/usr/lib/x86_64-linux-gnu", + "/usr/local/cuda/lib64", + os.getenv("NCCL_ROOT") and (os.getenv("NCCL_ROOT") .. "/lib") or nil, + os.getenv("NCCL_HOME") and (os.getenv("NCCL_HOME") .. "/lib") or nil, + } + + for _, path in ipairs(nccl_paths) do + if path and os.isfile(path .. "/libnccl.so") then + nccl_available = true + nccl_lib = path + break + end + end +end + +-- NVIDIA GPU implementation +target("llaisys-device-nvidia") + set_kind("static") + -- Device files are added in xmake.lua to avoid duplication + -- This target is currently empty but kept for future use + add_includedirs("$(projectdir)/src") + + -- CUDA settings + add_cugencodes("sm_80") + add_cuflags("-rdc=true", {force = true}) + + -- Platform specific + if is_plat("linux") then + add_links("cudart", "cublas") + if nccl_available then + add_links("nccl") + add_defines("ENABLE_NCCL") + if nccl_lib then + add_linkdirs(nccl_lib) + end + end + else + add_links("cudart", "cublas") + end + + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cuflags("-Xcompiler=-fPIC", "-Wno-unknown-pragmas") + + on_install(function (target) end) +target_end() + +-- NVIDIA operator implementations +-- Note: Operator CUDA files are added in xmake.lua to avoid duplication with llaisys target +target("llaisys-ops-nvidia") + set_kind("static") + -- Files are compiled in xmake.lua's llaisys target for proper device linking + add_includedirs("$(projectdir)/src") + + add_cugencodes("sm_80") + add_cuflags("-rdc=true", {force = true}) + + if is_plat("linux") then + add_links("cudart", "cublas") + else + add_links("cudart", "cublas") + end + + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cuflags("-Xcompiler=-fPIC", "-Wno-unknown-pragmas") + + on_install(function (target) end) +target_end() + +-- Main library with NVIDIA support +target("llaisys-nvidia") + set_kind("shared") + add_deps("llaisys-models", "llaisys-ops", "llaisys-ops-cpu", + "llaisys-tensor", "llaisys-core", "llaisys-device", + "llaisys-device-cpu", "llaisys-utils") + + if is_plat("linux") then + add_deps("llaisys-device-nvidia") + end + + add_linkdirs("/usr/local/cuda/lib64") + add_links("cudadevrt", "rt", "pthread", "dl") + + if is_plat("linux") and nccl_available then + add_links("nccl") + end + + add_cuflags("-rdc=true", {force = true}) + + if is_plat("linux") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cuflags("-Xcompiler=-fPIC", "-Wno-unknown-pragmas") + end + + on_install(function (target) end) +target_end() + +-- NCCL availability message +if is_plat("linux") and nccl_available then + print("NCCL found: " .. (nccl_lib or "system")) +elseif is_plat("linux") then + print("Warning: NCCL not found. TP will be disabled.") +else + print("NCCL not available on Windows. TP is disabled.") +end + +-- Export NCCL availability for other scripts +if nccl_available then + set_configvar("NCCL_AVAILABLE", 1) +end + +-- Return NCCL status for use in other xmake files +return { + nccl_available = nccl_available, + nccl_lib = nccl_lib +}