Skip to content
Open

1 #44

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions python/llaisys/libllaisys/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from ctypes import c_void_p, c_int64, c_size_t, c_float, c_int, POINTER, Structure
from . import LIB_LLAISYS

class Qwen2Meta(Structure):
_fields_ = [
("dtype", c_int),
("nlayer", c_size_t),
("hs", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("dh", c_size_t),
("di", c_size_t),
("maxseq", c_size_t),
("voc", c_size_t),
("epsilon", c_float),
("theta", c_float),
("end_token", c_int64),
]

class Qwen2Weights(Structure):
_fields_ = [
("in_embed", c_void_p),
("out_embed", c_void_p),
("out_norm_w", c_void_p),
("attn_norm_w", POINTER(c_void_p)),
("attn_q_w", POINTER(c_void_p)),
("attn_q_b", POINTER(c_void_p)),
("attn_k_w", POINTER(c_void_p)),
("attn_k_b", POINTER(c_void_p)),
("attn_v_w", POINTER(c_void_p)),
("attn_v_b", POINTER(c_void_p)),
("attn_o_w", POINTER(c_void_p)),
("mlp_norm_w", POINTER(c_void_p)),
("mlp_gate_w", POINTER(c_void_p)),
("mlp_up_w", POINTER(c_void_p)),
("mlp_down_w", POINTER(c_void_p)),
]

LIB_LLAISYS.llaisysQwen2ModelCreate.argtypes = [POINTER(Qwen2Meta), c_int, POINTER(c_int), c_int]
LIB_LLAISYS.llaisysQwen2ModelCreate.restype = c_void_p

LIB_LLAISYS.llaisysQwen2ModelDestroy.argtypes = [c_void_p]
LIB_LLAISYS.llaisysQwen2ModelDestroy.restype = None

LIB_LLAISYS.llaisysQwen2ModelWeights.argtypes = [c_void_p]
LIB_LLAISYS.llaisysQwen2ModelWeights.restype = POINTER(Qwen2Weights)

LIB_LLAISYS.llaisysQwen2ModelInfer.argtypes = [c_void_p, POINTER(c_int64), c_size_t]
LIB_LLAISYS.llaisysQwen2ModelInfer.restype = c_int64
91 changes: 81 additions & 10 deletions python/llaisys/models/qwen2.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,85 @@
from typing import Sequence
from ..libllaisys import LIB_LLAISYS
from ..libllaisys import DeviceType

from ..libllaisys import DeviceType, DataType
from ..libllaisys.models import Qwen2Meta, Qwen2Weights
from pathlib import Path
import safetensors

import json
from ctypes import c_int64, c_int

class Qwen2:

def __init__(self, model_path, device: DeviceType = DeviceType.CPU):
# TODO: Implement model constructor

model_path = Path(model_path)


with open(model_path / "config.json") as f:
config = json.load(f)

meta = Qwen2Meta()
meta.dtype = DataType.F32
meta.nlayer = config["num_hidden_layers"]
meta.hs = config["hidden_size"]
meta.nh = config["num_attention_heads"]
meta.nkvh = config["num_key_value_heads"]
meta.dh = config["hidden_size"] // config["num_attention_heads"]
meta.di = config["intermediate_size"]
meta.maxseq = config.get("max_position_embeddings", 32768)
meta.voc = config["vocab_size"]
meta.epsilon = config["rms_norm_eps"]
meta.theta = config.get("rope_theta", 10000.0)
meta.end_token = config.get("eos_token_id", 151643)

self.model = LIB_LLAISYS.llaisysQwen2ModelCreate(meta, device, None, 0)
self.weights_ptr = LIB_LLAISYS.llaisysQwen2ModelWeights(self.model)
self.weights = self.weights_ptr.contents
self.nlayer = meta.nlayer
self.end_token = meta.end_token

for file in sorted(model_path.glob("*.safetensors")):
data_ = safetensors.safe_open(file, framework="numpy", device="cpu")
data_ = safetensors.safe_open(file, framework="pt", device="cpu")
for name_ in data_.keys():
## TODO: load the model weights
pass
tensor_data = data_.get_tensor(name_)
if tensor_data.dtype.is_floating_point and tensor_data.dtype != tensor_data.float().dtype:
tensor_data = tensor_data.float()
tensor_data = tensor_data.numpy()
self._load_weight(name_, tensor_data)

def _load_weight(self, name, data):
if name == "model.embed_tokens.weight":
LIB_LLAISYS.tensorLoad(self.weights.in_embed, data.ctypes.data)
elif name == "lm_head.weight":
LIB_LLAISYS.tensorLoad(self.weights.out_embed, data.ctypes.data)
elif name == "model.norm.weight":
LIB_LLAISYS.tensorLoad(self.weights.out_norm_w, data.ctypes.data)
elif "layers" in name:
parts = name.split(".")
layer_idx = int(parts[2])
if "input_layernorm.weight" in name:
LIB_LLAISYS.tensorLoad(self.weights.attn_norm_w[layer_idx], data.ctypes.data)
elif "self_attn.q_proj.weight" in name:
LIB_LLAISYS.tensorLoad(self.weights.attn_q_w[layer_idx], data.ctypes.data)
elif "self_attn.q_proj.bias" in name:
LIB_LLAISYS.tensorLoad(self.weights.attn_q_b[layer_idx], data.ctypes.data)
elif "self_attn.k_proj.weight" in name:
LIB_LLAISYS.tensorLoad(self.weights.attn_k_w[layer_idx], data.ctypes.data)
elif "self_attn.k_proj.bias" in name:
LIB_LLAISYS.tensorLoad(self.weights.attn_k_b[layer_idx], data.ctypes.data)
elif "self_attn.v_proj.weight" in name:
LIB_LLAISYS.tensorLoad(self.weights.attn_v_w[layer_idx], data.ctypes.data)
elif "self_attn.v_proj.bias" in name:
LIB_LLAISYS.tensorLoad(self.weights.attn_v_b[layer_idx], data.ctypes.data)
elif "self_attn.o_proj.weight" in name:
LIB_LLAISYS.tensorLoad(self.weights.attn_o_w[layer_idx], data.ctypes.data)
elif "post_attention_layernorm.weight" in name:
LIB_LLAISYS.tensorLoad(self.weights.mlp_norm_w[layer_idx], data.ctypes.data)
elif "mlp.gate_proj.weight" in name:
LIB_LLAISYS.tensorLoad(self.weights.mlp_gate_w[layer_idx], data.ctypes.data)
elif "mlp.up_proj.weight" in name:
LIB_LLAISYS.tensorLoad(self.weights.mlp_up_w[layer_idx], data.ctypes.data)
elif "mlp.down_proj.weight" in name:
LIB_LLAISYS.tensorLoad(self.weights.mlp_down_w[layer_idx], data.ctypes.data)

def generate(
self,
Expand All @@ -27,7 +89,16 @@ def generate(
top_p: float = 0.8,
temperature: float = 0.8,
):

# TODO: Implement generate function

return []
tokens = list(inputs)
for _ in range(max_new_tokens or 128):
token_array = (c_int64 * len(tokens))(*tokens)
next_token = LIB_LLAISYS.llaisysQwen2ModelInfer(self.model, token_array, len(tokens))
tokens.append(next_token)
if next_token == self.end_token:
break
return tokens

def __del__(self):
if hasattr(self, 'model'):
LIB_LLAISYS.llaisysQwen2ModelDestroy(self.model)
31 changes: 31 additions & 0 deletions src/llaisys/qwen2.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "llaisys/models/qwen2.h"
#include "../models/qwen2/qwen2_model.hpp"

__C {

struct LlaisysQwen2Model {
llaisys::models::Qwen2Model *model;
};

__export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) {
auto model = new LlaisysQwen2Model;
model->model = new llaisys::models::Qwen2Model(meta, device, device_ids ? device_ids[0] : 0);
return model;
}

__export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model *model) {
delete model->model;
delete model;
}

__export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model *model) {
auto weights = new LlaisysQwen2Weights;
*weights = model->model->getWeights();
return weights;
}

