Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/llaisys/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define LLAISYS_RUNTIME_H

#include "../llaisys.h"
#include "tensor.h"

__C {
// Runtime API Functions
Expand Down Expand Up @@ -42,6 +43,18 @@ __C {

// Llaisys API for switching device context
__export void llaisysSetContextRuntime(llaisysDeviceType_t, int);

// Distributed runtime APIs
__export void llaisysInitDistributed(int rank, int world_size);
__export void llaisysFinalizeDistributed();
__export uint8_t llaisysDistributedIsInitialized();
__export int llaisysDistributedRank();
__export int llaisysDistributedWorldSize();

__export void llaisysDistAllReduce(llaisysTensor_t tensor);
__export llaisysTensor_t llaisysDistAllGather(llaisysTensor_t tensor);
__export void llaisysDistBroadcast(llaisysTensor_t tensor, int root);
__export void llaisysDistBarrier();
}

#endif // LLAISYS_RUNTIME_H
2 changes: 2 additions & 0 deletions python/llaisys/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .runtime import RuntimeAPI
from .runtime import DistributedContext
from .libllaisys import DeviceType
from .libllaisys import DataType
from .libllaisys import MemcpyKind
Expand All @@ -10,6 +11,7 @@

__all__ = [
"RuntimeAPI",
"DistributedContext",
"DeviceType",
"DataType",
"MemcpyKind",
Expand Down
7 changes: 7 additions & 0 deletions python/llaisys/libllaisys/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import ctypes
from pathlib import Path
import torch

from .runtime import load_runtime
from .runtime import LlaisysRuntimeAPI
Expand All @@ -12,6 +13,8 @@
from .tensor import llaisysTensor_t
from .tensor import load_tensor
from .ops import load_ops
from .models import load_models
from .models import LlaisysQwen2Meta, LlaisysQwen2Weights, llaisysQwen2Model_t


def load_shared_library():
Expand All @@ -38,6 +41,7 @@ def load_shared_library():
load_runtime(LIB_LLAISYS)
load_tensor(LIB_LLAISYS)
load_ops(LIB_LLAISYS)
load_models(LIB_LLAISYS)


__all__ = [
Expand All @@ -52,4 +56,7 @@ def load_shared_library():
"llaisysMemcpyKind_t",
"MemcpyKind",
"llaisysStream_t",
"LlaisysQwen2Meta",
"LlaisysQwen2Weights",
"llaisysQwen2Model_t",
]
72 changes: 72 additions & 0 deletions python/llaisys/libllaisys/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import ctypes
from ctypes import POINTER, c_void_p, c_size_t, c_int, c_int64, c_float, Structure
from .llaisys_types import llaisysDataType_t, llaisysDeviceType_t
from .tensor import llaisysTensor_t


# Model handle type
llaisysQwen2Model_t = c_void_p


class LlaisysQwen2Meta(Structure):
_fields_ = [
("dtype", llaisysDataType_t),
("nlayer", c_size_t),
("hs", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("dh", c_size_t),
("di", c_size_t),
("maxseq", c_size_t),
("voc", c_size_t),
("epsilon", c_float),
("theta", c_float),
("end_token", c_int64),
]


class LlaisysQwen2Weights(Structure):
_fields_ = [
("in_embed", llaisysTensor_t),
("out_embed", llaisysTensor_t),
("out_norm_w", llaisysTensor_t),
("attn_norm_w", POINTER(llaisysTensor_t)),
("attn_q_w", POINTER(llaisysTensor_t)),
("attn_q_b", POINTER(llaisysTensor_t)),
("attn_k_w", POINTER(llaisysTensor_t)),
("attn_k_b", POINTER(llaisysTensor_t)),
("attn_v_w", POINTER(llaisysTensor_t)),
("attn_v_b", POINTER(llaisysTensor_t)),
("attn_o_w", POINTER(llaisysTensor_t)),
("mlp_norm_w", POINTER(llaisysTensor_t)),
("mlp_gate_w", POINTER(llaisysTensor_t)),
("mlp_up_w", POINTER(llaisysTensor_t)),
("mlp_down_w", POINTER(llaisysTensor_t)),
]


def load_models(lib):
# llaisysQwen2ModelCreate
lib.llaisysQwen2ModelCreate.argtypes = [
POINTER(LlaisysQwen2Meta), # meta
llaisysDeviceType_t, # device
POINTER(c_int), # device_ids
c_int, # ndevice
]
lib.llaisysQwen2ModelCreate.restype = llaisysQwen2Model_t

# llaisysQwen2ModelDestroy
lib.llaisysQwen2ModelDestroy.argtypes = [llaisysQwen2Model_t]
lib.llaisysQwen2ModelDestroy.restype = None

# llaisysQwen2ModelWeights
lib.llaisysQwen2ModelWeights.argtypes = [llaisysQwen2Model_t]
lib.llaisysQwen2ModelWeights.restype = POINTER(LlaisysQwen2Weights)

# llaisysQwen2ModelInfer
lib.llaisysQwen2ModelInfer.argtypes = [
llaisysQwen2Model_t, # model
POINTER(c_int64), # token_ids
c_size_t, # ntoken
]
lib.llaisysQwen2ModelInfer.restype = c_int64
28 changes: 28 additions & 0 deletions python/llaisys/libllaisys/runtime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ctypes
from ctypes import c_void_p, c_size_t, c_int, Structure, CFUNCTYPE
from .llaisys_types import *
from .tensor import llaisysTensor_t

# Define function pointer types
get_device_count_api = CFUNCTYPE(c_int)
Expand Down Expand Up @@ -46,3 +47,30 @@ def load_runtime(lib):

lib.llaisysSetContextRuntime.argtypes = [llaisysDeviceType_t, c_int]
lib.llaisysSetContextRuntime.restype = None

lib.llaisysInitDistributed.argtypes = [c_int, c_int]
lib.llaisysInitDistributed.restype = None

lib.llaisysFinalizeDistributed.argtypes = []
lib.llaisysFinalizeDistributed.restype = None

lib.llaisysDistributedIsInitialized.argtypes = []
lib.llaisysDistributedIsInitialized.restype = ctypes.c_uint8

lib.llaisysDistributedRank.argtypes = []
lib.llaisysDistributedRank.restype = c_int

lib.llaisysDistributedWorldSize.argtypes = []
lib.llaisysDistributedWorldSize.restype = c_int

lib.llaisysDistAllReduce.argtypes = [llaisysTensor_t]
lib.llaisysDistAllReduce.restype = None

lib.llaisysDistAllGather.argtypes = [llaisysTensor_t]
lib.llaisysDistAllGather.restype = llaisysTensor_t

lib.llaisysDistBroadcast.argtypes = [llaisysTensor_t, c_int]
lib.llaisysDistBroadcast.restype = None

lib.llaisysDistBarrier.argtypes = []
lib.llaisysDistBarrier.restype = None
Loading