Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions include/llaisys/models/qwen2.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define LLAISYS_MODELS_QWEN2_H

#include "../tensor.h"
#include <vector>

__C {
struct LlaisysQwen2Meta {
Expand All @@ -14,19 +15,20 @@ __C {
struct LlaisysQwen2Weights {
llaisysTensor_t in_embed;
llaisysTensor_t out_embed;
llaisysTensor_t out_norm_w; // a.k.a. model.norm.weight
llaisysTensor_t *attn_norm_w; // a.k.a. input_layernorm.weight
llaisysTensor_t *attn_q_w;
llaisysTensor_t *attn_q_b;
llaisysTensor_t *attn_k_w;
llaisysTensor_t *attn_k_b;
llaisysTensor_t *attn_v_w;
llaisysTensor_t *attn_v_b;
llaisysTensor_t *attn_o_w;
llaisysTensor_t *mlp_norm_w; // a.k.a. post_attention_layernorm.weight
llaisysTensor_t *mlp_gate_w;
llaisysTensor_t *mlp_up_w;
llaisysTensor_t *mlp_down_w;
llaisysTensor_t out_norm_w;
// 改为 vector
std::vector<llaisysTensor_t> attn_norm_w;
std::vector<llaisysTensor_t> attn_q_w;
std::vector<llaisysTensor_t> attn_q_b;
std::vector<llaisysTensor_t> attn_k_w;
std::vector<llaisysTensor_t> attn_k_b;
std::vector<llaisysTensor_t> attn_v_w;
std::vector<llaisysTensor_t> attn_v_b;
std::vector<llaisysTensor_t> attn_o_w;
std::vector<llaisysTensor_t> mlp_norm_w;
std::vector<llaisysTensor_t> mlp_gate_w;
std::vector<llaisysTensor_t> mlp_up_w;
std::vector<llaisysTensor_t> mlp_down_w;
};

struct LlaisysQwen2Model;
Expand All @@ -37,6 +39,18 @@ __C {

__export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model);

__export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken);
__export void llaisysQwen2LoadWeight(
struct LlaisysQwen2Model * model,
const char * name,
const void * data,
size_t * shape,
size_t ndim,
llaisysDataType_t dtype);

__export int64_t llaisysQwen2ModelInfer(
struct LlaisysQwen2Model * model,
int64_t * token_ids,
size_t ntoken,
size_t start_pos);
}
#endif // LLAISYS_MODELS_QWEN2_H
142 changes: 131 additions & 11 deletions python/llaisys/models/qwen2.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,117 @@
from typing import Sequence
from ..libllaisys import LIB_LLAISYS
from ..libllaisys import DeviceType
from ..libllaisys import DeviceType, DataType

from .qwen2_binding import register_qwen2_lib, Qwen2MetaCStruct

from pathlib import Path
import safetensors
import safetensors.torch
import torch

import os
import json
import ctypes

class Qwen2:

def __init__(self, model_path, device: DeviceType = DeviceType.CPU):
# TODO: Implement model constructor
self.lib = LIB_LLAISYS
register_qwen2_lib(self.lib) # Register C functions

model_path = Path(model_path)
config_path = model_path / "config.json"

if not config_path.exists():
raise FileNotFoundError(f"Config not found at {config_path}")

with open(config_path, "r") as f:
config = json.load(f)

self.meta = Qwen2MetaCStruct()

# Populate Meta (Default fallback values based on typical Qwen2 config)
self.meta.hs = config.get("hidden_size", 1536)
self.meta.nlayer = config.get("num_hidden_layers", 28)
self.meta.nh = config.get("num_attention_heads", 12)
self.meta.nkvh = config.get("num_key_value_heads", 2)
self.meta.voc = config.get("vocab_size", 151936)
self.meta.maxseq = config.get("max_position_embeddings", 32768)
self.meta.di = config.get("intermediate_size", 8960)
self.meta.epsilon = config.get("rms_norm_eps", 1e-6)
self.meta.theta = config.get("rope_theta", 10000.0)
self.meta.dh = self.meta.hs // self.meta.nh

# Determine EOS token
eos_id = config.get("eos_token_id", 151643) # ID of <|endoftext|>
if isinstance(eos_id, list):
self.meta.end_token = eos_id[0]
else:
self.meta.end_token = eos_id

# Set dtype for the model struct (match weight dtype when possible)
torch_dtype = str(config.get("torch_dtype", "float32")).lower()
if "bfloat16" in torch_dtype or "bf16" in torch_dtype:
self.meta.dtype = DataType.BF16
elif "float16" in torch_dtype or "fp16" in torch_dtype:
self.meta.dtype = DataType.F16
else:
self.meta.dtype = DataType.F32

# Create C Model
device_ids = (ctypes.c_int * 1)(0)
# Use F32 for KV cache for stability on CPU
self.model = self.lib.llaisysQwen2ModelCreate(
ctypes.byref(self.meta),
device,
device_ids,
1
)

if not self.model:
raise RuntimeError("Failed to create native Qwen2 model instance.")

# Load Weights
for file in sorted(model_path.glob("*.safetensors")):
data_ = safetensors.safe_open(file, framework="numpy", device="cpu")
for name_ in data_.keys():
## TODO: load the model weights
pass
print(f"Loading weights from {file}...")
# Use safe_open from safetensors (torch backend to support BF16)
with safetensors.torch.safe_open(file, framework="pt", device="cpu") as f:
for name in f.keys():
tensor = f.get_tensor(name)

# Map torch dtype to Llaisys DataType
dt = DataType.F32
if tensor.dtype == torch.float16:
dt = DataType.F16
elif tensor.dtype == torch.float32:
dt = DataType.F32
elif tensor.dtype == torch.bfloat16:
dt = DataType.BF16
elif tensor.dtype == torch.int64:
dt = DataType.I64

# Ensure contiguous
if not tensor.is_contiguous():
tensor = tensor.contiguous()

shape = tensor.shape
c_shape = (ctypes.c_size_t * len(shape))(*shape)

# Keep a reference to data pointer valid during the C call
data_ptr = ctypes.c_void_p(tensor.data_ptr())

self.lib.llaisysQwen2LoadWeight(
self.model,
name.encode('utf-8'),
data_ptr,
c_shape,
len(shape),
dt
)

def __del__(self):
if hasattr(self, 'model') and self.model:
self.lib.llaisysQwen2ModelDestroy(self.model)
self.model = None

def generate(
self,
Expand All @@ -27,7 +121,33 @@ def generate(
top_p: float = 0.8,
temperature: float = 0.8,
):

# TODO: Implement generate function

return []
if max_new_tokens is None:
max_new_tokens = 20

tokens = list(inputs)
start_pos = 0

for _ in range(max_new_tokens):
if start_pos == 0:
current_input = tokens
else:
current_input = tokens[-1:] # Next token generation: use only last token

n_tokens = len(current_input)
c_inputs = (ctypes.c_int64 * n_tokens)(*current_input)

# Infer (argmax inside backend)
next_token_id = self.lib.llaisysQwen2ModelInfer(
self.model,
c_inputs,
n_tokens,
start_pos
)

tokens.append(next_token_id)
start_pos += n_tokens

if next_token_id == self.meta.end_token:
break

return tokens
57 changes: 57 additions & 0 deletions python/llaisys/models/qwen2_binding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import ctypes
from ctypes import c_size_t, c_int, c_float, c_void_p, c_int64, POINTER, Structure, c_char_p
from ..libllaisys.llaisys_types import DataType, llaisysDataType_t, llaisysDeviceType_t

class Qwen2MetaCStruct(Structure):
_fields_ = [
("dtype", llaisysDataType_t),
("nlayer", c_size_t),
("hs", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("dh", c_size_t),
("di", c_size_t),
("maxseq", c_size_t),
("voc", c_size_t),
("epsilon", c_float),
("theta", c_float),
("end_token", c_int64),
]

# Opaque pointer handle
LlaisysQwen2ModelHandle = c_void_p

def register_qwen2_lib(lib):
if hasattr(lib, "llaisysQwen2ModelCreate"):
# Create
lib.llaisysQwen2ModelCreate.restype = LlaisysQwen2ModelHandle
lib.llaisysQwen2ModelCreate.argtypes = [
POINTER(Qwen2MetaCStruct),
llaisysDeviceType_t,
POINTER(c_int), # device_ids
c_int # ndev
]

# Destroy
lib.llaisysQwen2ModelDestroy.restype = None
lib.llaisysQwen2ModelDestroy.argtypes = [LlaisysQwen2ModelHandle]

# Load Weight
lib.llaisysQwen2LoadWeight.restype = None
lib.llaisysQwen2LoadWeight.argtypes = [
LlaisysQwen2ModelHandle,
c_char_p, # name
c_void_p, # data
POINTER(c_size_t), # shape
c_size_t, # ndim
llaisysDataType_t # dtype
]

# Infer
lib.llaisysQwen2ModelInfer.restype = c_int64
lib.llaisysQwen2ModelInfer.argtypes = [
LlaisysQwen2ModelHandle,
POINTER(c_int64), # input_ids_ptr
c_size_t, # seq_len
c_size_t # start_pos
]
Loading