__export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken) {
return model->model->infer(token_ids, ntoken);
}

}
184 changes: 184 additions & 0 deletions src/models/qwen2/qwen2_model.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#include "qwen2_model.hpp"
#include "../../llaisys/llaisys_tensor.hpp"
#include "../../ops/embedding/op.hpp"
#include "../../ops/linear/op.hpp"
#include "../../ops/rms_norm/op.hpp"
#include "../../ops/rope/op.hpp"
#include "../../ops/self_attention/op.hpp"
#include "../../ops/swiglu/op.hpp"
#include "../../ops/argmax/op.hpp"
#include "../../ops/add/op.hpp"
#include "../../ops/rearrange/op.hpp"
#include "../../utils.hpp"
#include <cmath>
#include <cstring>

namespace llaisys::models {

Qwen2Model::Qwen2Model(const LlaisysQwen2Meta *meta_, llaisysDeviceType_t device, int dev_id)
: device_type(device), device_id(dev_id), cur_seq_len(0) {
meta = *meta_;

in_embed = Tensor::create({meta.voc, meta.hs}, meta.dtype, device, dev_id);
out_embed = Tensor::create({meta.voc, meta.hs}, meta.dtype, device, dev_id);
out_norm_w = Tensor::create({meta.hs}, meta.dtype, device, dev_id);

for (size_t i = 0; i < meta.nlayer; i++) {
attn_norm_w.push_back(Tensor::create({meta.hs}, meta.dtype, device, dev_id));
attn_q_w.push_back(Tensor::create({meta.nh * meta.dh, meta.hs}, meta.dtype, device, dev_id));
attn_q_b.push_back(Tensor::create({meta.nh * meta.dh}, meta.dtype, device, dev_id));
attn_k_w.push_back(Tensor::create({meta.nkvh * meta.dh, meta.hs}, meta.dtype, device, dev_id));
attn_k_b.push_back(Tensor::create({meta.nkvh * meta.dh}, meta.dtype, device, dev_id));
attn_v_w.push_back(Tensor::create({meta.nkvh * meta.dh, meta.hs}, meta.dtype, device, dev_id));
attn_v_b.push_back(Tensor::create({meta.nkvh * meta.dh}, meta.dtype, device, dev_id));
attn_o_w.push_back(Tensor::create({meta.hs, meta.nh * meta.dh}, meta.dtype, device, dev_id));

mlp_norm_w.push_back(Tensor::create({meta.hs}, meta.dtype, device, dev_id));
mlp_gate_w.push_back(Tensor::create({meta.di, meta.hs}, meta.dtype, device, dev_id));
mlp_up_w.push_back(Tensor::create({meta.di, meta.hs}, meta.dtype, device, dev_id));
mlp_down_w.push_back(Tensor::create({meta.hs, meta.di}, meta.dtype, device, dev_id));

k_cache.push_back(Tensor::create({meta.maxseq, meta.nkvh, meta.dh}, meta.dtype, device, dev_id));
v_cache.push_back(Tensor::create({meta.maxseq, meta.nkvh, meta.dh}, meta.dtype, device, dev_id));
}
}

LlaisysQwen2Weights Qwen2Model::getWeights() {
LlaisysQwen2Weights weights;
weights.in_embed = new LlaisysTensor{in_embed};
weights.out_embed = new LlaisysTensor{out_embed};
weights.out_norm_w = new LlaisysTensor{out_norm_w};

weights.attn_norm_w = new llaisysTensor_t[meta.nlayer];
weights.attn_q_w = new llaisysTensor_t[meta.nlayer];
weights.attn_q_b = new llaisysTensor_t[meta.nlayer];
weights.attn_k_w = new llaisysTensor_t[meta.nlayer];
weights.attn_k_b = new llaisysTensor_t[meta.nlayer];
weights.attn_v_w = new llaisysTensor_t[meta.nlayer];
weights.attn_v_b = new llaisysTensor_t[meta.nlayer];
weights.attn_o_w = new llaisysTensor_t[meta.nlayer];
weights.mlp_norm_w = new llaisysTensor_t[meta.nlayer];
weights.mlp_gate_w = new llaisysTensor_t[meta.nlayer];
weights.mlp_up_w = new llaisysTensor_t[meta.nlayer];
weights.mlp_down_w = new llaisysTensor_t[meta.nlayer];

for (size_t i = 0; i < meta.nlayer; i++) {
weights.attn_norm_w[i] = new LlaisysTensor{attn_norm_w[i]};
weights.attn_q_w[i] = new LlaisysTensor{attn_q_w[i]};
weights.attn_q_b[i] = new LlaisysTensor{attn_q_b[i]};
weights.attn_k_w[i] = new LlaisysTensor{attn_k_w[i]};
weights.attn_k_b[i] = new LlaisysTensor{attn_k_b[i]};
weights.attn_v_w[i] = new LlaisysTensor{attn_v_w[i]};
weights.attn_v_b[i] = new LlaisysTensor{attn_v_b[i]};
weights.attn_o_w[i] = new LlaisysTensor{attn_o_w[i]};
weights.mlp_norm_w[i] = new LlaisysTensor{mlp_norm_w[i]};
weights.mlp_gate_w[i] = new LlaisysTensor{mlp_gate_w[i]};
weights.mlp_up_w[i] = new LlaisysTensor{mlp_up_w[i]};
weights.mlp_down_w[i] = new LlaisysTensor{mlp_down_w[i]};
}

return weights;
}

int64_t Qwen2Model::infer(int64_t *token_ids, size_t ntoken) {
//向前传播
size_t seqlen = ntoken - cur_seq_len;

//embedding
auto idx_tensor = Tensor::create({seqlen}, LLAISYS_DTYPE_I64, device_type, device_id);
idx_tensor->load(token_ids + cur_seq_len);
auto x = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id);
ops::embedding(x, idx_tensor, in_embed);

//位置编码ID
std::vector<int64_t> pos_ids_vec(seqlen);
for (size_t i = 0; i < seqlen; i++) pos_ids_vec[i] = cur_seq_len + i;
auto pos_ids = Tensor::create({seqlen}, LLAISYS_DTYPE_I64, device_type, device_id);
pos_ids->load(pos_ids_vec.data());

//Transformer层
for (size_t layer = 0; layer < meta.nlayer; layer++) {

auto x_norm = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id);
ops::rms_norm(x_norm, x, attn_norm_w[layer], meta.epsilon);

auto q = Tensor::create({seqlen, meta.nh * meta.dh}, meta.dtype, device_type, device_id);
auto k = Tensor::create({seqlen, meta.nkvh * meta.dh}, meta.dtype, device_type, device_id);
auto v = Tensor::create({seqlen, meta.nkvh * meta.dh}, meta.dtype, device_type, device_id);
ops::linear(q, x_norm, attn_q_w[layer], attn_q_b[layer]);
ops::linear(k, x_norm, attn_k_w[layer], attn_k_b[layer]);
ops::linear(v, x_norm, attn_v_w[layer], attn_v_b[layer]);

//重塑
q = q->view({seqlen, meta.nh, meta.dh});
k = k->view({seqlen, meta.nkvh, meta.dh});
v = v->view({seqlen, meta.nkvh, meta.dh});

//rope
auto q_rope = Tensor::create({seqlen, meta.nh, meta.dh}, meta.dtype, device_type, device_id);
auto k_rope = Tensor::create({seqlen, meta.nkvh, meta.dh}, meta.dtype, device_type, device_id);
ops::rope(q_rope, q, pos_ids, meta.theta);
ops::rope(k_rope, k, pos_ids, meta.theta);

//更新KV cache
auto k_cache_slice = k_cache[layer]->slice(0, cur_seq_len, cur_seq_len + seqlen);
auto v_cache_slice = v_cache[layer]->slice(0, cur_seq_len, cur_seq_len + seqlen);
ops::rearrange(k_cache_slice, k_rope);
ops::rearrange(v_cache_slice, v);

auto k_full = k_cache[layer]->slice(0, 0, cur_seq_len + seqlen);
auto v_full = v_cache[layer]->slice(0, 0, cur_seq_len + seqlen);

//self attention
auto attn_out = Tensor::create({seqlen, meta.nh, meta.dh}, meta.dtype, device_type, device_id);
float scale = 1.0f / std::sqrt(static_cast<float>(meta.dh));
ops::self_attention(attn_out, q_rope, k_full, v_full, scale);

attn_out = attn_out->view({seqlen, meta.nh * meta.dh});
auto attn_proj = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id);
ops::linear(attn_proj, attn_out, attn_o_w[layer], nullptr);

