Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions ktransformers/models/custom_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
if multi_batch_enabled:
latent_shape = (max_batch_size, self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
else:
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
self.kv_lora_rank = config.kv_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
# TODO: support real page table
Expand Down Expand Up @@ -143,8 +147,14 @@ def update(
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
if multi_batch_enabled:
batch_size = key_states.size(0)
k_out[:batch_size, page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[:batch_size, page_idx, page_offset, :, self.kv_lora_rank:] = value_states
else:
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
return k_out, self.page_table_list[layer_idx]
else:
k_out[:, :, cache_position] = key_states
Expand Down
4 changes: 3 additions & 1 deletion ktransformers/operators/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
if torch.xpu.is_available():
return self.forward_xpu(
hidden_states,
Expand All @@ -707,7 +708,8 @@ def forward(
elif (os.name == 'nt'
or get_compute_capability() < 8
or hidden_states.device.type == 'cpu'
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
or device_manager.gpu_vendor != GPUVendor.NVIDIA
or multi_batch_enabled):
return self.forward_windows(
hidden_states,
attention_mask,
Expand Down
1 change: 1 addition & 0 deletions ktransformers/operators/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> t
padding_input[:,:self.orin_in_features] = x
x = padding_input
marlin_s = self.marlin_s.to(x.dtype)
x = x.contiguous()
x = KTransformersOps.gptq_marlin_gemm(
x,
self.marlin_q_w,
Expand Down
4 changes: 3 additions & 1 deletion ktransformers/operators/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,10 +669,12 @@ def forward(
if per_layer_prefill_flag:
causal_mask = None
else:
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
if (os.name == 'nt'
or get_compute_capability() < 8
or (self.transfer_map is not None and 'cpu' in self.transfer_map.values())
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
or device_manager.gpu_vendor != GPUVendor.NVIDIA
or multi_batch_enabled):
# print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(
Expand Down
Loading