Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Tuple

import torch
Expand Down Expand Up @@ -183,6 +184,8 @@ def update_step_context(cls, step_context):
def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,
backend_config: BackendConfig, device: torch.device):
"""Build graph runner."""
from lmdeploy.pytorch import envs

from .graph_runner import CUDAGraphRunner
from .warmup_manager import WarmupMeta, get_warmup_manager

Expand All @@ -194,6 +197,10 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_
)
get_warmup_manager().warmup(warmup_meta)

# add custom triton cache manager
if envs.triton_custom_cache_mgr_enable:
os.environ['TRITON_CACHE_MANAGER'] = 'lmdeploy.pytorch.kernels.cuda.triton_utils:MPLockCacheManager'

# make graph runner.
return CUDAGraphRunner(model, model_config, cache_config, backend_config, device)

Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def _patched_get_env(
# logging
log_file = os.getenv('LMDEPLOY_LOG_FILE', None)

# triton
triton_custom_cache_mgr_enable = env_to_bool('LMDEPLOY_TRITON_CUSTOM_CACHE_MGR_ENABLE', False)

# dlblas
# we don't need to read this, it would be passed to ray workers
# If Ray is launched from outside, it may fail to access the environment variables.
Expand Down
2 changes: 0 additions & 2 deletions lmdeploy/pytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .fill_kv_cache import fill_kv_cache
from .fused_moe import fused_moe
from .fused_rotary_emb import fused_rotary_emb
from .multinomial_sampling import multinomial_sampling
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm
Expand All @@ -14,7 +13,6 @@
__all__ = [
'apply_rotary_pos_emb',
'fused_moe',
'fused_rotary_emb',
'paged_attention_fwd',
'alibi_paged_attention_fwd',
'fill_kv_cache',
Expand Down
2 changes: 0 additions & 2 deletions lmdeploy/pytorch/kernels/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .flashattention import flash_attention_fwd
from .flatten_kv_cache import flatten_kv_cache
from .fused_moe import fused_moe
from .fused_rotary_emb import fused_rotary_emb
from .multinomial_sampling import multinomial_sampling
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm
Expand All @@ -17,7 +16,6 @@
__all__ = [
'apply_rotary_pos_emb',
'fused_moe',
'fused_rotary_emb',
'paged_attention_fwd',
'alibi_paged_attention_fwd',
'fill_kv_cache',
Expand Down
23 changes: 5 additions & 18 deletions lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import triton.language as tl
from torch import Tensor

from .triton_utils import get_kernel_meta, wrap_jit_func

assert triton.__version__ >= '2.1.0'

LOG2 = tl.constexpr(math.log(2))
Expand Down Expand Up @@ -65,7 +63,6 @@ def _load_block_offsets(offset_ptr, block_id, num_sub_blocks: tl.constexpr, BLOC
return tl.load(offset_ptr + block_id) * BLOCK + offs_n


@wrap_jit_func
@triton.jit
def _fwd_split_kernel(
Q,
Expand Down Expand Up @@ -200,7 +197,6 @@ def _fwd_split_kernel(
tl.store(Acc_out + off_meta + 1 + tl.arange(0, 1), l_i)


@wrap_jit_func
@triton.jit
def _reduce_split_kernel(
Acc,
Expand Down Expand Up @@ -244,7 +240,6 @@ def _reduce_split_kernel(
tl.store(Out + out_offs, acc)


@wrap_jit_func
@triton.jit
def _fwd_kernel(
Q,
Expand Down Expand Up @@ -375,7 +370,6 @@ def _fwd_kernel(
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)


@wrap_jit_func
@triton.jit
def _fwd_split_kernel_quant(
Q,
Expand Down Expand Up @@ -561,7 +555,6 @@ def _fwd_split_kernel_quant(
tl.store(Acc_out + off_meta + 1 + tl.arange(0, 1), l_i)


@wrap_jit_func
@triton.jit
def _fwd_kernel_quant(
Q,
Expand Down Expand Up @@ -802,7 +795,6 @@ def alibi_paged_attention_fwd(
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,

num_warps = 4 if Lq <= 64 else 8
kernel_meta = get_kernel_meta(q)
is_decoding = q.shape[-3] == b_seq_len.size(0)
if not is_decoding:
if quant_policy > 0:
Expand Down Expand Up @@ -846,8 +838,7 @@ def alibi_paged_attention_fwd(
BLOCK_DMODEL=Lq,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
num_stages=1)
else:
_fwd_kernel[grid](q,
k,
Expand Down Expand Up @@ -880,8 +871,7 @@ def alibi_paged_attention_fwd(
BLOCK_DMODEL=Lq,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
num_stages=1)
else:
SPLIT_K = 4
grid = (batch, head, SPLIT_K)
Expand Down Expand Up @@ -927,8 +917,7 @@ def alibi_paged_attention_fwd(
BLOCK_DMODEL=Lq,
BLOCK_N=BLOCK,
num_warps=4,
num_stages=1,
**kernel_meta)
num_stages=1)

else:
_fwd_split_kernel[grid](q,
Expand Down Expand Up @@ -961,8 +950,7 @@ def alibi_paged_attention_fwd(
BLOCK_DMODEL=Lq,
BLOCK_N=BLOCK,
num_warps=4,
num_stages=1,
**kernel_meta)
num_stages=1)

grid = (batch, head)
_reduce_split_kernel[grid](acc,
Expand All @@ -977,5 +965,4 @@ def alibi_paged_attention_fwd(
SPLIT_K=SPLIT_K,
BLOCK_DMODEL=Lq,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
num_stages=1)
4 changes: 1 addition & 3 deletions lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,7 @@ def _gemm_fp8_tma_kernel(
'BLOCK_N': 64,
}, num_stages=3, num_warps=4)
],
key=['N', 'K'],
warmup=5,
rep=10)
key=['N', 'K'])
@triton.jit
def _gemm_fp8_kernel(
A,
Expand Down
2 changes: 0 additions & 2 deletions lmdeploy/pytorch/kernels/cuda/fused_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def _atomic_store(ptrs, val, mask):
configs=get_autotune_config(),
key=['N', 'K'],
restore_value=['c_ptr'],
warmup=5,
rep=20,
)
@triton.jit
def _fused_lora_kernel(
Expand Down
141 changes: 0 additions & 141 deletions lmdeploy/pytorch/kernels/cuda/fused_rotary_emb.py

This file was deleted.

Loading