Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lightllm.utils.infer_utils import mark_cost_time
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv
from lightllm.distributed import all_reduce
from lightllm.common.mem_manager import MemoryManager
from typing import Tuple


Expand Down Expand Up @@ -51,8 +52,8 @@ def _post_cache_kv(self, cache_kv, infer_state: InferStateInfo, layer_weight):
self._copy_kv_to_mem_cache(cache_kv, infer_state.mem_index, mem_manager)
return

def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager):
destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_])
def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager: MemoryManager):
destindex_copy_kv(buffer, mem_index, mem_manager.get_kv_buffer(self.layer_num_))
return

def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .fused_moe_weight_tp import FusedMoeWeightTP
from .fused_moe_weight_ep import FusedMoeWeightEP
from .parameter_weight import ParameterWeight, TpParameterWeight
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from typing import Dict
from .base_weight import BaseWeightTpl
from lightllm.utils.dist_utils import get_current_device_id


class ParameterWeight(BaseWeightTpl):
def __init__(self, weight_name: str, data_type: torch.dtype, bias_name: str = None):
super().__init__()
self.weight_name = weight_name
self.bias_name = bias_name
self.data_type_ = data_type
self.weight = None
self.bias = None

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
if self.weight_name in weights:
self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id())

def verify_load(self):
load_ok = True
# Verify weight. The weight must be not None.
load_ok = load_ok and self.weight is not None
# Verify bias. If bias_name is set, it must be not None.
if self.bias_name is not None:
load_ok = load_ok and self.bias is not None
return load_ok


class TpParameterWeight(ParameterWeight):
def __init__(self, weight_name: str, data_type: torch.dtype, split_n_embed: int, bias_name: str = None):
super().__init__(weight_name, data_type, bias_name)
self.split_n_embed = split_n_embed

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
start = self.split_n_embed * self.tp_rank_
end = self.split_n_embed * (self.tp_rank_ + 1)

if self.weight_name in weights:
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id())
3 changes: 3 additions & 0 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
def get_cell_size(self):
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)

def get_kv_buffer(self, layer_index):
return self.kv_buffer[layer_index]

def profile_size(self, mem_fraction):
if self.size is not None:
return
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"1": {
"BLK_HEADS": 8,
"num_warps": 1
},
"100": {
"BLK_HEADS": 8,
"num_warps": 1
},
"1024": {
"BLK_HEADS": 8,
"num_warps": 4
},
"128": {
"BLK_HEADS": 64,
"num_warps": 1
},
"16": {
"BLK_HEADS": 32,
"num_warps": 2
},
"2048": {
"BLK_HEADS": 8,
"num_warps": 1
},
"256": {
"BLK_HEADS": 32,
"num_warps": 1
},
"32": {
"BLK_HEADS": 64,
"num_warps": 1
},
"4096": {
"BLK_HEADS": 8,
"num_warps": 1
},
"64": {
"BLK_HEADS": 8,
"num_warps": 1
},
"8": {
"BLK_HEADS": 32,
"num_warps": 4
},
"8448": {
"BLK_HEADS": 8,
"num_warps": 4
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"1024": {
"BLOCK_N": 64,
"num_warps": 1
},
"128": {
"BLOCK_N": 256,
"num_warps": 1
},
"16384": {
"BLOCK_N": 128,
"num_warps": 2
},
"2048": {
"BLOCK_N": 64,
"num_warps": 1
},
"256": {
"BLOCK_N": 1024,
"num_warps": 2
},
"32768": {
"BLOCK_N": 128,
"num_warps": 2
},
"512": {
"BLOCK_N": 256,
"num_warps": 2
},
"64": {
"BLOCK_N": 512,
"num_warps": 1
},
"67584": {
"BLOCK_N": 64,
"num_warps": 1
},
"8": {
"BLOCK_N": 256,
"num_warps": 2
},
"800": {
"BLOCK_N": 1024,
"num_warps": 2
},
"8192": {
"BLOCK_N": 64,
"num_warps": 1
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
{
"10": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 16,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"1000": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"10240": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"NEED_TRANS": false,
"num_stages": 5,
"num_warps": 4
},
"1280": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 16,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"160": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"20480": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"2560": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"320": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"40960": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 8
},
"640": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"80": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"84480": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 16,
"NEED_TRANS": false,
"num_stages": 4,
"num_warps": 4
}
}
Loading