Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,8 @@ htmlcov/
# Windows
Thumbs.db
ehthumbs.db
desktop.ini
desktop.ini

# Models
/models/
/checkpoints/
1 change: 1 addition & 0 deletions include/llaisys.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ typedef enum {
LLAISYS_DEVICE_CPU = 0,
//// TODO: Add more device types here. Numbers need to be consecutive.
LLAISYS_DEVICE_NVIDIA = 1,
LLAISYS_DEVICE_METAX = 2,
LLAISYS_DEVICE_TYPE_COUNT
} llaisysDeviceType_t;

Expand Down
2 changes: 2 additions & 0 deletions include/llaisys/models/qwen2.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,7 @@ __C {
__export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model);

__export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken);

__export void llaisysQwen2ModelResetCache(struct LlaisysQwen2Model * model);
}
#endif // LLAISYS_MODELS_QWEN2_H
37 changes: 37 additions & 0 deletions include/llaisys/models/qwen2_tp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef LLAISYS_MODELS_QWEN2_TP_H
#define LLAISYS_MODELS_QWEN2_TP_H

#include "qwen2.h"

__C {
// Tensor Parallel Qwen2 Model
struct LlaisysQwen2ModelTP;

// Create a TP model with multiple devices
// device_ids: array of device IDs (e.g., [0, 1, 2, 3] for 4-GPU TP)
// ndevice: number of devices (TP world size)
__export struct LlaisysQwen2ModelTP *llaisysQwen2ModelTPCreate(
const struct LlaisysQwen2Meta *meta,
const int *device_ids,
int world_size);

__export void llaisysQwen2ModelTPDestroy(struct LlaisysQwen2ModelTP *model);

// Get weights for each rank
// Returns array of weight pointers, one for each rank
__export struct LlaisysQwen2Weights *llaisysQwen2ModelTPWeights(
struct LlaisysQwen2ModelTP *model,
int rank);

__export int64_t llaisysQwen2ModelTPInfer(
struct LlaisysQwen2ModelTP *model,
const int64_t *token_ids,
size_t ntoken);

__export void llaisysQwen2ModelTPResetCache(struct LlaisysQwen2ModelTP *model);

// Get the number of ranks in the TP group
__export int llaisysQwen2ModelTPGetWorldSize(struct LlaisysQwen2ModelTP *model);
}

#endif // LLAISYS_MODELS_QWEN2_TP_H
12 changes: 9 additions & 3 deletions python/llaisys/libllaisys/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from .tensor import llaisysTensor_t
from .tensor import load_tensor
from .ops import load_ops
from .models.qwen2 import load_qwen2, LlaisysQwen2Meta, LlaisysQwen2Weights
from .models.qwen2_tp import load_qwen2_tp


def load_shared_library():
lib_dir = Path(__file__).parent
lib_dir = Path(__file__).parent.resolve()

if sys.platform.startswith("linux"):
libname = "libllaisys.so"
Expand All @@ -26,9 +28,9 @@ def load_shared_library():
else:
raise RuntimeError("Unsupported platform")

lib_path = os.path.join(lib_dir, libname)
lib_path = lib_dir / libname

if not os.path.isfile(lib_path):
if not lib_path.is_file():
raise FileNotFoundError(f"Shared library not found: {lib_path}")

return ctypes.CDLL(str(lib_path))
Expand All @@ -38,6 +40,8 @@ def load_shared_library():
load_runtime(LIB_LLAISYS)
load_tensor(LIB_LLAISYS)
load_ops(LIB_LLAISYS)
load_qwen2(LIB_LLAISYS)
load_qwen2_tp(LIB_LLAISYS)


__all__ = [
Expand All @@ -52,4 +56,6 @@ def load_shared_library():
"llaisysMemcpyKind_t",
"MemcpyKind",
"llaisysStream_t",
"LlaisysQwen2Meta",
"LlaisysQwen2Weights",
]
3 changes: 2 additions & 1 deletion python/llaisys/libllaisys/llaisys_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
class DeviceType(IntEnum):
CPU = 0
NVIDIA = 1
COUNT = 2
METAX = 2
COUNT = 3


llaisysDeviceType_t = ctypes.c_int
Expand Down
3 changes: 3 additions & 0 deletions python/llaisys/libllaisys/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .qwen2 import load_qwen2, LlaisysQwen2Meta, LlaisysQwen2Weights

__all__ = ["load_qwen2", "LlaisysQwen2Meta", "LlaisysQwen2Weights"]
83 changes: 83 additions & 0 deletions python/llaisys/libllaisys/models/qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from ctypes import (
Structure,
POINTER,
c_void_p,
c_size_t,
c_int,
c_int64,
c_float,
)
from ..llaisys_types import llaisysDataType_t, llaisysDeviceType_t
from ..tensor import llaisysTensor_t


