diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index ba567b30ce..d56c9e3599 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os from typing import Tuple import torch @@ -183,6 +184,8 @@ def update_step_context(cls, step_context): def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig, backend_config: BackendConfig, device: torch.device): """Build graph runner.""" + from lmdeploy.pytorch import envs + from .graph_runner import CUDAGraphRunner from .warmup_manager import WarmupMeta, get_warmup_manager @@ -194,6 +197,10 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_ ) get_warmup_manager().warmup(warmup_meta) + # add custom triton cache manager + if envs.triton_custom_cache_mgr_enable: + os.environ['TRITON_CACHE_MANAGER'] = 'lmdeploy.pytorch.kernels.cuda.triton_utils:MPLockCacheManager' + # make graph runner. return CUDAGraphRunner(model, model_config, cache_config, backend_config, device) diff --git a/lmdeploy/pytorch/envs.py b/lmdeploy/pytorch/envs.py index 03fa830311..3f5b1c4a02 100644 --- a/lmdeploy/pytorch/envs.py +++ b/lmdeploy/pytorch/envs.py @@ -98,6 +98,9 @@ def _patched_get_env( # logging log_file = os.getenv('LMDEPLOY_LOG_FILE', None) + # triton + triton_custom_cache_mgr_enable = env_to_bool('LMDEPLOY_TRITON_CUSTOM_CACHE_MGR_ENABLE', False) + # dlblas # we don't need to read this, it would be passed to ray workers # If Ray is launched from outside, it may fail to access the environment variables. diff --git a/lmdeploy/pytorch/kernels/__init__.py b/lmdeploy/pytorch/kernels/__init__.py index 7e775aa59f..5706739fbf 100644 --- a/lmdeploy/pytorch/kernels/__init__.py +++ b/lmdeploy/pytorch/kernels/__init__.py @@ -4,7 +4,6 @@ from .apply_rotary_pos_emb import apply_rotary_pos_emb from .fill_kv_cache import fill_kv_cache from .fused_moe import fused_moe -from .fused_rotary_emb import fused_rotary_emb from .multinomial_sampling import multinomial_sampling from .pagedattention import paged_attention_fwd from .rms_norm import rms_norm @@ -14,7 +13,6 @@ __all__ = [ 'apply_rotary_pos_emb', 'fused_moe', - 'fused_rotary_emb', 'paged_attention_fwd', 'alibi_paged_attention_fwd', 'fill_kv_cache', diff --git a/lmdeploy/pytorch/kernels/cuda/__init__.py b/lmdeploy/pytorch/kernels/cuda/__init__.py index f741a8053d..f4ae57714b 100644 --- a/lmdeploy/pytorch/kernels/cuda/__init__.py +++ b/lmdeploy/pytorch/kernels/cuda/__init__.py @@ -7,7 +7,6 @@ from .flashattention import flash_attention_fwd from .flatten_kv_cache import flatten_kv_cache from .fused_moe import fused_moe -from .fused_rotary_emb import fused_rotary_emb from .multinomial_sampling import multinomial_sampling from .pagedattention import paged_attention_fwd from .rms_norm import rms_norm @@ -17,7 +16,6 @@ __all__ = [ 'apply_rotary_pos_emb', 'fused_moe', - 'fused_rotary_emb', 'paged_attention_fwd', 'alibi_paged_attention_fwd', 'fill_kv_cache', diff --git a/lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py b/lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py index 7de6298323..b1e9ae7473 100644 --- a/lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py @@ -8,8 +8,6 @@ import triton.language as tl from torch import Tensor -from .triton_utils import get_kernel_meta, wrap_jit_func - assert triton.__version__ >= '2.1.0' LOG2 = tl.constexpr(math.log(2)) @@ -65,7 +63,6 @@ def _load_block_offsets(offset_ptr, block_id, num_sub_blocks: tl.constexpr, BLOC return tl.load(offset_ptr + block_id) * BLOCK + offs_n -@wrap_jit_func @triton.jit def _fwd_split_kernel( Q, @@ -200,7 +197,6 @@ def _fwd_split_kernel( tl.store(Acc_out + off_meta + 1 + tl.arange(0, 1), l_i) -@wrap_jit_func @triton.jit def _reduce_split_kernel( Acc, @@ -244,7 +240,6 @@ def _reduce_split_kernel( tl.store(Out + out_offs, acc) -@wrap_jit_func @triton.jit def _fwd_kernel( Q, @@ -375,7 +370,6 @@ def _fwd_kernel( tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) -@wrap_jit_func @triton.jit def _fwd_split_kernel_quant( Q, @@ -561,7 +555,6 @@ def _fwd_split_kernel_quant( tl.store(Acc_out + off_meta + 1 + tl.arange(0, 1), l_i) -@wrap_jit_func @triton.jit def _fwd_kernel_quant( Q, @@ -802,7 +795,6 @@ def alibi_paged_attention_fwd( grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, num_warps = 4 if Lq <= 64 else 8 - kernel_meta = get_kernel_meta(q) is_decoding = q.shape[-3] == b_seq_len.size(0) if not is_decoding: if quant_policy > 0: @@ -846,8 +838,7 @@ def alibi_paged_attention_fwd( BLOCK_DMODEL=Lq, BLOCK_N=BLOCK, num_warps=num_warps, - num_stages=1, - **kernel_meta) + num_stages=1) else: _fwd_kernel[grid](q, k, @@ -880,8 +871,7 @@ def alibi_paged_attention_fwd( BLOCK_DMODEL=Lq, BLOCK_N=BLOCK, num_warps=num_warps, - num_stages=1, - **kernel_meta) + num_stages=1) else: SPLIT_K = 4 grid = (batch, head, SPLIT_K) @@ -927,8 +917,7 @@ def alibi_paged_attention_fwd( BLOCK_DMODEL=Lq, BLOCK_N=BLOCK, num_warps=4, - num_stages=1, - **kernel_meta) + num_stages=1) else: _fwd_split_kernel[grid](q, @@ -961,8 +950,7 @@ def alibi_paged_attention_fwd( BLOCK_DMODEL=Lq, BLOCK_N=BLOCK, num_warps=4, - num_stages=1, - **kernel_meta) + num_stages=1) grid = (batch, head) _reduce_split_kernel[grid](acc, @@ -977,5 +965,4 @@ def alibi_paged_attention_fwd( SPLIT_K=SPLIT_K, BLOCK_DMODEL=Lq, num_warps=num_warps, - num_stages=1, - **kernel_meta) + num_stages=1) diff --git a/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py index ab1bb213fe..6c5f0a6f53 100644 --- a/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py +++ b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py @@ -233,9 +233,7 @@ def _gemm_fp8_tma_kernel( 'BLOCK_N': 64, }, num_stages=3, num_warps=4) ], - key=['N', 'K'], - warmup=5, - rep=10) + key=['N', 'K']) @triton.jit def _gemm_fp8_kernel( A, diff --git a/lmdeploy/pytorch/kernels/cuda/fused_lora.py b/lmdeploy/pytorch/kernels/cuda/fused_lora.py index 75c987d635..93b3829de5 100644 --- a/lmdeploy/pytorch/kernels/cuda/fused_lora.py +++ b/lmdeploy/pytorch/kernels/cuda/fused_lora.py @@ -38,8 +38,6 @@ def _atomic_store(ptrs, val, mask): configs=get_autotune_config(), key=['N', 'K'], restore_value=['c_ptr'], - warmup=5, - rep=20, ) @triton.jit def _fused_lora_kernel( diff --git a/lmdeploy/pytorch/kernels/cuda/fused_rotary_emb.py b/lmdeploy/pytorch/kernels/cuda/fused_rotary_emb.py deleted file mode 100644 index 27f4d3a8ab..0000000000 --- a/lmdeploy/pytorch/kernels/cuda/fused_rotary_emb.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import triton -import triton.language as tl -from torch import Tensor - -from .triton_utils import get_kernel_meta, wrap_jit_func - - -@wrap_jit_func(type_hint=dict(Q=Tensor, - K=Tensor, - PostionIds=Tensor, - InvFreq=Tensor, - scaling_factor=float, - OutQ=Tensor, - OutK=Tensor, - stride_bq=int, - stride_sq=int, - stride_hq=int, - stride_dq=int, - stride_bk=int, - stride_sk=int, - stride_hk=int, - stride_dk=int, - stride_bp=int, - stride_sp=int, - max_seq_len=int, - BLOCK=torch.int32, - BLOCK_HQ=torch.int32, - BLOCK_HK=torch.int32, - BLOCK_F=torch.int32)) -@triton.jit -def _fused_rotary_emb_kernel(Q, K, PostionIds, InvFreq, scaling_factor, OutQ, OutK, stride_bq, stride_sq, - stride_hq: tl.constexpr, stride_dq: tl.constexpr, stride_bk, stride_sk, - stride_hk: tl.constexpr, stride_dk: tl.constexpr, stride_bp, stride_sp, max_seq_len, - BLOCK: tl.constexpr, BLOCK_HQ: tl.constexpr, BLOCK_HK: tl.constexpr, - BLOCK_F: tl.constexpr): - """Fused rotary emb kernel.""" - batch_id = tl.program_id(0) - seq_block_id = tl.program_id(1) - - s_off = seq_block_id * BLOCK + tl.arange(0, BLOCK)[:, None] - f_off = tl.arange(0, BLOCK_F)[None, :] - s_mask = s_off < max_seq_len - - bp_off = stride_bp * batch_id - p_off = bp_off + stride_sp * s_off - - sq_off = batch_id * stride_bq + s_off * stride_sq - q0_off = sq_off + f_off * stride_dq - q1_off = q0_off + BLOCK_F * stride_dq - - sk_off = batch_id * stride_bk + s_off * stride_sk - k0_off = sk_off + f_off * stride_dk - k1_off = k0_off + BLOCK_F * stride_dk - - inv_freq = tl.load(InvFreq + f_off).to(tl.float32) - position_ids = tl.load(PostionIds + p_off, mask=s_mask).to(tl.float32) - position_ids = position_ids / scaling_factor - - # pos_freq = tl.dot(position_ids, inv_freq) - pos_freq = position_ids * inv_freq - cos = tl.cos(pos_freq).to(Q.dtype.element_ty) - sin = tl.sin(pos_freq).to(Q.dtype.element_ty) - - for h in range(BLOCK_HQ): - q0 = tl.load(Q + q0_off + h * stride_hq, mask=s_mask) - q1 = tl.load(Q + q1_off + h * stride_hq, mask=s_mask) - q0_out = q0 * cos - q1 * sin - tl.store(OutQ + q0_off + h * stride_hq, q0_out, mask=s_mask) - q1_out = q1 * cos + q0 * sin - tl.store(OutQ + q1_off + h * stride_hq, q1_out, mask=s_mask) - - for h in range(BLOCK_HK): - k0 = tl.load(K + k0_off + h * stride_hk, mask=s_mask) - k1 = tl.load(K + k1_off + h * stride_hk, mask=s_mask) - k0_out = k0 * cos - k1 * sin - tl.store(OutK + k0_off + h * stride_hk, k0_out, mask=s_mask) - k1_out = k1 * cos + k0 * sin - tl.store(OutK + k1_off + h * stride_hk, k1_out, mask=s_mask) - - -def fused_rotary_emb(q: Tensor, - k: Tensor, - position_ids: torch.LongTensor, - inv_freq: Tensor, - scaling_factor: float, - out_q: Tensor = None, - out_k: Tensor = None): - """Fuse `rotary_embedding` and `apply_rotary_pos_emb`.""" - - if out_q is None: - out_q = torch.empty_like(q) - else: - assert q.stride() == out_q.stride() - if out_k is None: - out_k = torch.empty_like(k) - else: - assert k.stride() == out_k.stride() - - assert q.dim() == 4 - assert k.dim() == 4 - assert q.size(0) == position_ids.size(0) - - BLOCK = 32 - BLOCK_HQ = q.size(-2) - BLOCK_HK = k.size(-2) - BLOCK_F = q.size(-1) // 2 - batch_size = q.size(0) - max_seq_len = q.size(1) - kernel_meta = get_kernel_meta(q) - num_warps = 4 - - grid = (batch_size, triton.cdiv(max_seq_len, BLOCK)) - _fused_rotary_emb_kernel[grid](q, - k, - position_ids, - inv_freq, - scaling_factor, - out_q, - out_k, - stride_bq=q.stride(0), - stride_sq=q.stride(1), - stride_hq=q.stride(2), - stride_dq=q.stride(3), - stride_bk=k.stride(0), - stride_sk=k.stride(1), - stride_hk=k.stride(2), - stride_dk=k.stride(3), - stride_bp=position_ids.stride(0), - stride_sp=position_ids.stride(1), - max_seq_len=max_seq_len, - BLOCK=BLOCK, - BLOCK_HQ=BLOCK_HQ, - BLOCK_HK=BLOCK_HK, - BLOCK_F=BLOCK_F, - num_warps=num_warps, - num_stages=1, - **kernel_meta) - - return out_q, out_k diff --git a/lmdeploy/pytorch/kernels/cuda/triton_utils.py b/lmdeploy/pytorch/kernels/cuda/triton_utils.py index 2cc4dad25a..56713b06a0 100644 --- a/lmdeploy/pytorch/kernels/cuda/triton_utils.py +++ b/lmdeploy/pytorch/kernels/cuda/triton_utils.py @@ -1,459 +1,83 @@ # Copyright (c) OpenMMLab. All rights reserved. -import functools -import inspect -from typing import Callable, Dict, Sequence, Union, cast, overload +import os.path as osp +from typing import Dict, Sequence, Union -import torch -import triton -from packaging import version -from triton import JITFunction +from triton.runtime.cache import FileCacheManager -if version.parse(triton.__version__) <= version.parse('2.2.0'): +from lmdeploy.utils import get_logger - def get_kernel_meta(tensor: torch.Tensor): - """Kernel meta.""" - from triton.runtime.jit import get_cuda_stream - - device = tensor.device - device_idx = device.index - device_type = device.type - stream = get_cuda_stream(device_idx) - return dict(device=device, device_type=device_type, stream=stream) -else: - - KERNEL_META = dict() - - def get_kernel_meta(tensor: torch.Tensor): - """Kernel meta.""" - return KERNEL_META - - -TRITON_DIVIIBILITY = getattr(JITFunction, 'divisibility', 16) -TRITON_DIVIIBILITY_8 = getattr(JITFunction, 'divisibility_8', 8) +logger = get_logger('lmdeploy') TypeHintType = Union[Dict[str, type], Sequence[type], None] -def _check_type_hint(jit_func: JITFunction, type_hint: TypeHintType): - """Check type hint.""" - params = jit_func.params - arg_key = tuple(p.name for p in params) - - if isinstance(type_hint, Dict): - for key in arg_key: - if key not in type_hint: - type_hint[key] = None - return type_hint - elif type_hint is None: - return dict((key, None) for key in arg_key) - elif isinstance(type_hint, Sequence): - assert len(arg_key) == len(type_hint) - return dict(zip(arg_key, type_hint)) - else: - raise RuntimeError(f'Unknown type_hint: {type_hint}') - - -class JitFunction220Wrapper: - - def __init__(self, jit_func: JITFunction, type_hint: TypeHintType = None): - """Jit func.""" - self.jit_func = jit_func - self.type_hint = _check_type_hint(jit_func, type_hint) - self.run = self._make_launcher(jit_func) - self.arg_names = jit_func.arg_names - - self.__doc__ = jit_func.__doc__ - self.__name__ = jit_func.__name__ - self.__globals__ = jit_func.__globals__ - self.__module__ = jit_func.__module__ - - @staticmethod - def _specialization_key(value): - if isinstance(value, int): - # bool is a subclass of int, so we don't check explicitly above. - return ( - value % TRITON_DIVIIBILITY == 0, - value % TRITON_DIVIIBILITY_8 == 0, - value == 1, - ) - - if hasattr(value, 'data_ptr'): - return (value.data_ptr() % TRITON_DIVIIBILITY == 0, ) - - return (False, ) - - def _make_launcher(self, jit_func: triton.JITFunction): - """Make input builder.""" - - from triton.common.backend import get_backend, get_cuda_version_key - from triton.compiler import CompiledKernel, get_arch_default_num_stages, get_arch_default_num_warps - - def _make_spec_key_str(key): - anno = self.type_hint[key] - if anno == torch.Tensor: - return f'({key}.data_ptr() % {TRITON_DIVIIBILITY} == 0, )' - elif anno in [int, bool, torch.int32, torch.int64]: - return (f'({key} % {TRITON_DIVIIBILITY} == 0, ' - f'{key} % {TRITON_DIVIIBILITY_8} == 0, ' - f'{key} == 1, )') - elif anno is not None: - return '(False,)' - return f'_specialization_key({key})' - - def _make_sig_key_str(key): - anno = self.type_hint[key] - if anno == bool: - return '"i1"' - elif anno == float: - return '"fp32"' - elif anno == torch.int32: - return '"i32"' - elif anno == torch.int64: - return '"i64"' - elif anno == torch.Tensor: - return f'{key}.dtype' - return f'_key_of({key})' - - fn = jit_func.fn - params = jit_func.params - - # arg key - arg_key = tuple(p.name for p in params) - arg_key_str = ', '.join(arg_key) - grid_args = ','.join([f'{arg}={arg}' for arg in arg_key]) - args_signature = ', '.join(p.name if p.default == inspect._empty else f'{p.name} == {p.default}' - for p in params) - - # constexpr key - constexpr_key = tuple(p.name for p in params if p.is_constexpr) - constexpr_key_str = ', '.join(constexpr_key) - - # sig key - sig_key = tuple(p.name for p in params if not p.is_constexpr) - sig_name_str = ', '.join(key for key in sig_key) - sig_key_str = ', '.join(_make_sig_key_str(key) for key in sig_key) - - # spec key - spec_key = tuple(p.name for p in params if not p.do_not_specialize) - spec_key_str = ', '.join(_make_spec_key_str(key) for key in spec_key) - - # options - cuda_opt_fields = dict( - num_warps=None, - num_ctas=1, - num_stages=None, - enable_warp_specialization=False, - enable_fp_fusion=True, - extern_libs=None, - stream=None, - device=None, - device_type=None, - ) - cuda_opt_signature = ', '.join(f'{k} = {v}' for k, v in cuda_opt_fields.items()) - cuda_opt_args = ', '.join(f'{k}={k}' for k in cuda_opt_fields) - src = f""" -def _{fn.__name__}_launcher({args_signature}, grid=None, {cuda_opt_signature}, warmup=False, **kwargs): - debug=jit_func.debug - device_backend = None - - if device_type not in ["cuda"]: - device_backend = get_backend(device_type) - if device_backend is None: - raise ValueError("Cannot find backend for " + device_type) - - if num_warps is None: - num_warps = get_arch_default_num_warps(device_type) - if num_stages is None: - num_stages = get_arch_default_num_stages(device_type) - - if device_type in ["cuda"]: - version_key = get_cuda_version_key() - else: - version_key = device_backend.get_version_key() - - sig_key = ({sig_key_str}, ) - spec_key = ({spec_key_str}, ) - constexpr_key = ({constexpr_key_str}, ) - key = ( - version_key, - sig_key, - constexpr_key, - spec_key, - num_warps, - num_ctas, - num_stages, - enable_warp_specialization, - enable_fp_fusion, - debug, - ) - if extern_libs is not None: - key = (key, tuple(extern_libs.items())) - - bin = kernel_cache[device].get(key, None) - if bin is None: - return jit_func[grid]({arg_key_str}, {cuda_opt_args}, **kwargs) - - non_constexpr_arg_values = ({sig_name_str}) - if callable(grid): - grid = grid(dict({grid_args})) - grid_size = len(grid) - grid_0 = grid[0] - grid_1 = grid[1] if grid_size > 1 else 1 - grid_2 = grid[2] if grid_size > 2 else 1 - if not hasattr(bin, 'tensormaps_info'): - bin.c_wrapper( - grid_0, - grid_1, - grid_2, - bin.num_warps, - bin.num_ctas, - *bin.clusterDims, - bin.shared, - stream, - bin.cu_function, - launch_enter_hook, - launch_exit_hook, - bin, - {sig_name_str}, - ) - else: - bin.c_wrapper( - grid_0, - grid_1, - grid_2, - bin.num_warps, - bin.num_ctas, - *bin.clusterDims, - bin.shared, - stream, - bin.cu_function, - launch_enter_hook, - launch_exit_hook, - bin, - *bin.assemble_tensormap_to_arg(non_constexpr_arg_values), - ) - - return bin -""" # noqa: E501 - scope = dict( - get_backend=get_backend, - get_arch_default_num_stages=get_arch_default_num_stages, - get_arch_default_num_warps=get_arch_default_num_warps, - _specialization_key=self._specialization_key, - get_cuda_version_key=get_cuda_version_key, - jit_func=jit_func, - _key_of=JITFunction._key_of, - kernel_cache=jit_func.cache, - launch_enter_hook=CompiledKernel.launch_enter_hook, - launch_exit_hook=CompiledKernel.launch_exit_hook, - ) - exec(src, scope) - return scope[f'_{fn.__name__}_launcher'] - - def __getitem__(self, grid): - """Get item.""" - return functools.partial(cast(Callable, self.run), grid=grid) - - -class JitFunction230Wrapper: - - def __init__(self, jit_func: JITFunction, type_hint: TypeHintType = None): - """Jit func.""" - self.jit_func = jit_func - self.type_hint = _check_type_hint(jit_func, type_hint) - self.run = self._make_launcher(jit_func) - self.arg_names = jit_func.arg_names - - self.__doc__ = jit_func.__doc__ - self.__name__ = jit_func.__name__ - self.__globals__ = jit_func.__globals__ - self.__module__ = jit_func.__module__ - - @staticmethod - @functools.lru_cache - def build_cuda_options(*args, **kwargs): - from triton.compiler.backends.cuda import CUDAOptions - return CUDAOptions(*args, **kwargs) - - @staticmethod - def _specialization_key(value): - if hasattr(value, 'data_ptr'): - return (value.data_ptr() % TRITON_DIVIIBILITY == 0, ) - - if isinstance(value, int): - # bool is a subclass of int, so we don't check explicitly above. - return ( - value % TRITON_DIVIIBILITY == 0, - value % TRITON_DIVIIBILITY_8 == 0, - value == 1, - ) - - return (False, ) - - def _make_launcher(self, jit_func: JITFunction): - """Make input builder.""" - from dataclasses import fields - - from triton.common.backend import get_cuda_version_key - from triton.compiler import CompiledKernel - from triton.compiler.backends.cuda import CUDABackend, CUDAOptions - from triton.runtime.driver import driver - - def _make_spec_key_str(key): - anno = self.type_hint[key] - if anno == torch.Tensor: - return f'({key}.data_ptr() % {TRITON_DIVIIBILITY} == 0, )' - elif anno in [int, bool, torch.int32, torch.int64]: - return (f'({key} % {TRITON_DIVIIBILITY} == 0, ' - f'{key} % {TRITON_DIVIIBILITY_8} == 0, ' - f'{key} == 1, )') - elif anno is not None: - return '(False,)' - return f'_specialization_key({key})' - - def _make_sig_key_str(key): - anno = self.type_hint[key] - if anno == bool: - return '"i1"' - elif anno == float: - return '"fp32"' - elif anno == torch.int32: - return '"i32"' - elif anno == torch.int64: - return '"i64"' - elif anno == torch.Tensor: - return f'{key}.dtype' - return f'_key_of({key})' - - fn = jit_func.fn - params = jit_func.params - - # arg key - arg_key = tuple(p.name for p in params) - arg_key_str = ', '.join(arg_key) - grid_args = ','.join([f'{arg}={arg}' for arg in arg_key]) - args_signature = ', '.join(p.name if p.default == inspect._empty else f'{p.name} == {p.default}' - for p in params) - - # constexpr key - constexpr_key = tuple(p.name for p in params if p.is_constexpr) - constexpr_key_str = ', '.join(constexpr_key) - - # sig key - sig_key = tuple(p.name for p in params if not p.is_constexpr) - sig_name_str = ', '.join(key for key in sig_key) - sig_key_str = ', '.join(_make_sig_key_str(key) for key in sig_key) - - # spec key - spec_key = tuple(p.name for p in params if not p.do_not_specialize) - spec_key_str = ', '.join(_make_spec_key_str(key) for key in spec_key) - - # cuda opt key/default - cuda_opt_fields = dict((f.name, f.default) for f in fields(CUDAOptions)) - cuda_opt_fields['debug'] = jit_func.debug - cuda_opt_signature = ', '.join(f'{k} = {v}' for k, v in cuda_opt_fields.items()) - cuda_opt_args = ', '.join(f'{k}={k}' for k in cuda_opt_fields) - - triton_version = version.parse(triton.__version__) - if triton_version == version.parse('2.3.0'): - mni_acc_default = '0 if target[1] >= 89 else None' +class MPLockCacheManager(FileCacheManager): + """A cache manager that uses a lock to ensure thread safety.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + logger.debug(f'Create MPLockCacheManager with key={self.key}') + self._lock_map = dict() + + def _acquire_lock(self, lock_path, timeout=5): + """Acquire an exclusive lock on the file.""" + import filelock + logger.debug(f'Acquiring lock for {lock_path}') + full_lock_path = osp.join(self.cache_dir, f'{lock_path}.lock') + lock = filelock.FileLock(full_lock_path) + + lock.acquire(timeout=timeout) + self._lock_map[lock_path] = lock + + def _release_lock(self, lock_path): + """Release the lock.""" + if lock_path not in self._lock_map: + return + logger.debug(f'Release lock for {lock_path}') + lock_file = self._lock_map.pop(lock_path) + lock_file.release() + + def _group_is_ready(self, filename: str, group: dict) -> bool: + """Check if the group is ready.""" + if not isinstance(group, dict): + return False + return filename in group + + def get_group(self, filename: str) -> Dict[str, str]: + out = super().get_group(filename) + if self._group_is_ready(filename, out): + return out + + # lock if group is not ready + self._acquire_lock(filename) + out = super().get_group(filename) + + if self._group_is_ready(filename, out): + self._release_lock(filename) + return out + + def get_file(self, filename) -> str: + out = super().get_file(filename) + if out is not None: + return out + + # lock if file is not ready + self._acquire_lock(filename) + # try get file again if other process has put the file + out = super().get_file(filename) + + # release lock if file exists + if out is not None: + self._release_lock(filename) + return out + + def put(self, data, filename, binary=True) -> str: + out = super().put(data, filename, binary) + logger.debug(f'Put file {filename}.') + if filename.startswith('__grp__'): + # release group + self._release_lock(filename[7:]) else: - mni_acc_default = '2**30 if target[1] == 90 else 0' - - src = f""" -def _{fn.__name__}_launcher({args_signature}, grid=None, {cuda_opt_signature}, warmup=False, **kwargs): - device = get_current_device() - stream = get_current_stream(device) - target = get_current_target() - allow_fp8e4nv = target[1] >= 89 - max_num_imprecise_acc_default = {mni_acc_default} - options = build_cuda_options({cuda_opt_args}, ) - sig_key = ({sig_key_str}, ) - spec_key = ({spec_key_str}, ) - constexpr_key = ({constexpr_key_str}, ) - key = (get_cuda_version_key(), sig_key, constexpr_key, spec_key, options) - - kernel = kernel_cache[device].get(key, None) - if kernel is None: - return jit_func[grid]({arg_key_str}, {cuda_opt_args}, **kwargs) - - args = ({sig_name_str}) - if callable(grid): - grid = grid(dict({grid_args})) - grid_size = len(grid) - grid_0 = grid[0] - grid_1 = grid[1] if grid_size > 1 else 1 - grid_2 = grid[2] if grid_size > 2 else 1 - tensormaps_info = kernel.metadata["tensormaps_info"] - if len(tensormaps_info) > 0: - kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas, - *kernel.cluster_dims, - kernel.shared, stream, kernel.function, launch_enter_hook, - launch_exit_hook, kernel, - *assemble_tensormap_to_arg(kernel.metadata["tensormaps_info"], args)) - else: - kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas, - *kernel.cluster_dims, - kernel.shared, stream, kernel.function, launch_enter_hook, - launch_exit_hook, kernel, - {sig_name_str}) - - return kernel -""" # noqa: E501 - scope = dict( - get_current_device=driver.get_current_device, - get_current_stream=driver.get_current_stream, - get_current_target=driver.get_current_target, - assemble_tensormap_to_arg=driver.assemble_tensormap_to_arg, - _specialization_key=self._specialization_key, - get_cuda_version_key=get_cuda_version_key, - CUDABackend=CUDABackend, - build_cuda_options=self.build_cuda_options, - jit_func=jit_func, - _key_of=JITFunction._key_of, - kernel_cache=jit_func.cache, - launch_enter_hook=CompiledKernel.launch_enter_hook, - launch_exit_hook=CompiledKernel.launch_exit_hook, - ) - exec(src, scope) - return scope[f'_{fn.__name__}_launcher'] - - def __getitem__(self, grid): - """Get item.""" - return functools.partial(cast(Callable, self.run), grid=grid) - - -@overload -def wrap_jit_func(func: JITFunction): - ... - - -@overload -def wrap_jit_func( - *, - type_hint: TypeHintType = None, -): - ... - - -def wrap_jit_func( - func: JITFunction = None, - *, - type_hint: TypeHintType = None, -): - """Wrap jit func.""" - - def decorator(func: JITFunction): - triton_version = version.parse(triton.__version__) - - if triton_version == version.parse('2.2.0'): - return JitFunction220Wrapper(func, type_hint) - if version.parse('2.2.0') < triton_version <= version.parse('2.3.1'): - return JitFunction230Wrapper(func, type_hint) - return func - - if func is not None: - return decorator(func) - else: - return decorator + self._release_lock(filename) + return out diff --git a/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py b/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py index 4b5fb0babb..69da32d08d 100644 --- a/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py +++ b/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py @@ -6,7 +6,6 @@ from .activation import silu_and_mul from .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize -from .triton_utils import get_kernel_meta from .w8a8_triton_kernels import per_token_quant_int8 @@ -50,8 +49,6 @@ def get_cuda_autotune_config(): @triton.autotune( configs=get_cuda_autotune_config(), key=['N', 'K', 'M_NP2'], - warmup=10, - rep=25, ) @triton.jit def fused_moe_w8a8_kernel( @@ -198,7 +195,6 @@ def _grid_fn(META): C = C.flatten(0, -2) grid = _grid_fn - kernel_meta = get_kernel_meta(A) fused_moe_w8a8_kernel[grid]( A, A_scale, @@ -226,7 +222,6 @@ def _grid_fn(META): reindex_c=reindex_c, M_NP2=M_NP2, ACCUMULATOR_DTYPE=accumulator_dtype, - **kernel_meta, ) diff --git a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py index 0ddc8d1126..bdff352823 100644 --- a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py +++ b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py @@ -6,7 +6,6 @@ from packaging import version from ..default.w8a8_kernels import per_channel_quant -from .triton_utils import get_kernel_meta TRITON_VERSION = version.parse(triton.__version__) if TRITON_VERSION >= version.parse('3.0.0'): @@ -29,8 +28,6 @@ }, num_stages=3, num_warps=8) ], key=['N', 'K'], - warmup=5, - rep=20, ) @triton.jit(do_not_specialize=['M']) def _linear( @@ -110,8 +107,6 @@ def _linear( }, num_stages=3, num_warps=8) ], key=['N', 'K'], - warmup=5, - rep=20, ) @triton.jit(do_not_specialize=['M']) def _linear_add(A, B, C, residual_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, @@ -185,7 +180,6 @@ def matmul_kernel_dynamic_quant(a, b, rms_scale, linear_scale, residual=None, bi def grid(META): return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) - kernel_meta = get_kernel_meta(a) if residual is not None: _linear_add[grid](a, b, @@ -203,8 +197,7 @@ def grid(META): GROUP_SIZE_M=8, rms_scale_ptr=rms_scale, linear_scale_ptr=linear_scale, - ACCUMULATOR_DTYPE=accumulator_dtype, - **kernel_meta) + ACCUMULATOR_DTYPE=accumulator_dtype) else: _linear[grid](a, b, @@ -221,8 +214,7 @@ def grid(META): GROUP_SIZE_M=8, rms_scale_ptr=rms_scale, linear_scale_ptr=linear_scale, - ACCUMULATOR_DTYPE=accumulator_dtype, - **kernel_meta) + ACCUMULATOR_DTYPE=accumulator_dtype) if bias is not None: c += bias @@ -288,7 +280,6 @@ def per_token_quant_int8(x, eps, quant_dtype=torch.int8): x = x.flatten(0, -2) assert x.stride(-1) == 1 # enqueue kernel - kernel_meta = get_kernel_meta(x) _per_token_quant_int8[(M, )](x, x_q, x_s, @@ -299,8 +290,7 @@ def per_token_quant_int8(x, eps, quant_dtype=torch.int8): BLOCK=BLOCK, Q_MAX=q_max, IS_FLOATING_POINT=quant_dtype.is_floating_point, - num_warps=num_warps, - **kernel_meta) + num_warps=num_warps) return x_q, x_s diff --git a/lmdeploy/pytorch/kernels/fused_rotary_emb.py b/lmdeploy/pytorch/kernels/fused_rotary_emb.py deleted file mode 100644 index 1d6fcf1b9a..0000000000 --- a/lmdeploy/pytorch/kernels/fused_rotary_emb.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .dispatcher import FunctionDispatcher - -fused_rotary_emb = FunctionDispatcher('fused_rotary_emb').make_caller() diff --git a/tests/pytorch/kernel/test_apply_rotary.py b/tests/pytorch/kernel/test_apply_rotary.py index a5058767ce..f978d7bd47 100644 --- a/tests/pytorch/kernel/test_apply_rotary.py +++ b/tests/pytorch/kernel/test_apply_rotary.py @@ -94,5 +94,8 @@ def test_apply_rotary(self, q_states, k_states, cos, sin, gt): if q_states.dtype == torch.float16: rtol = 1e-5 atol = 1e-3 + elif q_states.dtype == torch.bfloat16: + rtol = 1e-3 + atol = 1e-2 torch.testing.assert_close(q_embed, q_gt, rtol=rtol, atol=atol) torch.testing.assert_close(k_embed, k_gt, rtol=rtol, atol=atol) diff --git a/tests/pytorch/kernel/test_fused_rotary_emb.py b/tests/pytorch/kernel/test_fused_rotary_emb.py deleted file mode 100644 index 83179000de..0000000000 --- a/tests/pytorch/kernel/test_fused_rotary_emb.py +++ /dev/null @@ -1,114 +0,0 @@ -import pytest -import torch -from torch import nn - -from lmdeploy.pytorch.kernels.fused_rotary_emb import fused_rotary_emb - - -class DummyRotaryEmbedding(nn.Module): - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer('inv_freq', inv_freq, persistent=False) - - def forward(self, x, position_ids, seq_len=None): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype=x.dtype) - sin = emb.sin().to(dtype=x.dtype) - # backwards compatibility - return cos, sin - - -class DummyLinearScalingRotaryEmbedding(DummyRotaryEmbedding): - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def forward(self, x, position_ids, seq_len=None): - position_ids = position_ids.float() / self.scaling_factor - cos, sin = super().forward(x, position_ids, seq_len) - return cos, sin - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=2): - """Applies Rotary Position Embedding to the query and key tensors.""" - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class TestFusedRotaryEmb: - - @pytest.fixture - def dtype(self): - yield torch.float16 - - @pytest.fixture - def batch_size(self): - yield 2 - - @pytest.fixture - def head_dim(self): - yield 64 - - @pytest.fixture - def q_num_heads(self): - yield 4 - - @pytest.fixture - def k_num_heads(self): - yield 2 - - @pytest.fixture - def seq_len(self): - yield 100 - - @pytest.fixture - def q(self, batch_size, seq_len, q_num_heads, head_dim, dtype): - yield torch.rand(batch_size, seq_len, q_num_heads, head_dim, dtype=dtype).to('cuda') - - @pytest.fixture - def k(self, batch_size, seq_len, k_num_heads, head_dim, dtype): - yield torch.rand(batch_size, seq_len, k_num_heads, head_dim, dtype=dtype).to('cuda') - - @pytest.fixture - def position_ids(self, batch_size, seq_len): - yield torch.randint(0, seq_len + 100, (batch_size, seq_len)).cuda() - - @pytest.fixture - def rotary_emb(self, head_dim): - yield DummyLinearScalingRotaryEmbedding(head_dim, scaling_factor=1.0).to('cuda') - - @pytest.fixture - def gt(self, q, k, position_ids, rotary_emb): - with torch.inference_mode(): - cos, sin = rotary_emb(q, position_ids) - yield apply_rotary_pos_emb(q, k, cos, sin, position_ids=position_ids) - - def test_fused_rotary_emb(self, q, k, position_ids, rotary_emb, gt): - inv_freq = rotary_emb.inv_freq - scaling_factor = rotary_emb.scaling_factor - - with torch.inference_mode(): - outq, outk = fused_rotary_emb(q, k, position_ids, inv_freq, scaling_factor=scaling_factor) - - gtq, gtk = gt - torch.testing.assert_close(outq, gtq, atol=1e-3, rtol=1e-5) - torch.testing.assert_close(outk, gtk, atol=1e-3, rtol=1e-5)