ops::add(x, x, attn_proj);

//MLP
auto x_mlp = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id);
ops::rms_norm(x_mlp, x, mlp_norm_w[layer], meta.epsilon);

auto gate = Tensor::create({seqlen, meta.di}, meta.dtype, device_type, device_id);
auto up = Tensor::create({seqlen, meta.di}, meta.dtype, device_type, device_id);
ops::linear(gate, x_mlp, mlp_gate_w[layer], nullptr);
ops::linear(up, x_mlp, mlp_up_w[layer], nullptr);

auto mlp_out = Tensor::create({seqlen, meta.di}, meta.dtype, device_type, device_id);
ops::swiglu(mlp_out, gate, up);

auto mlp_proj = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id);
ops::linear(mlp_proj, mlp_out, mlp_down_w[layer], nullptr);

// residual
ops::add(x, x, mlp_proj);
}

//归一化
auto x_final = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id);
ops::rms_norm(x_final, x, out_norm_w, meta.epsilon);

//用最后一个预测
auto last_hidden = x_final->slice(0, seqlen - 1, seqlen);
auto logits = Tensor::create({1, meta.voc}, meta.dtype, device_type, device_id);
ops::linear(logits, last_hidden, out_embed, nullptr);

//argmax
auto max_idx = Tensor::create({1}, LLAISYS_DTYPE_I64, device_type, device_id);
auto max_val = Tensor::create({1}, meta.dtype, device_type, device_id);
ops::argmax(max_idx, max_val, logits->view({meta.voc}));

int64_t result;
std::byte *data = max_idx->data();
std::memcpy(&result, data, sizeof(int64_t));

cur_seq_len += seqlen;
return result;
}

}
Loading