class LlaisysQwen2Meta(Structure):
_fields_ = [
("dtype", llaisysDataType_t),
("nlayer", c_size_t),
("hs", c_size_t), # hidden_size
("nh", c_size_t), # num_attention_heads
("nkvh", c_size_t), # num_key_value_heads
("dh", c_size_t), # head_dim
("di", c_size_t), # intermediate_size
("maxseq", c_size_t), # max_position_embeddings
("voc", c_size_t), # vocab_size
("epsilon", c_float), # rms_norm_eps
("theta", c_float), # rope_theta
("end_token", c_int64), # eos_token_id
]


class LlaisysQwen2Weights(Structure):
_fields_ = [
("in_embed", llaisysTensor_t),
("out_embed", llaisysTensor_t),
("out_norm_w", llaisysTensor_t),
("attn_norm_w", POINTER(llaisysTensor_t)),
("attn_q_w", POINTER(llaisysTensor_t)),
("attn_q_b", POINTER(llaisysTensor_t)),
("attn_k_w", POINTER(llaisysTensor_t)),
("attn_k_b", POINTER(llaisysTensor_t)),
("attn_v_w", POINTER(llaisysTensor_t)),
("attn_v_b", POINTER(llaisysTensor_t)),
("attn_o_w", POINTER(llaisysTensor_t)),
("mlp_norm_w", POINTER(llaisysTensor_t)),
("mlp_gate_w", POINTER(llaisysTensor_t)),
("mlp_up_w", POINTER(llaisysTensor_t)),
("mlp_down_w", POINTER(llaisysTensor_t)),
]


# Model handle type
llaisysQwen2Model_t = c_void_p


def load_qwen2(lib):
# llaisysQwen2ModelCreate
lib.llaisysQwen2ModelCreate.argtypes = [
POINTER(LlaisysQwen2Meta), # meta
llaisysDeviceType_t, # device
POINTER(c_int), # device_ids
c_int, # ndevice
]
lib.llaisysQwen2ModelCreate.restype = llaisysQwen2Model_t

# llaisysQwen2ModelDestroy
lib.llaisysQwen2ModelDestroy.argtypes = [llaisysQwen2Model_t]
lib.llaisysQwen2ModelDestroy.restype = None

# llaisysQwen2ModelWeights
lib.llaisysQwen2ModelWeights.argtypes = [llaisysQwen2Model_t]
lib.llaisysQwen2ModelWeights.restype = POINTER(LlaisysQwen2Weights)

# llaisysQwen2ModelInfer
lib.llaisysQwen2ModelInfer.argtypes = [
llaisysQwen2Model_t, # model
POINTER(c_int64), # token_ids
c_size_t, # ntoken
]
lib.llaisysQwen2ModelInfer.restype = c_int64

# llaisysQwen2ModelResetCache
lib.llaisysQwen2ModelResetCache.argtypes = [llaisysQwen2Model_t]
lib.llaisysQwen2ModelResetCache.restype = None
52 changes: 52 additions & 0 deletions python/llaisys/libllaisys/models/qwen2_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from ctypes import (
Structure,
POINTER,
c_void_p,
c_size_t,
c_int,
c_int64,
)
from ..llaisys_types import llaisysDataType_t, llaisysDeviceType_t
from ..tensor import llaisysTensor_t
from .qwen2 import LlaisysQwen2Meta, LlaisysQwen2Weights


# TP Model handle type
llaisysQwen2ModelTP_t = c_void_p


def load_qwen2_tp(lib):
# llaisysQwen2ModelTPCreate
lib.llaisysQwen2ModelTPCreate.argtypes = [
POINTER(LlaisysQwen2Meta), # meta
POINTER(c_int), # device_ids
c_int, # world_size
]
lib.llaisysQwen2ModelTPCreate.restype = llaisysQwen2ModelTP_t

# llaisysQwen2ModelTPDestroy
lib.llaisysQwen2ModelTPDestroy.argtypes = [llaisysQwen2ModelTP_t]
lib.llaisysQwen2ModelTPDestroy.restype = None

# llaisysQwen2ModelTPWeights
lib.llaisysQwen2ModelTPWeights.argtypes = [
llaisysQwen2ModelTP_t, # model
c_int, # rank
]
lib.llaisysQwen2ModelTPWeights.restype = POINTER(LlaisysQwen2Weights)

# llaisysQwen2ModelTPInfer
lib.llaisysQwen2ModelTPInfer.argtypes = [
llaisysQwen2ModelTP_t, # model
POINTER(c_int64), # token_ids
c_size_t, # ntoken
]
lib.llaisysQwen2ModelTPInfer.restype = c_int64

# llaisysQwen2ModelTPResetCache
lib.llaisysQwen2ModelTPResetCache.argtypes = [llaisysQwen2ModelTP_t]
lib.llaisysQwen2ModelTPResetCache.restype = None

# llaisysQwen2ModelTPGetWorldSize
lib.llaisysQwen2ModelTPGetWorldSize.argtypes = [llaisysQwen2ModelTP_t]
lib.llaisysQwen2ModelTPGetWorldSize.restype = c_int
3 changes: 3 additions & 0 deletions python/llaisys/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .qwen2 import Qwen2
from .qwen2_tp import Qwen2TP

__all__ = ["Qwen2", "Qwen2TP"]
Loading