From 7e61a00ece1e8de897d298b956bff6e91c497cf8 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Wed, 6 Nov 2024 17:06:13 +0800 Subject: [PATCH 01/21] Support medusa inference --- lmdeploy/api.py | 3 + lmdeploy/pytorch/config.py | 23 ++- lmdeploy/pytorch/configurations/medusa.py | 42 +++++ lmdeploy/pytorch/engine/engine.py | 79 ++++++++- lmdeploy/pytorch/engine/logits_process.py | 8 +- lmdeploy/pytorch/engine/model_agent.py | 83 ++++++++-- .../pytorch/engine/speculative_decoding.py | 37 +++++ lmdeploy/pytorch/model_inputs.py | 5 + lmdeploy/pytorch/models/medusa.py | 152 ++++++++++++++++++ lmdeploy/pytorch/models/module_map.py | 5 + lmdeploy/pytorch/nn/rejection_sampling.py | 138 ++++++++++++++++ .../weight_loader/model_weight_loader.py | 7 + lmdeploy/serve/async_engine.py | 10 +- 13 files changed, 570 insertions(+), 22 deletions(-) create mode 100644 lmdeploy/pytorch/configurations/medusa.py create mode 100644 lmdeploy/pytorch/engine/speculative_decoding.py create mode 100644 lmdeploy/pytorch/models/medusa.py create mode 100644 lmdeploy/pytorch/nn/rejection_sampling.py diff --git a/lmdeploy/api.py b/lmdeploy/api.py index 939065622a..e2db4f0b31 100644 --- a/lmdeploy/api.py +++ b/lmdeploy/api.py @@ -11,6 +11,7 @@ def pipeline(model_path: str, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, + speculative_model: str = None, log_level: str = 'ERROR', max_log_len: int = None, **kwargs): @@ -33,6 +34,7 @@ def pipeline(model_path: str, config instance. Default to None. chat_template_config (ChatTemplateConfig): chat template configuration. Default to None. + speculative_model (str): the path of the draft model. log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG] max_log_len(int): Max number of prompt characters or prompt tokens @@ -82,6 +84,7 @@ def pipeline(model_path: str, backend=backend, backend_config=backend_config, chat_template_config=chat_template_config, + speculative_model=speculative_model, max_log_len=max_log_len, **kwargs) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index c350f4b4cf..a9381890ee 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -108,6 +108,9 @@ class ModelConfig: hf_config: Any = None cogvlm_style: bool = False custom_module_map: Dict[str, setattr] = None + # medusa config + medusa_num_heads: int = None + medusa_num_layers: int = None def get_head_size(self): """get head size.""" @@ -129,12 +132,22 @@ def from_pretrained(cls, activations. Refer to `PyTorchEngineConfig` for details """ from transformers import AutoConfig - hf_config = AutoConfig.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=trust_remote_code) - if getattr(hf_config, 'model_type', None) in ['phi3']: - # phi3 + trust_remote_code leads to error when tp. + try: hf_config = AutoConfig.from_pretrained( - pretrained_model_name_or_path) + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code) + if getattr(hf_config, 'model_type', None) in ['phi3']: + # phi3 + trust_remote_code leads to error when tp. + hf_config = AutoConfig.from_pretrained( + pretrained_model_name_or_path) + except Exception as e: # noqa + from transformers import PretrainedConfig + hf_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code) + # medusa model config + if getattr(hf_config, 'medusa_num_heads', None) is not None: + setattr(hf_config, 'architectures', ['MedusaModel']) return cls.from_hf_config(hf_config, pretrained_model_name_or_path, dtype=dtype) diff --git a/lmdeploy/pytorch/configurations/medusa.py b/lmdeploy/pytorch/configurations/medusa.py new file mode 100644 index 0000000000..4935bc0e25 --- /dev/null +++ b/lmdeploy/pytorch/configurations/medusa.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.config import ModelConfig + +from .builder import AutoModelConfigBuilder + + +class MedusaModelConfigBuilder(AutoModelConfigBuilder): + + @classmethod + def condition(cls, hf_config): + """config.""" + return hf_config.architectures[0] == 'MedusaModel' + + @classmethod + def build(cls, hf_config, model_path: str = None): + """build.""" + from transformers import AutoConfig + base_config = AutoConfig.from_pretrained( + hf_config.base_model_name_or_path) + head_dim = base_config.hidden_size // base_config.num_attention_heads + # config is wrong + # https://huggingface.co/FasterDecoding/medusa-vicuna-7b-v1.3/blob/main/config.json#L3 + hf_config.medusa_num_heads = 5 + medusa_num_heads = hf_config.medusa_num_heads + medusa_num_layers = hf_config.medusa_num_layers + if getattr(hf_config, 'hidden_size', None) is None: + setattr(hf_config, 'hidden_size', base_config.hidden_size) + if getattr(hf_config, 'vocab_size', None) is None: + setattr(hf_config, 'vocab_size', base_config.vocab_size) + return ModelConfig( + hidden_size=base_config.hidden_size, + num_attention_heads=base_config.num_attention_heads, + num_layers=base_config.num_hidden_layers, + num_key_value_heads=base_config.num_key_value_heads, + bos_token_id=base_config.bos_token_id, + eos_token_id=base_config.eos_token_id, + head_dim=head_dim, + vocab_size=base_config.vocab_size, + hf_config=hf_config, + medusa_num_heads=medusa_num_heads, + medusa_num_layers=medusa_num_layers, + ) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index cffe13bbdb..9867e68a7a 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -3,7 +3,7 @@ import copy import os from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import numpy as np import torch @@ -84,12 +84,14 @@ class Engine: Args: model_path (str): The hugging face model path. engine_config (PytorchEngineConfig): The config of the Engine. + speculative_model (str): The path of the speculative model. trust_remote_code (bool): Trust remote code. """ def __init__(self, model_path: str, engine_config: PytorchEngineConfig = None, + speculative_model: Optional[str] = None, trust_remote_code: bool = True) -> None: if engine_config is None: engine_config = PytorchEngineConfig() @@ -150,6 +152,7 @@ def __init__(self, model_path, cache_config=cache_config, backend_config=backend_config, + speculative_model=speculative_model, trust_remote_code=trust_remote_code, adapters=adapters, tp=self.tp, @@ -176,6 +179,7 @@ def __init__(self, def from_pretrained(cls, pretrained_model_name_or_path: str, engine_config: PytorchEngineConfig = None, + speculative_model: Optional[str] = None, trust_remote_code: bool = True, **kwargs): """lmdeploy python inference engine. @@ -192,12 +196,14 @@ def from_pretrained(cls, "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. engine_config (PytorchEngineConfig): Pytorch engine config. + speculative_model (str): The path of the speculative model. trust_remote_code (bool): Trust remote code """ if len(kwargs) > 0: logger.debug(f'Get unexpected kwargs: {kwargs}') return cls(model_path=pretrained_model_name_or_path, engine_config=engine_config, + speculative_model=speculative_model, trust_remote_code=trust_remote_code) @property @@ -519,7 +525,10 @@ def _batch_stopping_criteria(self, token_ids: torch.Tensor, # one more step to cache last token(stop word) stopped = num_appendable_ids < 0 if stop_words is not None: - sw_stopped = (token_ids[:, None] == stop_words).any(1) + if len(stop_words.shape) == 1: + token_ids = token_ids[:, None] + # TODO speculative model supports multiple stop word + sw_stopped = (token_ids == stop_words).any(1) one_ids = torch.clamp_max(num_appendable_ids, 0) num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) @@ -550,6 +559,17 @@ def __get_last_logits(): return next_token_ids + def extract_tokens(self, token_ids, eos_token_ids): + """Token list containing eos.""" + if not isinstance(token_ids, list): + return [token_ids], token_ids in eos_token_ids + for i, token_id in enumerate(token_ids): + if token_id in eos_token_ids: + return token_ids[:i + 1], True + if token_id == -1: + break + return token_ids[:i], False + @logging_timer('UpdateRunning', logger) def update_running(self, running: SeqList, next_token_ids: torch.Tensor, stopped: torch.Tensor): @@ -559,12 +579,12 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, for token, msg, stop in zip(next_token_ids, running, stopped): if msg.status != MessageStatus.RUNNING: continue - update_token = token - stop = stop or token in eos_token_id + update_token, eos_stop = self.extract_tokens(token, eos_token_id) + stop = stop or eos_stop if stop: update_token = _EMPTY_TOKEN else: - msg.num_new_tokens += 1 + msg.num_new_tokens += len(update_token) msg.update_token_ids(update_token) if stop: msg.status = MessageStatus.STOPPED @@ -647,14 +667,24 @@ async def __long_context_single_forward(inputs): if not return_logits and not inputs.is_decoding: last_token_loc = inputs.seq_length.cumsum(0) - 1 ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] + if 'spec_hidden_states' in ret: + ret['spec_hidden_states'] = ret[ + 'spec_hidden_states'][:, last_token_loc] else: ret = await __long_context_single_forward(inputs) if not return_logits and not inputs.is_decoding: last_token_loc = [-1] ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] + if 'spec_hidden_states' in ret: + ret['spec_hidden_states'] = ret[ + 'spec_hidden_states'][:, last_token_loc] hidden_states = ret.pop('hidden_states') + spec_hidden_states = ret.pop('spec_hidden_states', None) logits = self.model_agent.get_logits(hidden_states) + if spec_hidden_states is not None: + spec_logits = self.model_agent.get_spec_logits(spec_hidden_states) + ret['spec_logits'] = spec_logits ret['logits'] = logits return ret @@ -669,7 +699,11 @@ def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence, return [] if token in msg.sampling_param.stop_words: return [] - return [token] + if isinstance(token, list): + token = [t for t in token if t != -1] + else: + token = [token] + return token def __get_q_start_loc(): inputs = self._inputs @@ -765,6 +799,37 @@ def __update_inputs(next_token_ids): logits, all_ids, guided_input_ids, sampling_inputs, inputs, num_ignore_eos > 0) num_ignore_eos = num_ignore_eos - 1 + if 'spec_logits' in output: + spec_logits = output['spec_logits'] + proposal_token_ids = self.async_sampling_logits( + spec_logits, all_ids, guided_input_ids, sampling_inputs, + inputs, num_ignore_eos > 0) + # score the proposals with the target model + spec_inputs = copy.deepcopy(inputs) + _, num_speculative_tokens = proposal_token_ids.shape + target_proposal_ids = torch.cat( + [next_token_ids.unsqueeze(-1), proposal_token_ids], -1) + spec_inputs.input_ids = target_proposal_ids.flatten( + ).unsqueeze(0) + spec_inputs.history_lengths += spec_inputs.seq_length + spec_inputs.seq_length = torch.ones_like( + spec_inputs.seq_length) * (num_speculative_tokens + 1) + spec_inputs.is_decoding = False + score_output = await self.model_agent.score_proposal( + spec_inputs, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map, + num_speculative_tokens=num_speculative_tokens) + from ..nn.rejection_sampling import RejectionSampler + rejection_sampler = RejectionSampler() + score_output = score_output.softmax(-1) + spec_logits = spec_logits.softmax(-1) + target_output = rejection_sampler.forward( + score_output, + draft_probs=spec_logits, + draft_token_ids=proposal_token_ids) + next_token_ids = torch.cat( + [next_token_ids[:, None], target_output], -1) # stopping criteria stopped, num_appendable_ids = self._batch_stopping_criteria( @@ -928,6 +993,8 @@ async def __step(): schedule_output = self.scheduler.schedule( is_prefill=prefill, prealloc_size=prefill_interval) + if self.model_agent.speculative_model is not None: + prefill = True in_que.put_nowait((prefill, schedule_output)) finish = False while not finish: diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 2ee2eaced2..9bf7589756 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -15,7 +15,11 @@ def _process_temperature_(scores: torch.Tensor, temperature: torch.Tensor): """process temperature.""" temperature = temperature.to(scores.dtype) - scores.div_(temperature[:, None]) + if len(scores.shape) == 3: + temperature = temperature[:, None, None] + else: # len==2 + temperature = temperature[:, None] + scores.div_(temperature) return scores @@ -23,6 +27,8 @@ def _process_bad_words_(scores: torch.Tensor, bad_words: torch.LongTensor, filter_value: float = -float('inf')): """process bad words.""" + if len(scores.shape) == 3: + bad_words = bad_words[:, :, None] mask = bad_words >= 0 bad_words = bad_words.where(mask, 0) filtered_scores = scores.gather(1, bad_words) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 918e64e782..b74dd425a0 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -212,6 +212,8 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, backend_config: BackendConfig, + speculative_model: str = None, + speculative_model_config: ModelConfig = None, adapters: Dict[str, str] = None, trust_remote_code: bool = True): super().__init__(model_config=model_config, cache_config=cache_config) @@ -220,6 +222,7 @@ def __init__(self, self._adapters = adapters self.patched_model = self._build_model(model_path, + self.model_config, adapters, device=device) @@ -232,6 +235,13 @@ def __init__(self, cache_config=cache_config, backend_config=backend_config, device=device) + self.speculative_model = None + if speculative_model is not None: + self.speculative_model_config = speculative_model_config + self.speculative_model = self._build_model( + speculative_model, + self.speculative_model_config, + device=device) self.cache_engine = CacheEngine(cache_config, model_config) @@ -239,21 +249,22 @@ def __init__(self, def _build_model(self, model_path: str, + model_config: ModelConfig, adapters: Dict[str, str] = None, device: torch.device = 'cuda'): """build patched model.""" - custom_module_map = self.model_config.custom_module_map + custom_module_map = model_config.custom_module_map if custom_module_map is not None: update_custom_module_map(custom_module_map) logger.info('build model.') - patched_model = build_patched_model(self.model_config, device=device) + patched_model = build_patched_model(model_config, device=device) logger.info('loading weights.') load_model_weights(patched_model, model_path, device=device) logger.info('loading adapters.') if adapters is not None: add_adapters(patched_model, adapters, - dtype=self.model_config.dtype, + dtype=model_config.dtype, device=device) return patched_model @@ -274,6 +285,16 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, world_size=1, stream=self.stream, ) + if self.speculative_model is not None: + inputs.last_hidden_states = output['hidden_states'] + spec_outputs = model_forward( + self.speculative_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + output['spec_hidden_states'] = spec_outputs['hidden_states'] return output def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, @@ -307,10 +328,42 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, self.stream.synchronize) return output + async def score_proposal(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap, num_speculative_tokens): + """model forward. + + Args: + inputs (Dict): The input data comes from _make_inputs. + swap_in_map (SwapMap): Cache maps to swap in. + swap_out_map (SwapMap): Cache maps to swap out. + num_speculative_tokens (int): The number of the proposal tokens. + """ + cache_swapping(self.cache_engine, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map) + spec_outputs = model_forward( + self.patched_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + await asyncio.get_event_loop().run_in_executor(None, + self.stream.synchronize) + hidden_states = spec_outputs['hidden_states'] + hidden_states = hidden_states.reshape( + [-1, num_speculative_tokens + 1, hidden_states.shape[-1]]) + logits = self.get_logits(hidden_states) + return logits + def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" return self.patched_model.get_logits(hidden_states) + def get_spec_logits(self, hidden_states_list: List[torch.Tensor]): + """get logits of model output.""" + return self.speculative_model.get_logits(hidden_states_list) + @torch.inference_mode() def _tp_build_model( @@ -753,6 +806,7 @@ def build_model_agent(model_path: str, cache_config: CacheConfig, backend_config: BackendConfig, trust_remote_code: bool, + speculative_model: str = None, adapters: Dict[str, str] = None, tp: int = 1, dtype: str = 'auto', @@ -764,6 +818,7 @@ def build_model_agent(model_path: str, cache_config (CacheConfig): config of kv cache backend_config (BackendConfig): config of backend devices trust_remote_code (bool): To use the remote modeling code or not + speculative_model (str): The path of the speculative model adapters (Dict): lora adapters tp (int): the number of devices to be used in tensor parallelism dtype (str): the data type of model weights and activations @@ -772,18 +827,28 @@ def build_model_agent(model_path: str, model_config = ModelConfig.from_pretrained( model_path, trust_remote_code=trust_remote_code, dtype=dtype) model_config.custom_module_map = custom_module_map + speculative_model_config = None + if speculative_model is not None: + speculative_model_config = ModelConfig.from_pretrained( + speculative_model, + trust_remote_code=trust_remote_code, + dtype=dtype) if tp == 1: - model_agent = BaseModelAgent(model_path, - model_config=model_config, - cache_config=cache_config, - backend_config=backend_config, - adapters=adapters, - trust_remote_code=trust_remote_code) + model_agent = BaseModelAgent( + model_path, + model_config=model_config, + cache_config=cache_config, + backend_config=backend_config, + speculative_model=speculative_model, + speculative_model_config=speculative_model_config, + adapters=adapters, + trust_remote_code=trust_remote_code) else: model_agent = TPModelAgent(model_path, model_config=model_config, cache_config=cache_config, backend_config=backend_config, + speculative_model=speculative_model, world_size=tp, adapters=adapters, trust_remote_code=trust_remote_code) diff --git a/lmdeploy/pytorch/engine/speculative_decoding.py b/lmdeploy/pytorch/engine/speculative_decoding.py new file mode 100644 index 0000000000..97a9547b5a --- /dev/null +++ b/lmdeploy/pytorch/engine/speculative_decoding.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +from lmdeploy.messages import PytorchEngineConfig +from lmdeploy.utils import get_logger, get_model + +from ..devices import get_device_manager +from .engine import Engine +from .model_agent import build_model_agent + +logger = get_logger('lmdeploy') + + +class SpeculativeDecodingEngine(Engine): + + def __init__(self, + model_path: str, + speculative_model: str = None, + engine_config: PytorchEngineConfig = None, + trust_remote_code: bool = True) -> None: + super().__init__(model_path, engine_config, trust_remote_code) + + if not os.path.exists(speculative_model): + speculative_model = get_model(speculative_model, + engine_config.download_dir, + engine_config.revision) + self.speculative_model = speculative_model + + with get_device_manager().context(self.device_context): + self.speculative_model_agent = build_model_agent( + speculative_model, + cache_config=self.cache_config, + backend_config=self.backend_config, + trust_remote_code=trust_remote_code, + tp=self.tp, + dtype=engine_config.dtype, + custom_module_map=engine_config.custom_module_map) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 669625d43d..5fa21fa6d6 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -122,6 +122,7 @@ class ModelInputs: mrope_inputs: MRopeModelInputs = None cross_attention_states: torch.Tensor = None history_cross_kv_seqlens: torch.LongTensor = None + last_hidden_states: torch.Tensor = None def update(self, input_ids: torch.LongTensor): """update input ids.""" @@ -217,6 +218,7 @@ class StepContext: cross_attention_states: torch.Tensor = None cross_kv_seqlens: torch.LongTensor = None kv_quant_policy: Literal[0, 4, 8] = 0 + last_hidden_states: torch.Tensor = None _outputs: Dict = field(default_factory=dict) @@ -251,6 +253,8 @@ def new( mrope_position_ids = inputs.mrope_inputs.get_inputs( history_seqlens, q_seqlens) + # for speculative decoding + last_hidden_states = inputs.last_hidden_states # kv_seqlens cross_attention_states = inputs.cross_attention_states if inputs.is_decoding: @@ -288,6 +292,7 @@ def new( vision_inputs=inputs.vision_inputs, mrope_position_ids=mrope_position_ids, cross_attention_states=cross_attention_states, + last_hidden_states=last_hidden_states, cross_kv_seqlens=inputs.history_cross_kv_seqlens, kv_quant_policy=kv_quant_policy, ) diff --git a/lmdeploy/pytorch/models/medusa.py b/lmdeploy/pytorch/models/medusa.py new file mode 100644 index 0000000000..8fef4fc392 --- /dev/null +++ b/lmdeploy/pytorch/models/medusa.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.models.llama import LlamaConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn.linear import build_rowwise_linear +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .utils.cudagraph import CudaGraphMixin + + +class ResBlock(nn.Module): + """A Residual Block module. + + This module performs a linear transformation followed by a SiLU activation, + and then adds the result to the original input, creating a residual + connection. + + Args: + hidden_size (int): The size of the hidden layers in the block. + """ + + def __init__(self, + hidden_size, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.linear = build_rowwise_linear(hidden_size, + hidden_size, + bias=True, + dtype=dtype, + device=device) + # Initialize as an identity mapping + torch.nn.init.zeros_(self.linear.weight) + # Use SiLU activation to keep consistent with the Llama model + self.act = nn.SiLU() + + def forward(self, x): + """Forward pass of the ResBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output after the residual connection and activation. + """ + return x + self.act(self.linear(x)) + + +class MedusaModel(nn.Module, CudaGraphMixin): + """The medusa model architecture.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: LlamaConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build medusa + self.medusa_head = nn.ModuleList([ + nn.Sequential( + *([ + ResBlock( + self.config.hidden_size, device=device, dtype=dtype) + ] * self.config.medusa_num_layers), + build_rowwise_linear(self.config.hidden_size, + self.config.vocab_size, + bias=False, + dtype=dtype, + device=device), + ) for _ in range(self.config.medusa_num_heads) + ]) + + def forward(self, last_hidden_states: torch.Tensor, + **kwargs) -> List[torch.Tensor]: + outputs = [head[0](last_hidden_states) for head in self.medusa_head] + outputs = torch.cat(outputs, 0) + return outputs + + def get_logits(self, hidden_states: List[torch.Tensor]): + """compute logits of the model output.""" + outputs = [] + for lm_head, hidden_state in zip(self.medusa_head, hidden_states): + outputs.append(lm_head(hidden_state)) + outputs = torch.stack(outputs, 1) + return outputs + + def support_cuda_graph( + self, + input_ids: torch.Tensor, + **kwargs, + ): + """support cudagraph.""" + seq_lens = input_ids.size(1) + if seq_lens <= 512: + return True + + # prevent oom on llama-3 70b + if self.config.num_hidden_layers >= 40: + return False + + return False + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + **kwargs, + ): + """prepare input.""" + return dict(last_hidden_states=context.last_hidden_states) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + name = 'medusa_head.' + name + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index bc6385d8b2..4da26627ef 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -173,4 +173,9 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mllama.MllamaForConditionalGeneration', }) +# medusa +MODULE_MAP.update({ + 'MedusaModel': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.medusa.MedusaModel', +}) CUSTOM_MODULE_MAP = dict() diff --git a/lmdeploy/pytorch/nn/rejection_sampling.py b/lmdeploy/pytorch/nn/rejection_sampling.py new file mode 100644 index 0000000000..34e076b1ca --- /dev/null +++ b/lmdeploy/pytorch/nn/rejection_sampling.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + + +class RejectionSampler(nn.Module): + """apply rejection sampling according to "Accelerating Large Language Model + Decoding with Speculative Sampling". + + https://arxiv.org/pdf/2302.01318 + """ + + def __init__(self): + super().__init__() + + @staticmethod + def create_uniform_random(batch_size: int, num_speculative_tokens: int, + device: torch.device, dtype: torch.dtype): + """Generate a batch of random uniform samples. + + Args: + batch_size (int): The batch size. + num_speculative_tokens (int): The number of speculative tokens. + device (torch.device): The device of the output tensor. + dtype (torch.dtype): The dtype of the output tensor. + + Returns: + random_uniform (Tensor): The uniform tensor, shape + (batch_size, num_speculative_tokens). + """ + return torch.rand(batch_size, + num_speculative_tokens, + device=device, + dtype=dtype) + + @staticmethod + def adjusted_distribution(target_probs_without_bonus: torch.Tensor, + draft_probs: torch.Tensor): + """Adjust the distribution from the draft_probs if needed. + + Args: + target_probs_without_bonus (Tensor): The probability distribution + over token ids from the target model, shape + (batch_size, num_speculative_tokens, vocab_size). + + draft_probs (Tensor): The probability distribution over token ids + from the draft model, shape (batch_size, + num_speculative_tokens, vocab_size). + Returns: + adjusted_probs (Tensor): The adjusted probability distribution, + shape (batch_size, num_speculative_tokens, vocab_size). + """ + adjusted_probs = (target_probs_without_bonus - + draft_probs).clamp_min_(0) + adjusted_probs = adjusted_probs / adjusted_probs.sum( + -1, keepdim=True).clamp_min_(1e-5) # clamp to avoid div zero + return adjusted_probs + + # TODO fuse into a triton kernel + # TODO add seed + def forward(self, target_probs: torch.Tensor, draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor) -> torch.Tensor: + """Reject sampling probs and return token_ids. + + Args: + target_probs (Tensor): The probability distribution + over token ids from the target model, shape + (batch_size, num_speculative_tokens + 1, vocab_size). + + draft_probs (Tensor): The probability distribution over token ids + from the draft model, shape (batch_size, + num_speculative_tokens, vocab_size). + + draft_token_ids (Tensor): The proposal token id from the draft + model, shape (batch_size, num_speculative_tokens). + + Returns: + output_token_ids (Tensor): Token ids sampled through rejection + sampling. shape (batch)size, num_speculative_tokens + 1). + """ + target_probs_without_bonus = target_probs[:, :-1] + batch_size, num_speculative_tokens, _ = draft_probs.shape + device = draft_probs.device + batch_indices = torch.arange(batch_size, device=device) + probs_indicies = torch.arange(num_speculative_tokens, device=device) + draft_token_probs = draft_probs[batch_indices[:, None], probs_indicies, + draft_token_ids] + target_token_probs = target_probs_without_bonus[batch_indices[:, None], + probs_indicies, + draft_token_ids] + # target model scores draft token ids + scores = target_token_probs / draft_token_probs + random_uniform = self.create_uniform_random(batch_size, + num_speculative_tokens, + device=device, + dtype=scores.dtype) + rejected = scores < random_uniform + rejected_mask = rejected.cumsum(-1) > 0 + accepted_mask = ~rejected_mask + rejected_mask = torch.cat( + [rejected_mask, + rejected_mask.new_ones(batch_size, 1)], -1) + reject_idx = rejected_mask.float().argmax(-1, False) + # compute adjusted token ids + adjusted_probs = self.adjusted_distribution(target_probs_without_bonus, + draft_probs) + adjusted_probs = torch.cat([adjusted_probs, target_probs[:, -1:]], 1) + adjusted_probs = adjusted_probs[batch_indices, reject_idx] + adjusted_token_ids = torch.multinomial(adjusted_probs, + num_samples=1, + replacement=True).squeeze(-1) + output_token_ids = draft_token_ids.new_full( + (batch_size, num_speculative_tokens + 1), -1) + output_token_ids[~rejected_mask] = draft_token_ids[accepted_mask] + output_token_ids[batch_indices, reject_idx] = adjusted_token_ids + return output_token_ids + + +def test_rejection_sampler(): + batch_size = 4 + num_speculative_tokens = 5 + vocab_size = 1024 + dtype = torch.float32 + device = torch.device('cuda') + target_logits_with_bonus = torch.rand( + (batch_size, num_speculative_tokens + 1, vocab_size), + dtype=dtype, + device=device) + draft_logits = torch.rand((batch_size, num_speculative_tokens, vocab_size), + dtype=dtype, + device=device) + draft_token_ids = torch.randint(0, + vocab_size, + (batch_size, num_speculative_tokens), + device=device) + rejection_sampler = RejectionSampler() + rejection_sampler.forward(target_logits_with_bonus, draft_logits, + draft_token_ids) diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index cb548614c7..26169432db 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -12,6 +12,8 @@ logger = get_logger('lmdeploy') +MEDUSA_WEIGHT_NAME = 'medusa_lm_head.pt' + def load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, **kwargs): @@ -55,6 +57,9 @@ def _get_weight_type(model_path: str, use_safetensors: bool = None): # Load from a sharded PyTorch checkpoint weight_type = 'pytorch' is_sharded = True + elif osp.isfile(osp.join(model_path, MEDUSA_WEIGHT_NAME)): + # Load from a medusa head + weight_type = 'medusa' else: raise RuntimeError('Unknown weight type.') @@ -83,6 +88,8 @@ def _get_weight_path(model_path: str, weight_type: str): weight_name = SAFE_WEIGHTS_NAME elif weight_type == 'pytorch': weight_name = WEIGHTS_NAME + elif weight_type == 'medusa': + weight_name = MEDUSA_WEIGHT_NAME else: raise RuntimeError('Unknown weight type.') diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 598977747c..140205b67f 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -126,6 +126,8 @@ class AsyncEngine(LogitsMixin): config instance. Default to none. chat_template_config (ChatTemplateConfig): chat template configuration. Default to None. + speculative_model (str): The path of the draft model. Only can be used + with pytorch backend for speculative decoding. max_log_len (int): Max number of prompt characters or prompt tokens being printed in log. Default: Unlimited """ @@ -137,6 +139,7 @@ def __init__(self, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, + speculative_model: Optional[str] = None, max_log_len: int = None, **kwargs) -> None: logger.info( @@ -155,12 +158,15 @@ def __init__(self, # build backend engine if backend == 'turbomind': + assert speculative_model is None, 'plese use '\ + '--backend pytorch to use speculative decoding' self._build_turbomind(model_path=model_path, backend_config=backend_config, **kwargs) elif backend == 'pytorch': self._build_pytorch(model_path=model_path, backend_config=backend_config, + speculative_model=speculative_model, **kwargs) else: raise ValueError(f'unsupported backend {backend}') @@ -204,11 +210,13 @@ def _build_pytorch( model_path: str, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, + speculative_model: Optional[str] = None, **kwargs): """Innter build method for pytorch backend.""" from lmdeploy.pytorch.engine import Engine self.engine = Engine(model_path=model_path, - engine_config=backend_config) + engine_config=backend_config, + speculative_model=speculative_model) self.backend_config = self.engine.engine_config self.hf_tm_cfg = getattr(self.engine.model_config, 'hf_config', None) From 7f46c3c24d58eaab6d198c4884a3f6858f146e13 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Wed, 6 Nov 2024 17:26:20 +0800 Subject: [PATCH 02/21] fix --- lmdeploy/pytorch/engine/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 9867e68a7a..badfb88509 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -561,7 +561,7 @@ def __get_last_logits(): def extract_tokens(self, token_ids, eos_token_ids): """Token list containing eos.""" - if not isinstance(token_ids, list): + if not isinstance(token_ids, np.ndarray): return [token_ids], token_ids in eos_token_ids for i, token_id in enumerate(token_ids): if token_id in eos_token_ids: From 45d250367188df4598dba120d34d9addeb91cf43 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Thu, 7 Nov 2024 11:44:29 +0800 Subject: [PATCH 03/21] fix medusa inference and vicuna template --- lmdeploy/model.py | 2 +- lmdeploy/pytorch/models/medusa.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/lmdeploy/model.py b/lmdeploy/model.py index f251ca18d2..c2cd614f44 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -296,7 +296,7 @@ def __init__( eosys=' ', user='USER: ', eoh=' ', - assistant='ASSISTANT: ', + assistant='ASSISTANT:', eoa='', stop_words=[''], **kwargs): diff --git a/lmdeploy/pytorch/models/medusa.py b/lmdeploy/pytorch/models/medusa.py index 8fef4fc392..931bb4e8b9 100644 --- a/lmdeploy/pytorch/models/medusa.py +++ b/lmdeploy/pytorch/models/medusa.py @@ -33,8 +33,6 @@ def __init__(self, bias=True, dtype=dtype, device=device) - # Initialize as an identity mapping - torch.nn.init.zeros_(self.linear.weight) # Use SiLU activation to keep consistent with the Llama model self.act = nn.SiLU() @@ -97,8 +95,8 @@ def forward(self, last_hidden_states: torch.Tensor, def get_logits(self, hidden_states: List[torch.Tensor]): """compute logits of the model output.""" outputs = [] - for lm_head, hidden_state in zip(self.medusa_head, hidden_states): - outputs.append(lm_head(hidden_state)) + for medusa_head, hidden_state in zip(self.medusa_head, hidden_states): + outputs.append(medusa_head[-1](hidden_state)) outputs = torch.stack(outputs, 1) return outputs From 1459af82015591147a5e0bfd2f1ce0ced3892620 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 8 Nov 2024 14:47:10 +0800 Subject: [PATCH 04/21] fix finish --- lmdeploy/pytorch/engine/engine.py | 26 ++++++++++++++++------- lmdeploy/pytorch/nn/rejection_sampling.py | 2 +- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index badfb88509..0b2384063d 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -695,13 +695,18 @@ def _make_infer_outputs(self, next_token_ids: torch.LongTensor, def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence, stopped: bool): """check if output is necessary.""" - if stopped: - return [] - if token in msg.sampling_param.stop_words: - return [] if isinstance(token, list): - token = [t for t in token if t != -1] + idx = len(token) + for i, t in enumerate(token): + if t == -1: + idx = i + break + token = token[:idx] else: + if stopped: + return [] + if token in msg.sampling_param.stop_words: + return [] token = [token] return token @@ -806,7 +811,7 @@ def __update_inputs(next_token_ids): inputs, num_ignore_eos > 0) # score the proposals with the target model spec_inputs = copy.deepcopy(inputs) - _, num_speculative_tokens = proposal_token_ids.shape + batch_size, num_speculative_tokens = proposal_token_ids.shape target_proposal_ids = torch.cat( [next_token_ids.unsqueeze(-1), proposal_token_ids], -1) spec_inputs.input_ids = target_proposal_ids.flatten( @@ -814,7 +819,6 @@ def __update_inputs(next_token_ids): spec_inputs.history_lengths += spec_inputs.seq_length spec_inputs.seq_length = torch.ones_like( spec_inputs.seq_length) * (num_speculative_tokens + 1) - spec_inputs.is_decoding = False score_output = await self.model_agent.score_proposal( spec_inputs, swap_in_map=swap_in_map, @@ -824,12 +828,18 @@ def __update_inputs(next_token_ids): rejection_sampler = RejectionSampler() score_output = score_output.softmax(-1) spec_logits = spec_logits.softmax(-1) - target_output = rejection_sampler.forward( + target_output, last_accpet = rejection_sampler.forward( score_output, draft_probs=spec_logits, draft_token_ids=proposal_token_ids) next_token_ids = torch.cat( [next_token_ids[:, None], target_output], -1) + # truncate final outputs to appendable length + batch_indices = torch.arange(batch_size, + device=score_output.device) + num_appendable_ids = num_appendable_ids - last_accpet - 1 + max_len_ids = (num_appendable_ids - 1).clamp_max(-1) + next_token_ids[batch_indices, max_len_ids] = -1 # stopping criteria stopped, num_appendable_ids = self._batch_stopping_criteria( diff --git a/lmdeploy/pytorch/nn/rejection_sampling.py b/lmdeploy/pytorch/nn/rejection_sampling.py index 34e076b1ca..a89697cf5a 100644 --- a/lmdeploy/pytorch/nn/rejection_sampling.py +++ b/lmdeploy/pytorch/nn/rejection_sampling.py @@ -113,7 +113,7 @@ def forward(self, target_probs: torch.Tensor, draft_probs: torch.Tensor, (batch_size, num_speculative_tokens + 1), -1) output_token_ids[~rejected_mask] = draft_token_ids[accepted_mask] output_token_ids[batch_indices, reject_idx] = adjusted_token_ids - return output_token_ids + return output_token_ids, reject_idx def test_rejection_sampler(): From 69940bd9ffa7dc8bc1232717308dd6b0b008c250 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 8 Nov 2024 15:32:54 +0800 Subject: [PATCH 05/21] Fix chat --- lmdeploy/pytorch/paging/scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 8879863092..1b43615b96 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -115,7 +115,7 @@ def add_sequence(self, seq: SchedulerSequence): self._set_message_status(seq, MessageStatus.WAITING) @logging_timer('SchedulePrefilling', logger) - def _schedule_prefill(self): + def _schedule_prefill(self, prealloc_size: int = 0): """Schedule for prefilling.""" current_running = self.running @@ -140,7 +140,7 @@ def __evict_for_seq(seq: SchedulerSequence, waiting): hanging = reversed(self.hanging) waiting = reversed(waiting) evictable = list(chain(hanging, waiting)) - return eviction_helper.evict_for_seq(seq, evictable, 0) + return eviction_helper.evict_for_seq(seq, evictable, prealloc_size) def _reorder_waiting(): """reorder waiting.""" @@ -164,7 +164,7 @@ def _reorder_waiting(): break # allocate session memory - self.block_manager.allocate(seq) + self.block_manager.allocate(seq, prealloc_size) _to_running(seq) return running, swap_in_map, swap_out_map, copy_map @@ -215,7 +215,7 @@ def __evict_for_seq(seq: SchedulerSequence): def schedule(self, is_prefill: bool, prealloc_size: int = 0): """Schedule inputs for next steps.""" if is_prefill: - output = self._schedule_prefill() + output = self._schedule_prefill(prealloc_size) else: output = self._schedule_decoding(prealloc_size) running, swap_in_map, swap_out_map, copy_map = output From 731d70ad0d05450944f3eaa0a0f7e1b18dba36aa Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 8 Nov 2024 15:57:58 +0800 Subject: [PATCH 06/21] fix and remove cuda_graph func of medusa --- lmdeploy/pytorch/engine/engine.py | 2 +- lmdeploy/pytorch/models/medusa.py | 16 ---------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 0b2384063d..3600418e60 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -525,7 +525,7 @@ def _batch_stopping_criteria(self, token_ids: torch.Tensor, # one more step to cache last token(stop word) stopped = num_appendable_ids < 0 if stop_words is not None: - if len(stop_words.shape) == 1: + if len(token_ids.shape) == 1: token_ids = token_ids[:, None] # TODO speculative model supports multiple stop word sw_stopped = (token_ids == stop_words).any(1) diff --git a/lmdeploy/pytorch/models/medusa.py b/lmdeploy/pytorch/models/medusa.py index 931bb4e8b9..4a6c62e5d4 100644 --- a/lmdeploy/pytorch/models/medusa.py +++ b/lmdeploy/pytorch/models/medusa.py @@ -100,22 +100,6 @@ def get_logits(self, hidden_states: List[torch.Tensor]): outputs = torch.stack(outputs, 1) return outputs - def support_cuda_graph( - self, - input_ids: torch.Tensor, - **kwargs, - ): - """support cudagraph.""" - seq_lens = input_ids.size(1) - if seq_lens <= 512: - return True - - # prevent oom on llama-3 70b - if self.config.num_hidden_layers >= 40: - return False - - return False - def get_input_embeddings(self): """get input embeddings.""" return self.model.get_input_embeddings() From bcec0afd9abc42626bc1c83aae32d070bb49e04f Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 8 Nov 2024 18:46:23 +0800 Subject: [PATCH 07/21] fix --- lmdeploy/pytorch/engine/engine.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 3600418e60..e93a744001 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -701,6 +701,10 @@ def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence, if t == -1: idx = i break + if stopped: + idx = min( + idx, + msg.sampling_param.max_new_tokens - msg.num_new_tokens) token = token[:idx] else: if stopped: @@ -811,7 +815,7 @@ def __update_inputs(next_token_ids): inputs, num_ignore_eos > 0) # score the proposals with the target model spec_inputs = copy.deepcopy(inputs) - batch_size, num_speculative_tokens = proposal_token_ids.shape + _, num_speculative_tokens = proposal_token_ids.shape target_proposal_ids = torch.cat( [next_token_ids.unsqueeze(-1), proposal_token_ids], -1) spec_inputs.input_ids = target_proposal_ids.flatten( @@ -834,12 +838,7 @@ def __update_inputs(next_token_ids): draft_token_ids=proposal_token_ids) next_token_ids = torch.cat( [next_token_ids[:, None], target_output], -1) - # truncate final outputs to appendable length - batch_indices = torch.arange(batch_size, - device=score_output.device) num_appendable_ids = num_appendable_ids - last_accpet - 1 - max_len_ids = (num_appendable_ids - 1).clamp_max(-1) - next_token_ids[batch_indices, max_len_ids] = -1 # stopping criteria stopped, num_appendable_ids = self._batch_stopping_criteria( From 1898213546b226001011b62e61f1c4e917ce738d Mon Sep 17 00:00:00 2001 From: AllentDan Date: Mon, 11 Nov 2024 15:39:32 +0800 Subject: [PATCH 08/21] support tp --- lmdeploy/pytorch/engine/model_agent.py | 56 ++++++++++++++++--- .../pytorch/engine/speculative_decoding.py | 37 ------------ lmdeploy/pytorch/models/medusa.py | 8 +++ 3 files changed, 56 insertions(+), 45 deletions(-) delete mode 100644 lmdeploy/pytorch/engine/speculative_decoding.py diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index af4db6b47f..9240d28ecd 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -604,6 +604,8 @@ def __init__(self, backend_config: BackendConfig, world_size: int, adapters: Dict[str, str] = None, + speculative_model: str = None, + speculative_model_config: ModelConfig = None, trust_remote_code: bool = True) -> None: import signal @@ -636,6 +638,13 @@ def __signal_term_handler(sig, frame): world_size=world_size, barrier=self.mp_bar) + self.speculative_model = None + if speculative_model is not None: + self.speculative_model_config = speculative_model_config + self.speculative_model = self._build_speculative_model( + speculative_model, + self.speculative_model_config, + world_size=world_size) model, cache_engine, cache_config = self._build_model( model_path=model_path, model_config=model_config, @@ -743,6 +752,35 @@ def _build_model( return model, cache_engine, cache_config + @torch.inference_mode() + def _build_speculative_model( + self, + model_path: str, + model_config: ModelConfig, + world_size: int, + cache_config: CacheConfig = None, + backend_config: BackendConfig = None, + ): + """build model. + + Currently, cache engine and backend config not used. + """ + with get_dist_manager().context(self._dist_ctx): + rank = 0 + device_map = torch.device('cuda') + + custom_module_map = model_config.custom_module_map + if custom_module_map is not None: + update_custom_module_map(custom_module_map) + if rank == 0: + logger.info('build model.') + patched_model = build_patched_model(model_config, + device=device_map) + if rank == 0: + logger.info('loading weights.') + load_model_weights(patched_model, model_path, device=device_map) + return patched_model + def get_block_numel(self): """get block nelement.""" k_cache = self.cache_engine.local_gpu_cache[0][0] @@ -851,12 +889,14 @@ def build_model_agent(model_path: str, adapters=adapters, trust_remote_code=trust_remote_code) else: - model_agent = TPModelAgent(model_path, - model_config=model_config, - cache_config=cache_config, - backend_config=backend_config, - speculative_model=speculative_model, - world_size=tp, - adapters=adapters, - trust_remote_code=trust_remote_code) + model_agent = TPModelAgent( + model_path, + model_config=model_config, + cache_config=cache_config, + backend_config=backend_config, + speculative_model=speculative_model, + speculative_model_config=speculative_model_config, + world_size=tp, + adapters=adapters, + trust_remote_code=trust_remote_code) return model_agent diff --git a/lmdeploy/pytorch/engine/speculative_decoding.py b/lmdeploy/pytorch/engine/speculative_decoding.py deleted file mode 100644 index 97a9547b5a..0000000000 --- a/lmdeploy/pytorch/engine/speculative_decoding.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os - -from lmdeploy.messages import PytorchEngineConfig -from lmdeploy.utils import get_logger, get_model - -from ..devices import get_device_manager -from .engine import Engine -from .model_agent import build_model_agent - -logger = get_logger('lmdeploy') - - -class SpeculativeDecodingEngine(Engine): - - def __init__(self, - model_path: str, - speculative_model: str = None, - engine_config: PytorchEngineConfig = None, - trust_remote_code: bool = True) -> None: - super().__init__(model_path, engine_config, trust_remote_code) - - if not os.path.exists(speculative_model): - speculative_model = get_model(speculative_model, - engine_config.download_dir, - engine_config.revision) - self.speculative_model = speculative_model - - with get_device_manager().context(self.device_context): - self.speculative_model_agent = build_model_agent( - speculative_model, - cache_config=self.cache_config, - backend_config=self.backend_config, - trust_remote_code=trust_remote_code, - tp=self.tp, - dtype=engine_config.dtype, - custom_module_map=engine_config.custom_module_map) diff --git a/lmdeploy/pytorch/models/medusa.py b/lmdeploy/pytorch/models/medusa.py index 4a6c62e5d4..4e73876fb9 100644 --- a/lmdeploy/pytorch/models/medusa.py +++ b/lmdeploy/pytorch/models/medusa.py @@ -86,6 +86,14 @@ def __init__(self, ) for _ in range(self.config.medusa_num_heads) ]) + def support_cuda_graph( + self, + *args, + **kwargs, + ): + """support cudagraph.""" + return True + def forward(self, last_hidden_states: torch.Tensor, **kwargs) -> List[torch.Tensor]: outputs = [head[0](last_hidden_states) for head in self.medusa_head] From 5131a05c81ae47745ff7a95221b68f91a4f39323 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Mon, 11 Nov 2024 18:45:34 +0800 Subject: [PATCH 09/21] fix tp --- lmdeploy/pytorch/engine/model_agent.py | 45 ++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 9240d28ecd..82ba02f4a7 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -804,8 +804,49 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, world_size=1, stream=self.stream, ) + if self.speculative_model is not None: + inputs.last_hidden_states = output['hidden_states'] + spec_outputs = model_forward( + self.speculative_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + output['spec_hidden_states'] = spec_outputs['hidden_states'] return output + async def score_proposal(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap, num_speculative_tokens): + """model forward. + + Args: + inputs (Dict): The input data comes from _make_inputs. + swap_in_map (SwapMap): Cache maps to swap in. + swap_out_map (SwapMap): Cache maps to swap out. + num_speculative_tokens (int): The number of the proposal tokens. + """ + with get_dist_manager().context(self._dist_ctx): + self.mp_bar.wait() + rank = 0 + _broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map], + self.stream) + cache_swapping(self.cache_engine, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map) + spec_outputs = model_forward( + self.patched_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + hidden_states = spec_outputs['hidden_states'] + hidden_states = hidden_states.reshape( + [-1, num_speculative_tokens + 1, hidden_states.shape[-1]]) + logits = self.get_logits(hidden_states) + return logits + def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -841,6 +882,10 @@ def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" return self.patched_model.get_logits(hidden_states) + def get_spec_logits(self, hidden_states_list: List[torch.Tensor]): + """get logits of model output.""" + return self.speculative_model.get_logits(hidden_states_list) + def _exit_handler(agent: TPModelAgent): if hasattr(agent, 'patched_model'): From 8c0a76d8f8232a6ad4427b08d5e450f6c12d4d6a Mon Sep 17 00:00:00 2001 From: AllentDan Date: Tue, 12 Nov 2024 14:52:26 +0800 Subject: [PATCH 10/21] add cli --- lmdeploy/cli/cli.py | 3 +++ lmdeploy/cli/serve.py | 2 ++ lmdeploy/cli/utils.py | 9 +++++++++ lmdeploy/pytorch/chat.py | 3 +++ lmdeploy/pytorch/engine/engine.py | 5 ++--- lmdeploy/pytorch/engine/logits_process.py | 8 +------- lmdeploy/serve/openai/api_server.py | 3 +++ 7 files changed, 23 insertions(+), 10 deletions(-) diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index a2083c6e64..48ee6b9978 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -43,6 +43,7 @@ def add_parser_convert(): help='The directory path of the model') ArgumentHelper.model_format(parser) ArgumentHelper.tp(parser) + ArgumentHelper.speculative_model(parser) # other args ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) @@ -121,6 +122,7 @@ def add_parser_chat(): # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.adapters(pt_group) + ArgumentHelper.speculative_model(pt_group) ArgumentHelper.device(pt_group) ArgumentHelper.eager_mode(pt_group) # common engine args @@ -270,6 +272,7 @@ def chat(args): quant_policy=args.quant_policy) run_chat(args.model_path, engine_config, + speculative_model=args.speculative_model, chat_template_config=chat_template_config) else: from lmdeploy.turbomind.chat import main as run_chat diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 68f9de8c15..d4a0e54b1b 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -159,6 +159,7 @@ def add_parser_api_server(): pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.adapters(pt_group) + ArgumentHelper.speculative_model(pt_group) ArgumentHelper.device(pt_group) ArgumentHelper.eager_mode(pt_group) @@ -337,6 +338,7 @@ def api_server(args): vision_config = VisionConfig(args.vision_max_batch_size) run_api_server(args.model_path, model_name=args.model_name, + speculative_model=args.speculative_model, backend=backend, backend_config=backend_config, chat_template_config=chat_template_config, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index ad7a058c8f..d45def3c89 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -100,6 +100,15 @@ def model_name(parser): 'by the RESTful API `/v1/models`. If it is not specified, ' '`model_path` will be adopted') + @staticmethod + def speculative_model(parser): + """Add argument speculative_model to parser.""" + + return parser.add_argument('--speculative-model', + type=str, + default=None, + help='The path of the spculative model.') + @staticmethod def dtype(parser, default: str = 'auto'): return parser.add_argument( diff --git a/lmdeploy/pytorch/chat.py b/lmdeploy/pytorch/chat.py index 2b5ee85edc..3783b8255e 100644 --- a/lmdeploy/pytorch/chat.py +++ b/lmdeploy/pytorch/chat.py @@ -52,6 +52,7 @@ def _stop_words(stop_words: List[str], tokenizer: Tokenizer): def run_chat(model_path: str, engine_config: PytorchEngineConfig, + speculative_model: Optional[str] = None, gen_config: GenerationConfig = None, session_id: int = 1, trust_remote_code: bool = True, @@ -62,12 +63,14 @@ def run_chat(model_path: str, Args: model_path (str): the huggingface model path. engine_config (PytorchEngineConfig): Config of engine. + speculative_model (str): the path of the speculative model. gen_config (GenerationConfig): Config of generation. session_id (int): the identical id of a session. trust_remote_code (bool): trust remote code. """ from lmdeploy.pytorch.engine import Engine tm_model = Engine.from_pretrained(model_path, + speculative_model=speculative_model, engine_config=engine_config, trust_remote_code=trust_remote_code) tokenizer = tm_model.tokenizer diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index e93a744001..fdd9eb0377 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -810,9 +810,8 @@ def __update_inputs(next_token_ids): num_ignore_eos = num_ignore_eos - 1 if 'spec_logits' in output: spec_logits = output['spec_logits'] - proposal_token_ids = self.async_sampling_logits( - spec_logits, all_ids, guided_input_ids, sampling_inputs, - inputs, num_ignore_eos > 0) + # TODO add tree decoding + proposal_token_ids = spec_logits.argmax(-1) # score the proposals with the target model spec_inputs = copy.deepcopy(inputs) _, num_speculative_tokens = proposal_token_ids.shape diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 2281964b2f..54740a4fb3 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -15,11 +15,7 @@ def _process_temperature_(scores: torch.Tensor, temperature: torch.Tensor): """process temperature.""" temperature = temperature.to(scores.dtype) - if len(scores.shape) == 3: - temperature = temperature[:, None, None] - else: # len==2 - temperature = temperature[:, None] - scores.div_(temperature) + scores.div_(temperature[:, None]) return scores @@ -27,8 +23,6 @@ def _process_bad_words_(scores: torch.Tensor, bad_words: torch.LongTensor, filter_value: float = -float('inf')): """process bad words.""" - if len(scores.shape) == 3: - bad_words = bad_words[:, :, None] mask = bad_words >= 0 bad_words = bad_words.where(mask, 0) filtered_scores = scores.gather(1, bad_words) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index a12cadaa7d..cd26e30426 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -973,6 +973,7 @@ async def startup_event(): def serve(model_path: str, + speculative_model: Optional[str] = None, model_name: Optional[str] = None, backend: Literal['turbomind', 'pytorch'] = 'turbomind', backend_config: Optional[Union[PytorchEngineConfig, @@ -1008,6 +1009,7 @@ def serve(model_path: str, on huggingface.co, such as "internlm/internlm-chat-7b", "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. + speculative_model (str): the path of the speculative model. model_name (str): the name of the served model. It can be accessed by the RESTful API `/v1/models`. If it is not specified, `model_path` will be adopted @@ -1072,6 +1074,7 @@ def serve(model_path: str, VariableInterface.async_engine = pipeline_class( model_path=model_path, + speculative_model=speculative_model, model_name=model_name, backend=backend, backend_config=backend_config, From 9930e613e978841bc9e361621fd1e1aba479afe8 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Wed, 4 Dec 2024 17:29:50 +0800 Subject: [PATCH 11/21] Add tree decoding --- lmdeploy/pytorch/backends/cuda/attention.py | 5 + lmdeploy/pytorch/backends/cuda/op_backend.py | 15 + lmdeploy/pytorch/engine/engine.py | 54 ++- lmdeploy/pytorch/engine/model_agent.py | 37 +- .../pytorch/kernels/cuda/flashattention.py | 428 ++++++++++++++++-- lmdeploy/pytorch/model_inputs.py | 7 + lmdeploy/pytorch/models/medusa.py | 194 ++++++++ lmdeploy/pytorch/paging/scheduler.py | 1 + tests/pytorch/kernel/test_flash_attention.py | 71 ++- 9 files changed, 720 insertions(+), 92 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index d01d6fe9b4..fae60976c1 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -21,6 +21,7 @@ class TritonAttentionMetadata(AttentionMetadata): fill_seqlens: torch.Tensor = None quant_policy: Literal[0, 4, 8] = 0 kv_flatten_size: int = None + medusa_attn_mask: torch.Tensor = None def _cdiv(a, b): @@ -100,6 +101,9 @@ def forward( fill_seqlens = attn_metadata.fill_seqlens fill_max_q_seqlen = key.numel() // (key.size(-1) * key.size(-2)) fill_q_start_loc = fill_seqlens.cumsum(0) - fill_seqlens + attention_mask = None + if attn_metadata.medusa_attn_mask is not None: + attention_mask = attn_metadata.medusa_attn_mask # fill kv cache if key is not None and value is not None: @@ -161,6 +165,7 @@ def forward( flatten_k, flatten_v, attn_output, + attention_mask=attention_mask, q_start_loc=q_start_loc, q_seqlens=q_seqlens, kv_start_loc=kv_start_loc, diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index 3e7fc23728..5307ba2e62 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -110,6 +110,20 @@ def update_step_context(cls, step_context): if not step_context.is_decoding: kv_start_loc = kv_seqlens.cumsum(0) - kv_seqlens kv_flatten_size = kv_seqlens.sum().item() + if step_context.medusa_attn_mask is not None: + max_q_seqlen = q_seqlens.max() + max_kv_seqlen = kv_seqlens.max() + bs = q_seqlens.shape[0] + medusa_len = step_context.medusa_attn_mask.shape[-1] + dtype = step_context.medusa_attn_mask.dtype + device = step_context.medusa_attn_mask.device + attention_mask = torch.zeros((bs, max_q_seqlen, max_kv_seqlen), + dtype=dtype, + device=device) + attention_mask[:, -medusa_len:, -medusa_len:] = ( + 1 - step_context.medusa_attn_mask) * (-1e30) + step_context.medusa_attn_mask = attention_mask + attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, @@ -119,6 +133,7 @@ def update_step_context(cls, step_context): kv_seqlens=kv_seqlens, kv_flatten_size=kv_flatten_size, quant_policy=step_context.kv_quant_policy, + medusa_attn_mask=step_context.medusa_attn_mask, ) cross_attn_metadata = None diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index fdd9eb0377..64381b8748 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -565,10 +565,10 @@ def extract_tokens(self, token_ids, eos_token_ids): return [token_ids], token_ids in eos_token_ids for i, token_id in enumerate(token_ids): if token_id in eos_token_ids: - return token_ids[:i + 1], True + return token_ids[:i], True if token_id == -1: - break - return token_ids[:i], False + return token_ids[:i], False + return token_ids, False @logging_timer('UpdateRunning', logger) def update_running(self, running: SeqList, next_token_ids: torch.Tensor, @@ -810,34 +810,38 @@ def __update_inputs(next_token_ids): num_ignore_eos = num_ignore_eos - 1 if 'spec_logits' in output: spec_logits = output['spec_logits'] - # TODO add tree decoding - proposal_token_ids = spec_logits.argmax(-1) - # score the proposals with the target model + cart_candidates, tree_candidates, medusa_attn_mask, medusa_position_ids, retrieve_indices = self.model_agent.generate_candidates( + spec_logits, next_token_ids) + bs, _, tree_decode_len = tree_candidates.shape spec_inputs = copy.deepcopy(inputs) - _, num_speculative_tokens = proposal_token_ids.shape - target_proposal_ids = torch.cat( - [next_token_ids.unsqueeze(-1), proposal_token_ids], -1) - spec_inputs.input_ids = target_proposal_ids.flatten( - ).unsqueeze(0) + spec_inputs.input_ids = tree_candidates.flatten().unsqueeze(0) spec_inputs.history_lengths += spec_inputs.seq_length spec_inputs.seq_length = torch.ones_like( - spec_inputs.seq_length) * (num_speculative_tokens + 1) - score_output = await self.model_agent.score_proposal( + spec_inputs.seq_length) * tree_decode_len + spec_inputs.medusa_attn_mask = medusa_attn_mask + spec_inputs.medusa_position_ids = medusa_position_ids + logits = await self.model_agent.tree_decoding( spec_inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map, - num_speculative_tokens=num_speculative_tokens) - from ..nn.rejection_sampling import RejectionSampler - rejection_sampler = RejectionSampler() - score_output = score_output.softmax(-1) - spec_logits = spec_logits.softmax(-1) - target_output, last_accpet = rejection_sampler.forward( - score_output, - draft_probs=spec_logits, - draft_token_ids=proposal_token_ids) - next_token_ids = torch.cat( - [next_token_ids[:, None], target_output], -1) - num_appendable_ids = num_appendable_ids - last_accpet - 1 + retrieve_indices=retrieve_indices) + # NOTE currently only greedy sampling supported + proposal_len = cart_candidates.shape[-1] + greedy_token_ids = logits.argmax(-1) + posterior_mask = cart_candidates[..., 1:] == greedy_token_ids[ + ..., :-1] + accept_len, best_idx = torch.cumprod(posterior_mask, + dim=-1).sum(-1).max(-1) + # accept_len = torch.where(accept_len==proposal_len-1, proposal_len, accept_len) + next_token_ids = cart_candidates[torch.arange(bs), best_idx] + # bonus_token_ids = greedy_token_ids[torch.arange(bs),best_idx,-1:] + # next_token_ids = torch.cat([best_candidates, bonus_token_ids], -1) + mask_idx = torch.arange( + proposal_len, + device=next_token_ids.device).expand_as(next_token_ids) + next_token_ids[mask_idx > accept_len[:, None]] = -1 + # next_token_ids = next_token_ids[...,:-1] # to be removed + num_appendable_ids = num_appendable_ids - accept_len - 1 # stopping criteria stopped, num_appendable_ids = self._batch_stopping_criteria( diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 82ba02f4a7..eff8d190a7 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -330,7 +330,7 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, async def score_proposal(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap, num_speculative_tokens): - """model forward. + """score the proposal. Args: inputs (Dict): The input data comes from _make_inputs. @@ -356,6 +356,36 @@ async def score_proposal(self, inputs: ModelInputs, swap_in_map: SwapMap, logits = self.get_logits(hidden_states) return logits + async def tree_decoding(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap, + retrieve_indices: torch.Tensor): + cache_swapping(self.cache_engine, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map) + bs = inputs.history_lengths.shape[0] + inputs.medusa_position_ids = inputs.medusa_position_ids.repeat( + inputs.history_lengths.shape[0], 1) + inputs.medusa_position_ids = inputs.medusa_position_ids.to( + inputs.history_lengths.device) + inputs.history_lengths[:, None] + spec_outputs = model_forward( + self.patched_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + await asyncio.get_event_loop().run_in_executor(None, + self.stream.synchronize) + hidden_states = spec_outputs['hidden_states'] + hidden_states = hidden_states.reshape(bs, -1, hidden_states.shape[-1]) + logits = self.get_logits(hidden_states)[:, retrieve_indices] + return logits + + def generate_candidates(self, draft_logits: torch.Tensor, + base_token_id: torch.Tensor): + return self.speculative_model.generate_candidates( + draft_logits, base_token_id) + def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" return self.patched_model.get_logits(hidden_states) @@ -878,6 +908,11 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, self.stream.synchronize) return output + def generate_candidates(self, draft_logits: torch.Tensor, + base_token_id: torch.Tensor): + return self.speculative_model.generate_candidates( + draft_logits, base_token_id) + def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" return self.patched_model.get_logits(hidden_states) diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index 7521a3e2bb..0bff92e9c6 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -47,6 +47,17 @@ def softcapping(qk, logit_softcapping: tl.constexpr): return qk +@triton.jit +def _load_kv(ptrs, causal_mask: tl.constexpr, boundary_check: tl.constexpr): + """load kv.""" + if causal_mask: + return tl.load(ptrs, + boundary_check=boundary_check, + padding_option='zero') + else: + return tl.load(ptrs) + + @triton.jit def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, qk_scale, history_mask, @@ -63,11 +74,11 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, for start_n in range(loop_start, loop_end, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - k = tl.load(k_ptrs) + k = _load_kv(k_ptrs, causal_mask, boundary_check=(1, )) qk = tl.dot(q, k) if BLOCK_DK1 != 0: - k1 = tl.load(k1_ptrs) + k1 = _load_kv(k1_ptrs, causal_mask, boundary_check=(1, )) qk += tl.dot(q1, k1) if causal_mask: @@ -113,7 +124,7 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, acc = acc * alpha[:, None] # update acc - v = tl.load(v_ptrs) + v = _load_kv(v_ptrs, causal_mask, boundary_check=(0, )) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i @@ -127,6 +138,90 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, return acc, l_i, m_i +@triton.jit +def _prefill_fwd_inner_with_mask( + acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, + qk_scale, history_mask, kv_min_loc, attn_mask_ptr, + causal_mask: tl.constexpr, window_size: tl.constexpr, + logit_softcapping: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_DK1: tl.constexpr): + k_ptrs = tl.advance(k_ptrs, (0, loop_start)) + attn_mask_ptr = tl.advance(attn_mask_ptr, (0, loop_start)) + v_ptrs = tl.advance(v_ptrs, (loop_start, 0)) + if BLOCK_DK1: + k1_ptrs = tl.advance(k1_ptrs, (0, loop_start)) + + offs_n = tl.arange(0, BLOCK_N) + for start_n in range(loop_start, loop_end, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + k = _load_kv(k_ptrs, causal_mask, boundary_check=(1, )) + attn_mask = tl.load(attn_mask_ptr) + qk = tl.dot(q, k) + + if BLOCK_DK1 != 0: + k1 = _load_kv(k1_ptrs, causal_mask, boundary_check=(1, )) + qk += tl.dot(q1, k1) + + if causal_mask: + qk *= qk_scale + qk = softcapping(qk, logit_softcapping) + qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :]) + if window_size > 0: + qk_mask = qk_mask and ( + (start_n + offs_n[None, :]) >= kv_min_loc[:, None]) + qk = tl.where( + qk_mask, + qk, + float(-1e30), + ) + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_i_new[:, None] + elif window_size > 0: + qk *= qk_scale + qk = softcapping(qk, logit_softcapping) + qk_mask = ((start_n + offs_n[None, :]) >= kv_min_loc[:, None]) + qk = tl.where( + qk_mask, + qk, + float(-1e30), + ) + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_i_new[:, None] + elif logit_softcapping > 0: + qk *= qk_scale + qk = softcapping(qk, logit_softcapping) + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_i_new[:, None] + else: + m_i_new = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_i_new[:, None] + + qk = qk + attn_mask + # -- compute p, m_i and l_i + p = tl_exp2(qk) + alpha = tl_exp2(m_i - m_i_new) + l_i = alpha * l_i + tl.sum(p, 1) + # -- update output accumulator -- + # scale acc + acc = acc * alpha[:, None] + + # update acc + v = _load_kv(v_ptrs, causal_mask, boundary_check=(0, )) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + m_i = m_i_new + + k_ptrs = tl.advance(k_ptrs, (0, BLOCK_N)) + attn_mask_ptr = tl.advance(attn_mask_ptr, (0, BLOCK_N)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_N, 0)) + if BLOCK_DK1: + k1_ptrs = tl.advance(k1_ptrs, (0, BLOCK_N)) + + return acc, l_i, m_i + + # # FOR DEBUG, DON'T REMOVE # import itertools # configs = [ @@ -168,6 +263,7 @@ def _flash_prefill_fwd_kernel( kv_group_num, head_dim_k, head_dim_v, + causal: tl.constexpr, window_size: tl.constexpr, logit_softcapping: tl.constexpr, BLOCK_M: tl.constexpr, @@ -257,9 +353,13 @@ def _flash_prefill_fwd_kernel( acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) qk_scale = sm_scale * tl_log2(math.e) - history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M) + if causal: + history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M) + loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N + else: + history_mask = tl.full([BLOCK_M], kv_seqlen - 1, dtype=tl.int32) + loop_end = kv_seqlen // BLOCK_N * BLOCK_N - loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N acc, l_i, m_i = _prefill_fwd_inner(acc, l_i, m_i, @@ -280,7 +380,10 @@ def _flash_prefill_fwd_kernel( BLOCK_DK1=BLOCK_DK1) loop_start = loop_end - loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N) + if causal: + loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N) + else: + loop_end = kv_seqlen acc, l_i, m_i = _prefill_fwd_inner(acc, l_i, m_i, @@ -314,6 +417,203 @@ def _flash_prefill_fwd_kernel( mask=(offs_m[:, None] < q_seqlen) & mask_dv[None, :]) +@triton.jit +def _flash_prefill_fwd_kernel_with_mask( + q_ptr, + k_ptr, + v_ptr, + o_ptr, + q_start_loc_ptr, + q_seqlens_ptr, + kv_start_loc_ptr, + kv_seqlens_ptr, + sm_scale, + attention_mask, + stride_qs: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_ks: tl.constexpr, + stride_kh, + stride_kd: tl.constexpr, + stride_vs: tl.constexpr, + stride_vh, + stride_vd: tl.constexpr, + stride_os: tl.constexpr, + stride_oh: tl.constexpr, + stride_od: tl.constexpr, + stride_amb: tl.constexpr, + stride_amqs: tl.constexpr, + stride_amkvs: tl.constexpr, + kv_group_num, + head_dim_k, + head_dim_v, + causal: tl.constexpr, + window_size: tl.constexpr, + logit_softcapping: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DK: tl.constexpr, + BLOCK_DK1: tl.constexpr, + BLOCK_DV: tl.constexpr, +): + """flash attention kernel.""" + start_m = tl.program_id(0) + head_id = tl.program_id(1) + batch_id = tl.program_id(2) + + q_seqlen = tl.load(q_seqlens_ptr + batch_id) + + if BLOCK_M * start_m >= q_seqlen: + return + + kv_head_id = head_id // kv_group_num + q_seqlen = q_seqlen.to(tl.int32) + kv_seqlen = tl.load(kv_seqlens_ptr + batch_id).to(tl.int32) + q_start_loc = tl.load(q_start_loc_ptr + batch_id).to(tl.int32) + kv_start_loc = tl.load(kv_start_loc_ptr + batch_id).to(tl.int32) + + history_len = kv_seqlen - q_seqlen + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + loop_start = 0 + kv_min_loc = tl.zeros([BLOCK_M], dtype=tl.int32) + if window_size > 0: + start_block_id = tl.maximum( + history_len + start_m * BLOCK_M - window_size, 0) // BLOCK_N + kv_min_loc = tl.maximum(history_len + offs_m - window_size, 0) + loop_start = start_block_id * BLOCK_N + + offs_dk = tl.arange(0, BLOCK_DK) + mask_dk = offs_dk < head_dim_k + offs_dk = tl.multiple_of(tl.max_contiguous(offs_dk % head_dim_k, BLOCK_DK), + BLOCK_DK) + off_q = ((q_start_loc + offs_m[:, None]) * stride_qs + + head_id * stride_qh + offs_dk[None, :] * stride_qd) + q_ptrs = q_ptr + off_q + q = tl.load(q_ptrs, mask=(offs_m[:, None] < q_seqlen and mask_dk[None, :])) + + k_ptrs = tl.make_block_ptr( + base=k_ptr + kv_start_loc * stride_ks + kv_head_id * stride_kh, + shape=(head_dim_k, kv_seqlen), + strides=(stride_kd, stride_ks), + offsets=(0, 0), + block_shape=(BLOCK_DK, BLOCK_N), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + kv_start_loc * stride_vs + kv_head_id * stride_vh, + shape=(kv_seqlen, head_dim_v), + strides=(stride_vs, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DV), + order=(1, 0), + ) + attn_mask_ptrs = tl.make_block_ptr( + base=attention_mask + batch_id * stride_amb + + start_m * BLOCK_M * stride_amqs, + shape=(q_seqlen, kv_seqlen), + strides=(stride_amqs, stride_amkvs), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(0, 1), + ) + + if BLOCK_DK1 != 0: + offs_dk1 = BLOCK_DK + tl.arange(0, BLOCK_DK1) + mask_dk1 = offs_dk1 < head_dim_k + offs_dk1 = tl.multiple_of( + tl.max_contiguous(offs_dk1 % head_dim_k, BLOCK_DK1), BLOCK_DK1) + offs_q1 = ((q_start_loc + offs_m[:, None]) * stride_qs + + head_id * stride_qh + offs_dk1[None, :] * stride_qd) + q1_ptrs = q_ptr + offs_q1 + q1 = tl.load(q1_ptrs, + mask=(offs_m[:, None] < q_seqlen and mask_dk1[None, :])) + k1_ptrs = tl.make_block_ptr( + base=k_ptr + kv_start_loc * stride_ks + kv_head_id * stride_kh, + shape=(head_dim_k, kv_seqlen), + strides=(stride_kd, stride_ks), + offsets=(BLOCK_DK, 0), + block_shape=(BLOCK_DK1, BLOCK_N), + order=(0, 1), + ) + else: + q1 = q + k1_ptrs = k_ptrs + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) + + qk_scale = sm_scale * tl_log2(math.e) + if causal: + history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M) + loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N + else: + history_mask = tl.full([BLOCK_M], kv_seqlen - 1, dtype=tl.int32) + loop_end = kv_seqlen // BLOCK_N * BLOCK_N + + acc, l_i, m_i = _prefill_fwd_inner_with_mask( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + q1, + k1_ptrs, + loop_start, + loop_end, + qk_scale, + history_mask, + kv_min_loc, + attn_mask_ptrs, + causal_mask=False, + window_size=window_size, + logit_softcapping=logit_softcapping, + BLOCK_N=BLOCK_N, + BLOCK_DK1=BLOCK_DK1) + + loop_start = loop_end + if causal: + loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N) + else: + loop_end = kv_seqlen + acc, l_i, m_i = _prefill_fwd_inner_with_mask( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + q1, + k1_ptrs, + loop_start, + loop_end, + qk_scale, + history_mask, + kv_min_loc, + attn_mask_ptrs, + causal_mask=True, + window_size=window_size, + logit_softcapping=logit_softcapping, + BLOCK_N=BLOCK_N, + BLOCK_DK1=BLOCK_DK1) + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + + # initialize pointers to output + offs_dv = tl.arange(0, BLOCK_DV) + mask_dv = offs_dv < head_dim_v + off_o = ((q_start_loc + offs_m[:, None]) * stride_os + + head_id * stride_oh + offs_dv[None, :] * stride_od) + out_ptrs = o_ptr + off_o + tl.store(out_ptrs, + acc, + mask=(offs_m[:, None] < q_seqlen) & mask_dv[None, :]) + + _nv_cap = None @@ -326,10 +626,12 @@ def flash_attention_fwd( q_seqlens: Tensor, kv_start_loc: Tensor, kv_seqlens: Tensor, + attention_mask: Tensor = None, max_seqlen: int = None, window_size: int = None, sm_scale: float = None, logit_softcapping: float = None, + causal: bool = True, kv_layout: str = 'hsd', ): """varlen flash Attention forward. @@ -380,6 +682,7 @@ def grid(args): BLOCK_M = max(16, 8192 // BLOCK_DK) else: BLOCK_M = max(16, 16384 // BLOCK_DK) + BLOCK_M = min(128, BLOCK_M) num_warps = 4 num_stages = min(4, max(2, 1024 // BLOCK_DK)) if BLOCK_DK >= 512: @@ -388,40 +691,83 @@ def grid(args): num_stages = 3 else: num_stages = 4 - _flash_prefill_fwd_kernel[grid]( - q_states, - k_states, - v_states, - o_states, - q_start_loc, - q_seqlens, - kv_start_loc, - kv_seqlens, - sm_scale=sm_scale, - stride_qs=q_states.stride(0), - stride_qh=q_states.stride(1), - stride_qd=q_states.stride(2), - stride_ks=k_states.stride(s_dim), - stride_kh=k_states.stride(h_dim), - stride_kd=k_states.stride(d_dim), - stride_vs=v_states.stride(s_dim), - stride_vh=v_states.stride(h_dim), - stride_vd=v_states.stride(d_dim), - stride_os=o_states.stride(0), - stride_oh=o_states.stride(1), - stride_od=o_states.stride(2), - kv_group_num=kv_group_num, - head_dim_k=head_dim_k, - head_dim_v=head_dim_v, - window_size=window_size, - logit_softcapping=logit_softcapping, - BLOCK_DK=BLOCK_DK, - BLOCK_DK1=BLOCK_DK1, - BLOCK_DV=BLOCK_DV, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - ) + if attention_mask is None: + _flash_prefill_fwd_kernel[grid]( + q_states, + k_states, + v_states, + o_states, + q_start_loc, + q_seqlens, + kv_start_loc, + kv_seqlens, + sm_scale=sm_scale, + stride_qs=q_states.stride(0), + stride_qh=q_states.stride(1), + stride_qd=q_states.stride(2), + stride_ks=k_states.stride(s_dim), + stride_kh=k_states.stride(h_dim), + stride_kd=k_states.stride(d_dim), + stride_vs=v_states.stride(s_dim), + stride_vh=v_states.stride(h_dim), + stride_vd=v_states.stride(d_dim), + stride_os=o_states.stride(0), + stride_oh=o_states.stride(1), + stride_od=o_states.stride(2), + kv_group_num=kv_group_num, + head_dim_k=head_dim_k, + head_dim_v=head_dim_v, + causal=causal, + window_size=window_size, + logit_softcapping=logit_softcapping, + BLOCK_DK=BLOCK_DK, + BLOCK_DK1=BLOCK_DK1, + BLOCK_DV=BLOCK_DV, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _flash_prefill_fwd_kernel_with_mask[grid]( + q_states, + k_states, + v_states, + o_states, + q_start_loc, + q_seqlens, + kv_start_loc, + kv_seqlens, + sm_scale=sm_scale, + attention_mask=attention_mask, + stride_qs=q_states.stride(0), + stride_qh=q_states.stride(1), + stride_qd=q_states.stride(2), + stride_ks=k_states.stride(s_dim), + stride_kh=k_states.stride(h_dim), + stride_kd=k_states.stride(d_dim), + stride_vs=v_states.stride(s_dim), + stride_vh=v_states.stride(h_dim), + stride_vd=v_states.stride(d_dim), + stride_os=o_states.stride(0), + stride_oh=o_states.stride(1), + stride_od=o_states.stride(2), + stride_amb=attention_mask.stride(0), + stride_amqs=attention_mask.stride(1), + stride_amkvs=attention_mask.stride(2), + kv_group_num=kv_group_num, + head_dim_k=head_dim_k, + head_dim_v=head_dim_v, + causal=causal, + window_size=window_size, + logit_softcapping=logit_softcapping, + BLOCK_DK=BLOCK_DK, + BLOCK_DK1=BLOCK_DK1, + BLOCK_DV=BLOCK_DV, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + num_stages=num_stages, + ) return o_states diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 5fa21fa6d6..0ae5dd7986 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -123,6 +123,8 @@ class ModelInputs: cross_attention_states: torch.Tensor = None history_cross_kv_seqlens: torch.LongTensor = None last_hidden_states: torch.Tensor = None + medusa_attn_mask: torch.Tensor = None + medusa_position_ids: torch.Tensor = None def update(self, input_ids: torch.LongTensor): """update input ids.""" @@ -219,6 +221,7 @@ class StepContext: cross_kv_seqlens: torch.LongTensor = None kv_quant_policy: Literal[0, 4, 8] = 0 last_hidden_states: torch.Tensor = None + medusa_attn_mask: torch.Tensor = None _outputs: Dict = field(default_factory=dict) @@ -274,6 +277,9 @@ def new( # seq_len + history_length kv_seqlens = q_seqlens + history_seqlens kv_seqlens -= inputs.num_ignored_history + # medusa + if inputs.medusa_position_ids is not None: + position_ids = inputs.medusa_position_ids.reshape(1, -1) ret = StepContext( input_ids=inputs.input_ids, @@ -293,6 +299,7 @@ def new( mrope_position_ids=mrope_position_ids, cross_attention_states=cross_attention_states, last_hidden_states=last_hidden_states, + medusa_attn_mask=inputs.medusa_attn_mask, cross_kv_seqlens=inputs.history_cross_kv_seqlens, kv_quant_policy=kv_quant_policy, ) diff --git a/lmdeploy/pytorch/models/medusa.py b/lmdeploy/pytorch/models/medusa.py index 4e73876fb9..e7ed5c8aae 100644 --- a/lmdeploy/pytorch/models/medusa.py +++ b/lmdeploy/pytorch/models/medusa.py @@ -11,6 +11,45 @@ from .utils.cudagraph import CudaGraphMixin +vicuna_7b_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (0, 0, 0), (1, 0), (2, ), + (0, 2), (0, 0, 1), (0, 3), (3, ), (0, 1, 0), (2, 0), (4, ), + (0, 0, 2), (0, 4), (1, 1), (1, 0, 0), (0, 0, 0, 0), (5, ), + (0, 0, 3), (0, 5), (0, 2, 0), (3, 0), (0, 1, 1), (0, 6), + (6, ), (0, 7), (0, 0, 4), (4, 0), (1, 2), (0, 8), (7, ), + (0, 3, 0), (0, 0, 0, 1), (0, 0, 5), (2, 1), (0, 0, 6), + (1, 0, 1), (0, 0, 1, 0), (2, 0, 0), (5, 0), (0, 9), + (0, 1, 2), (8, ), (0, 4, 0), (0, 2, 1), (1, 3), (0, 0, 7), + (0, 0, 0, 2), (0, 0, 8), (1, 1, 0), (0, 1, 0, 0), (6, 0), + (9, ), (0, 1, 3), (0, 0, 0, 3), (1, 0, 2), (0, 5, 0), + (3, 1), (0, 0, 2, 0), (7, 0), (1, 4)] # noqa +TOPK = 10 + + +def pad_path(path, length, pad_value=-2): + """Pad the given path list with a specific value up to a specified length. + + Parameters: + - path (list): The original list that needs padding. + - length (int): The desired length of the padded list. + - pad_value (optional, default=-2): The value to use for padding. + + Returns: + - list: A new list based on the original path but padded to the desired length. + + Example: + >>> pad_path([1,2,3], 5) + [1, 2, 3, -2, -2] + + Note: + If the given path is already longer than the specified length, + then no padding occurs, and the original path is returned. + """ + + # Calculate the number of padding values needed by subtracting the length + # of the path from the desired length. + # Append the padding values to the original path and return the new list. + return path + [pad_value] * (length - len(path)) + class ResBlock(nn.Module): """A Residual Block module. @@ -85,6 +124,161 @@ def __init__(self, device=device), ) for _ in range(self.config.medusa_num_heads) ]) + self.medusa_choices = None + if 'vicuna-7b' in config.base_model_name_or_path: + self.medusa_choices = vicuna_7b_stage2 + self.generate_medusa_buffers(device=device) + + def generate_medusa_buffers(self, device: torch.dtype = None): + """Generate buffers for the Medusa structure based on the provided + choices. + + Args: + medusa_choices (list): A nested list representing tree in the + Medusa structure. + device (str): Device to which the tensors should be moved. + Default is "cuda". + + Returns: + dict: A dictionary containing buffers related to the + Medusa structure. + """ + if self.medusa_choices is None: + self.medusa_attn_mask = None + self.tree_indices = None + self.medusa_position_ids = None + self.retrieve_indices = None + return + + # Sort the medusa_choices based on their lengths and then their values + sorted_medusa_choices = sorted(self.medusa_choices, + key=lambda x: (len(x), x)) + medusa_len = len(sorted_medusa_choices) + 1 + + # Initialize depth_counts to keep track of how many choices have a particular depth + depth_counts = [] + prev_depth = 0 + for path in sorted_medusa_choices: + depth = len(path) + if depth != prev_depth: + depth_counts.append(0) + depth_counts[depth - 1] += 1 + prev_depth = depth + + # Create the attention mask for Medusa + medusa_attn_mask = torch.eye(medusa_len, medusa_len) + medusa_attn_mask[:, 0] = 1 + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_medusa_choice = sorted_medusa_choices[start + j] + # retrieve ancestor position + if len(cur_medusa_choice) == 1: + continue + ancestor_idx = [] + for c in range(len(cur_medusa_choice) - 1): + ancestor_idx.append( + sorted_medusa_choices.index(cur_medusa_choice[:c + + 1]) + 1) + medusa_attn_mask[j + start + 1, ancestor_idx] = 1 + start += depth_counts[i] + + # Generate tree indices for the Medusa structure + medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long) + medusa_tree_indices[0] = 0 + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_medusa_choice = sorted_medusa_choices[start + j] + medusa_tree_indices[start + j + + 1] = cur_medusa_choice[-1] + TOPK * i + 1 + start += depth_counts[i] + + # Generate position IDs for the Medusa structure + medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long) + start = 0 + for i in range(len(depth_counts)): + medusa_position_ids[start + 1:start + depth_counts[i] + 1] = i + 1 + start += depth_counts[i] + + # Generate retrieval indices for Medusa structure verification + retrieve_indices_nest = [] + retrieve_paths = [] + for i in range(len(sorted_medusa_choices)): + cur_medusa_choice = sorted_medusa_choices[-i - 1] + retrieve_indice = [] + if cur_medusa_choice in retrieve_paths: + continue + else: + for c in range(len(cur_medusa_choice)): + retrieve_indice.append( + sorted_medusa_choices.index(cur_medusa_choice[:c + 1])) + retrieve_paths.append(cur_medusa_choice[:c + 1]) + retrieve_indices_nest.append(retrieve_indice) + max_length = max([len(x) for x in retrieve_indices_nest]) + retrieve_indices = [ + pad_path(path, max_length) for path in retrieve_indices_nest + ] + retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) + retrieve_indices = retrieve_indices + 1 + retrieve_indices = torch.cat([ + torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), + retrieve_indices + ], + dim=1) + self.medusa_attn_mask = medusa_attn_mask.unsqueeze(0).unsqueeze(0).to( + device) + self.tree_indices = medusa_tree_indices.to(device) + self.medusa_position_ids = medusa_position_ids.to(device) + self.retrieve_indices = retrieve_indices.to(device) + + def generate_candidates(self, medusa_logits: torch.Tensor, + base_token_id: torch.Tensor): + """Generate candidates based on provided logits and indices. + + Args: + medusa_logits (torch.Tensor): Logits from a specialized Medusa + structure, aiding in candidate selection. Shape + [bs, speculative_num, vocab_size] + base_token_id (torch.Tensor): Standard logits from a language + model. Shape [bs] + + Returns: + tuple (torch.Tensor, torch.Tensor): A tuple containing two sets of candidates: + 1. Cartesian candidates derived from the combined original and Medusa logits. + 2. Tree candidates mapped from the Cartesian candidates using tree indices. + """ # noqa + # Greedy decoding: Select the most probable candidate from the original logits. + # here we only implement greedy decoding + bs = medusa_logits.shape[0] + candidates_logit = base_token_id.unsqueeze(-1) + # Extract the TOPK candidates from the medusa logits. + candidates_medusa_logits = torch.topk(medusa_logits, TOPK, + dim=-1).indices + + # Combine the selected candidate from the original logits with the topk medusa logits. + candidates = torch.cat( + [candidates_logit, + candidates_medusa_logits.view(bs, -1)], dim=-1) + + # Map the combined candidates to the tree indices to get tree candidates. + tree_candidates = candidates[:, self.tree_indices] + + # Extend the tree candidates by appending a zero. + tree_candidates_ext = torch.cat([ + tree_candidates, + torch.zeros( + (bs, 1), dtype=torch.long, device=tree_candidates.device) + ], + dim=-1) + + # Retrieve the cartesian candidates using the retrieve indices. + cart_candidates = tree_candidates_ext[:, self.retrieve_indices] + + # Unsqueeze the tree candidates for dimension consistency. + tree_candidates = tree_candidates.unsqueeze( + 1) # bs, 1, len(self.medusa_choices) + return cart_candidates, tree_candidates, self.medusa_attn_mask, self.medusa_position_ids, self.retrieve_indices def support_cuda_graph( self, diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 1b43615b96..e28e375965 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -214,6 +214,7 @@ def __evict_for_seq(seq: SchedulerSequence): def schedule(self, is_prefill: bool, prealloc_size: int = 0): """Schedule inputs for next steps.""" + prealloc_size = max(prealloc_size, 64) # 64 for medusa tree decode if is_prefill: output = self._schedule_prefill(prealloc_size) else: diff --git a/tests/pytorch/kernel/test_flash_attention.py b/tests/pytorch/kernel/test_flash_attention.py index 7d4b7a7f3a..ae4bf798f4 100644 --- a/tests/pytorch/kernel/test_flash_attention.py +++ b/tests/pytorch/kernel/test_flash_attention.py @@ -10,23 +10,29 @@ def _conti_input(data, q_seqlens): return data -def _make_bias(q_seqlens, history_lens, neg_val): - full_seq_lens = q_seqlens + history_lens +def _make_bias(q_seqlens, history_lens, neg_val, causal): + kv_seqlens = q_seqlens + history_lens max_seq_len = q_seqlens.max().item() - max_full_len = full_seq_lens.max().item() - seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens] - for r, l in zip(seq_ranges, q_seqlens): - r[l:] = -max_full_len - seq_ranges = torch.stack(seq_ranges, dim=0).cuda() - kv_ranges = [torch.arange(max_full_len) for _ in full_seq_lens] - kv_ranges = torch.stack(kv_ranges, 0).cuda() - mask = kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, - None, - None] - return mask.float() * neg_val - - -def _naive_attention(batched_q, batched_kv, bias): + max_kv_len = kv_seqlens.max().item() + if causal: + seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens] + for r, l in zip(seq_ranges, q_seqlens): + r[l:] = -max_kv_len + seq_ranges = torch.stack(seq_ranges, dim=0).cuda() + kv_ranges = [torch.arange(max_kv_len) for _ in kv_seqlens] + kv_ranges = torch.stack(kv_ranges, 0).cuda() + mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > + history_lens[:, None, None]) + return mask.float() * neg_val + else: + q_mask = torch.arange(max_seq_len)[None].cuda() < q_seqlens[:, None] + k_mask = torch.arange(max_kv_len)[None].cuda() < kv_seqlens[:, None] + mask = q_mask[:, :, None] & k_mask[:, None, :] + + return (~mask).float() * neg_val + + +def _naive_attention(batched_q, batched_kv, bias, rand_mask): batched_k, batched_v = batched_kv num_heads_q = batched_q.shape[2] @@ -43,7 +49,7 @@ def _naive_attention(batched_q, batched_kv, bias): v = v.unsqueeze(2).expand(-1, -1, group, -1, -1).flatten(1, 2) qk = torch.matmul(q, k) / math.sqrt(head_dim) - attn_weight = qk + bias[:, None] + attn_weight = qk + bias[:, None] + rand_mask[:, None] attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) attn_weight = attn_weight.to(q.dtype) attn_output = torch.matmul(attn_weight, v) @@ -100,6 +106,10 @@ def num_heads_q(self, request): def num_heads_k(self, request): yield request.param + @pytest.fixture + def causal(self, request): + yield request.param + @pytest.fixture def q_seqlens(self, request): yield torch.tensor(request.param, device='cuda') @@ -138,8 +148,8 @@ def batched_kv(self, q_seqlens, history_lens, num_heads_k, head_dim_k, head_dim_v, dtype): torch.manual_seed(123) batch_size = len(q_seqlens) - full_seq_lens = q_seqlens + history_lens - max_seq_len = full_seq_lens.max().item() + kv_seqlens = q_seqlens + history_lens + max_seq_len = kv_seqlens.max().item() k = torch.rand(batch_size, max_seq_len, num_heads_k, @@ -167,13 +177,18 @@ def conti_kv(self, kv_seqlens, batched_kv): yield (conti_k, conti_v) @pytest.fixture - def mask(self, q_seqlens, history_lens): + def mask(self, q_seqlens, history_lens, causal): neg_val = -1e30 - yield _make_bias(q_seqlens, history_lens, neg_val) + yield _make_bias(q_seqlens, history_lens, neg_val, causal) @pytest.fixture - def gt(self, batched_q, batched_kv, mask): - yield _naive_attention(batched_q, batched_kv, mask) + def rand_mask(self, mask): + neg_val = -1e30 + yield torch.rand_like(mask).round() * neg_val + + @pytest.fixture + def gt(self, batched_q, batched_kv, mask, rand_mask): + yield _naive_attention(batched_q, batched_kv, mask, rand_mask) @pytest.fixture def conti_gt(self, gt, q_seqlens): @@ -183,26 +198,32 @@ def conti_gt(self, gt, q_seqlens): @pytest.mark.parametrize('head_dim_v', [32], indirect=True) @pytest.mark.parametrize('num_heads_q', [8, 2], indirect=True) @pytest.mark.parametrize('num_heads_k', [2], indirect=True) + @pytest.mark.parametrize('causal', [True, False], indirect=True) @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([30, 50, 70, 90], [50, 40, 30, 20])], indirect=True) + @pytest.mark.parametrize('with_attention_mask', [True]) def test_flash_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, - kv_start_loc, kv_seqlens, head_dim_v, conti_gt): + kv_start_loc, kv_seqlens, head_dim_v, causal, + conti_gt, rand_mask, with_attention_mask): from lmdeploy.pytorch.kernels.cuda.flashattention import \ flash_attention_fwd max_seq_len = q_seqlens.max().item() conti_k, conti_v = conti_kv out = conti_q.new_empty(*conti_q.shape[:-1], head_dim_v) + rand_mask = rand_mask if with_attention_mask else rand_mask flash_attention_fwd(conti_q, conti_k, conti_v, out, + attention_mask=rand_mask, q_start_loc=q_start_loc, q_seqlens=q_seqlens, kv_start_loc=kv_start_loc, kv_seqlens=kv_seqlens, - max_seqlen=max_seq_len) + max_seqlen=max_seq_len, + causal=causal) torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5) @pytest.fixture From 37afbb638671afb095cc24531da632f91bda9cff Mon Sep 17 00:00:00 2001 From: AllentDan Date: Thu, 5 Dec 2024 11:35:43 +0800 Subject: [PATCH 12/21] support bonus token id --- lmdeploy/pytorch/engine/engine.py | 24 +++++---- lmdeploy/pytorch/models/medusa.py | 87 ++++++++++++++++++++++++++----- 2 files changed, 88 insertions(+), 23 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 64381b8748..5bcff97442 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -810,10 +810,12 @@ def __update_inputs(next_token_ids): num_ignore_eos = num_ignore_eos - 1 if 'spec_logits' in output: spec_logits = output['spec_logits'] - cart_candidates, tree_candidates, medusa_attn_mask, medusa_position_ids, retrieve_indices = self.model_agent.generate_candidates( - spec_logits, next_token_ids) + (cart_candidates, tree_candidates, medusa_attn_mask, + medusa_position_ids, + retrieve_indices) = self.model_agent.generate_candidates( + spec_logits, next_token_ids) bs, _, tree_decode_len = tree_candidates.shape - spec_inputs = copy.deepcopy(inputs) + spec_inputs = inputs spec_inputs.input_ids = tree_candidates.flatten().unsqueeze(0) spec_inputs.history_lengths += spec_inputs.seq_length spec_inputs.seq_length = torch.ones_like( @@ -826,22 +828,22 @@ def __update_inputs(next_token_ids): swap_out_map=swap_out_map, retrieve_indices=retrieve_indices) # NOTE currently only greedy sampling supported + # besides, we used the bonus token id predicted during + # tree decoding while original Medusa did not proposal_len = cart_candidates.shape[-1] greedy_token_ids = logits.argmax(-1) posterior_mask = cart_candidates[..., 1:] == greedy_token_ids[ ..., :-1] accept_len, best_idx = torch.cumprod(posterior_mask, dim=-1).sum(-1).max(-1) - # accept_len = torch.where(accept_len==proposal_len-1, proposal_len, accept_len) - next_token_ids = cart_candidates[torch.arange(bs), best_idx] - # bonus_token_ids = greedy_token_ids[torch.arange(bs),best_idx,-1:] - # next_token_ids = torch.cat([best_candidates, bonus_token_ids], -1) + greedy_token_ids = greedy_token_ids[torch.arange(bs), best_idx] + next_token_ids = torch.cat( + [next_token_ids[:, None], greedy_token_ids], -1) mask_idx = torch.arange( - proposal_len, + proposal_len + 1, device=next_token_ids.device).expand_as(next_token_ids) - next_token_ids[mask_idx > accept_len[:, None]] = -1 - # next_token_ids = next_token_ids[...,:-1] # to be removed - num_appendable_ids = num_appendable_ids - accept_len - 1 + next_token_ids[mask_idx > (accept_len[:, None] + 1)] = -1 + num_appendable_ids = num_appendable_ids - accept_len - 2 # stopping criteria stopped, num_appendable_ids = self._batch_stopping_criteria( diff --git a/lmdeploy/pytorch/models/medusa.py b/lmdeploy/pytorch/models/medusa.py index e7ed5c8aae..bc9d086dc9 100644 --- a/lmdeploy/pytorch/models/medusa.py +++ b/lmdeploy/pytorch/models/medusa.py @@ -21,20 +21,71 @@ (0, 1, 2), (8, ), (0, 4, 0), (0, 2, 1), (1, 3), (0, 0, 7), (0, 0, 0, 2), (0, 0, 8), (1, 1, 0), (0, 1, 0, 0), (6, 0), (9, ), (0, 1, 3), (0, 0, 0, 3), (1, 0, 2), (0, 5, 0), - (3, 1), (0, 0, 2, 0), (7, 0), (1, 4)] # noqa + (3, 1), (0, 0, 2, 0), (7, 0), (1, 4)] + +vicuna_13b_stage2 = [(0, ), (0, 0), (1, ), (0, 0, 0), (0, 1), (1, 0), (2, ), + (0, 2), (0, 0, 1), (0, 1, 0), (3, ), (0, 3), (2, 0), + (0, 0, 2), (0, 0, 0, 0), (0, 4), (1, 0, 0), (1, 1), (4, ), + (0, 0, 3), (0, 5), (0, 2, 0), (5, ), (3, 0), (0, 1, 1), + (0, 6), (0, 0, 4), (0, 0, 0, 1), + (0, 7), (0, 0, 5), (1, 2), (0, 0, 1, 0), (0, 3, 0), + (1, 0, 1), (4, 0), (0, 0, 6), (0, 8), (2, 0, 0), (0, 9), + (6, ), (7, ), (2, 1), (5, 0), (0, 1, 2), (0, 0, 0, 2), + (8, ), (0, 4, 0), (0, 1, 0, 0), (0, 2, 1), (0, 0, 7), + (1, 1, 0), (1, 3), (0, 0, 2, 0), (9, ), (0, 0, 8), + (0, 5, 0), (0, 0, 0, 3), (0, 0, 9), (0, 1, 3), (1, 0, 2), + (0, 0, 1, 1), (3, 0, 0), (1, 0, 0, 0)] + +vicuna_33b_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (0, 0, 0), (1, 0), (2, ), + (0, 2), (0, 0, 1), (0, 3), (3, ), + (0, 1, 0), (2, 0), (0, 4), (4, ), (0, 0, 2), (1, 1), + (1, 0, 0), (0, 5), (5, ), (0, 0, 0, 0), (0, 0, 3), (3, 0), + (0, 2, 0), (0, 6), (0, 1, 1), (6, ), (0, 0, 4), (0, 7), + (7, ), (1, 2), (4, 0), (8, ), (0, 3, 0), (0, 0, 5), + (0, 0, 0, 1), (0, 8), (2, 1), (0, 9), (1, 0, 1), + (2, 0, 0), (0, 0, 6), (5, 0), (0, 0, 1, 0), (1, 3), + (0, 1, 2), (0, 4, 0), (0, 0, 7), (0, 2, 1), (9, ), + (1, 1, 0), (0, 0, 0, 2), (6, 0), (0, 0, 8), (0, 1, 0, 0), + (7, 0), (0, 1, 3), (0, 5, 0), (1, 4), (0, 0, 9), (3, 1), + (1, 0, 2), (2, 2)] + +zephyr_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (2, ), + (0, 0, 0), (1, 0), (0, 2), (3, ), (0, 3), (4, ), (2, 0), + (0, 0, 1), (0, 4), (5, ), (0, 5), (0, 1, 0), (1, 1), (6, ), + (0, 0, 2), (3, 0), (0, 6), (7, ), (0, 7), (0, 8), (0, 0, 3), + (1, 0, 0), (0, 9), (0, 2, 0), (1, 2), (4, 0), (8, ), (9, ), + (2, 1), (0, 1, 1), (0, 0, 4), (0, 0, 0, 0), (5, 0), (0, 3, 0), + (1, 3), (0, 0, 5), (0, 0, 6), (6, 0), (2, 0, 0), (1, 0, 1), + (0, 1, 2), (0, 4, 0), (1, 4), (3, 1), (2, 2), (0, 0, 7), + (7, 0), (0, 2, 1), (0, 0, 8), (0, 1, 3), (0, 5, 0), (1, 5), + (0, 0, 9), (1, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0), (4, 1), + (2, 3)] +mc_sim_7b_63 = [[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], + [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, + 6], [6], + [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], + [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], + [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], + [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], + [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], + [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], + [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], + [0, 7, 0]] + TOPK = 10 def pad_path(path, length, pad_value=-2): """Pad the given path list with a specific value up to a specified length. - Parameters: - - path (list): The original list that needs padding. - - length (int): The desired length of the padded list. - - pad_value (optional, default=-2): The value to use for padding. + Args: + path (list): The original list that needs padding. + length (int): The desired length of the padded list. + pad_value (optional, default=-2): The value to use for padding. Returns: - - list: A new list based on the original path but padded to the desired length. + list: A new list based on the original path but padded to the desired + length. Example: >>> pad_path([1,2,3], 5) @@ -127,6 +178,14 @@ def __init__(self, self.medusa_choices = None if 'vicuna-7b' in config.base_model_name_or_path: self.medusa_choices = vicuna_7b_stage2 + elif 'vicuna-13b' in config.base_model_name_or_path: + self.medusa_choices = vicuna_13b_stage2 + elif 'vicuna-33b' in config.base_model_name_or_path: + self.medusa_choices = vicuna_33b_stage2 + elif 'zephyr' in config.base_model_name_or_path: + self.medusa_choices = zephyr_stage2 + else: + self.medusa_choices = mc_sim_7b_63 self.generate_medusa_buffers(device=device) def generate_medusa_buffers(self, device: torch.dtype = None): @@ -155,7 +214,8 @@ def generate_medusa_buffers(self, device: torch.dtype = None): key=lambda x: (len(x), x)) medusa_len = len(sorted_medusa_choices) + 1 - # Initialize depth_counts to keep track of how many choices have a particular depth + # Initialize depth_counts to keep track of how many choices have a + # particular depth depth_counts = [] prev_depth = 0 for path in sorted_medusa_choices: @@ -248,20 +308,22 @@ def generate_candidates(self, medusa_logits: torch.Tensor, 1. Cartesian candidates derived from the combined original and Medusa logits. 2. Tree candidates mapped from the Cartesian candidates using tree indices. """ # noqa - # Greedy decoding: Select the most probable candidate from the original logits. - # here we only implement greedy decoding + # Greedy decoding: Select the most probable candidate from the original + # logits. here we only implement greedy decoding bs = medusa_logits.shape[0] candidates_logit = base_token_id.unsqueeze(-1) # Extract the TOPK candidates from the medusa logits. candidates_medusa_logits = torch.topk(medusa_logits, TOPK, dim=-1).indices - # Combine the selected candidate from the original logits with the topk medusa logits. + # Combine the selected candidate from the original logits with the + # topk medusa logits. candidates = torch.cat( [candidates_logit, candidates_medusa_logits.view(bs, -1)], dim=-1) - # Map the combined candidates to the tree indices to get tree candidates. + # Map the combined candidates to the tree indices to get tree + # candidates. tree_candidates = candidates[:, self.tree_indices] # Extend the tree candidates by appending a zero. @@ -278,7 +340,8 @@ def generate_candidates(self, medusa_logits: torch.Tensor, # Unsqueeze the tree candidates for dimension consistency. tree_candidates = tree_candidates.unsqueeze( 1) # bs, 1, len(self.medusa_choices) - return cart_candidates, tree_candidates, self.medusa_attn_mask, self.medusa_position_ids, self.retrieve_indices + return (cart_candidates, tree_candidates, self.medusa_attn_mask, + self.medusa_position_ids, self.retrieve_indices) def support_cuda_graph( self, From dcc6e854b9b3e0140235722c2e962366572b9af1 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Thu, 5 Dec 2024 11:47:18 +0800 Subject: [PATCH 13/21] tp --- lmdeploy/pytorch/engine/model_agent.py | 31 ++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index eff8d190a7..d550b6a030 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -875,6 +875,7 @@ async def score_proposal(self, inputs: ModelInputs, swap_in_map: SwapMap, hidden_states = hidden_states.reshape( [-1, num_speculative_tokens + 1, hidden_states.shape[-1]]) logits = self.get_logits(hidden_states) + self.stream.synchronize() return logits def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, @@ -908,6 +909,36 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, self.stream.synchronize) return output + async def tree_decoding(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap, + retrieve_indices: torch.Tensor): + bs = inputs.history_lengths.shape[0] + inputs.medusa_position_ids = inputs.medusa_position_ids.repeat( + inputs.history_lengths.shape[0], 1) + inputs.medusa_position_ids = inputs.medusa_position_ids.to( + inputs.history_lengths.device) + inputs.history_lengths[:, None] + with get_dist_manager().context(self._dist_ctx): + self.mp_bar.wait() + rank = 0 + _broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map], + self.stream) + cache_swapping(self.cache_engine, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map) + spec_outputs = model_forward( + self.patched_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + hidden_states = spec_outputs['hidden_states'] + hidden_states = hidden_states.reshape(bs, -1, + hidden_states.shape[-1]) + logits = self.get_logits(hidden_states)[:, retrieve_indices] + self.stream.synchronize() + return logits + def generate_candidates(self, draft_logits: torch.Tensor, base_token_id: torch.Tensor): return self.speculative_model.generate_candidates( From bbc108cb3106205f032220edf1b8f6707b3b7c98 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Thu, 5 Dec 2024 19:46:17 +0800 Subject: [PATCH 14/21] remove tl.constexpr to avoid repeatly compile triton kernel --- lmdeploy/pytorch/kernels/cuda/flashattention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index 3cca3f9b8d..51c82e7400 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -444,9 +444,9 @@ def _flash_prefill_fwd_kernel_with_mask( stride_os: tl.constexpr, stride_oh: tl.constexpr, stride_od: tl.constexpr, - stride_amb: tl.constexpr, - stride_amqs: tl.constexpr, - stride_amkvs: tl.constexpr, + stride_amb, + stride_amqs, + stride_amkvs, kv_group_num, head_dim_k, head_dim_v, From 2940dce0cfc07e0671c77ae48837b4f24e11c0da Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 6 Dec 2024 10:26:51 +0800 Subject: [PATCH 15/21] fix with deep copy --- lmdeploy/pytorch/engine/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index b66ab61acb..263ef784ee 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -838,7 +838,7 @@ def __update_inputs(next_token_ids): retrieve_indices) = self.model_agent.generate_candidates( spec_logits, next_token_ids) bs, _, tree_decode_len = tree_candidates.shape - spec_inputs = inputs + spec_inputs = copy.deepcopy(inputs) spec_inputs.input_ids = tree_candidates.flatten().unsqueeze(0) spec_inputs.history_lengths += spec_inputs.seq_length spec_inputs.seq_length = torch.ones_like( From 055e7c79516145e4733687ac0d32a21f1b01c7ff Mon Sep 17 00:00:00 2001 From: AllentDan Date: Tue, 10 Dec 2024 11:19:16 +0800 Subject: [PATCH 16/21] update kernels --- .../pytorch/kernels/cuda/flashattention.py | 445 ++++-------------- 1 file changed, 82 insertions(+), 363 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index 51c82e7400..8651706c6e 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -61,11 +61,13 @@ def _load_kv(ptrs, causal_mask: tl.constexpr, boundary_check: tl.constexpr): @triton.jit def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, sm_scale, history_mask, - kv_min_loc, causal_mask: tl.constexpr, - window_size: tl.constexpr, + kv_min_loc, attn_mask_ptr, apply_mask: tl.constexpr, + causal_mask: tl.constexpr, window_size: tl.constexpr, logit_softcapping: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DK1: tl.constexpr): k_ptrs = tl.advance(k_ptrs, (0, loop_start)) + if apply_mask: + attn_mask_ptr = tl.advance(attn_mask_ptr, (0, loop_start)) v_ptrs = tl.advance(v_ptrs, (loop_start, 0)) if BLOCK_DK1: k1_ptrs = tl.advance(k1_ptrs, (0, loop_start)) @@ -75,6 +77,8 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, start_n = tl.multiple_of(start_n, BLOCK_N) k = _load_kv(k_ptrs, causal_mask, boundary_check=(1, )) + if apply_mask: + attn_mask = tl.load(attn_mask_ptr) qk = tl.dot(q, k) if BLOCK_DK1 != 0: @@ -83,6 +87,8 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, if causal_mask: qk *= sm_scale + if apply_mask: + qk = qk + attn_mask qk = softcapping(qk, logit_softcapping) qk = qk * tl_log2(math.e) qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :]) @@ -98,6 +104,8 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, qk -= m_i_new[:, None] elif window_size > 0: qk *= sm_scale + if apply_mask: + qk = qk + attn_mask qk = softcapping(qk, logit_softcapping) qk = qk * tl_log2(math.e) qk_mask = ((start_n + offs_n[None, :]) >= kv_min_loc[:, None]) @@ -110,98 +118,20 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, qk -= m_i_new[:, None] elif logit_softcapping > 0: qk *= sm_scale + if apply_mask: + qk = qk + attn_mask qk = softcapping(qk, logit_softcapping) qk = qk * tl_log2(math.e) m_i_new = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_i_new[:, None] else: - qk_scale = sm_scale * tl_log2(math.e) - m_i_new = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_i_new[:, None] - - # -- compute p, m_i and l_i - p = tl_exp2(qk) - alpha = tl_exp2(m_i - m_i_new) - l_i = alpha * l_i + tl.sum(p, 1) - # -- update output accumulator -- - # scale acc - acc = acc * alpha[:, None] - - # update acc - v = _load_kv(v_ptrs, causal_mask, boundary_check=(0, )) - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - m_i = m_i_new - - k_ptrs = tl.advance(k_ptrs, (0, BLOCK_N)) - v_ptrs = tl.advance(v_ptrs, (BLOCK_N, 0)) - if BLOCK_DK1: - k1_ptrs = tl.advance(k1_ptrs, (0, BLOCK_N)) - - return acc, l_i, m_i - - -@triton.jit -def _prefill_fwd_inner_with_mask( - acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, - qk_scale, history_mask, kv_min_loc, attn_mask_ptr, - causal_mask: tl.constexpr, window_size: tl.constexpr, - logit_softcapping: tl.constexpr, BLOCK_N: tl.constexpr, - BLOCK_DK1: tl.constexpr): - k_ptrs = tl.advance(k_ptrs, (0, loop_start)) - attn_mask_ptr = tl.advance(attn_mask_ptr, (0, loop_start)) - v_ptrs = tl.advance(v_ptrs, (loop_start, 0)) - if BLOCK_DK1: - k1_ptrs = tl.advance(k1_ptrs, (0, loop_start)) - - offs_n = tl.arange(0, BLOCK_N) - for start_n in range(loop_start, loop_end, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - k = _load_kv(k_ptrs, causal_mask, boundary_check=(1, )) - attn_mask = tl.load(attn_mask_ptr) - qk = tl.dot(q, k) - - if BLOCK_DK1 != 0: - k1 = _load_kv(k1_ptrs, causal_mask, boundary_check=(1, )) - qk += tl.dot(q1, k1) - - if causal_mask: - qk *= qk_scale - qk = softcapping(qk, logit_softcapping) - qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :]) - if window_size > 0: - qk_mask = qk_mask and ( - (start_n + offs_n[None, :]) >= kv_min_loc[:, None]) - qk = tl.where( - qk_mask, - qk, - float(-1e30), - ) - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_i_new[:, None] - elif window_size > 0: - qk *= qk_scale - qk = softcapping(qk, logit_softcapping) - qk_mask = ((start_n + offs_n[None, :]) >= kv_min_loc[:, None]) - qk = tl.where( - qk_mask, - qk, - float(-1e30), - ) - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_i_new[:, None] - elif logit_softcapping > 0: - qk *= qk_scale - qk = softcapping(qk, logit_softcapping) + qk *= sm_scale + if apply_mask: + qk = qk + attn_mask + qk *= tl_log2(math.e) m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_i_new[:, None] - else: - m_i_new = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_i_new[:, None] + qk = qk - m_i_new[:, None] - qk = qk + attn_mask # -- compute p, m_i and l_i p = tl_exp2(qk) alpha = tl_exp2(m_i - m_i_new) @@ -218,7 +148,6 @@ def _prefill_fwd_inner_with_mask( m_i = m_i_new k_ptrs = tl.advance(k_ptrs, (0, BLOCK_N)) - attn_mask_ptr = tl.advance(attn_mask_ptr, (0, BLOCK_N)) v_ptrs = tl.advance(v_ptrs, (BLOCK_N, 0)) if BLOCK_DK1: k1_ptrs = tl.advance(k1_ptrs, (0, BLOCK_N)) @@ -252,6 +181,8 @@ def _flash_prefill_fwd_kernel( kv_start_loc_ptr, kv_seqlens_ptr, sm_scale, + attention_mask, + apply_mask: tl.constexpr, stride_qs: tl.constexpr, stride_qh: tl.constexpr, stride_qd: tl.constexpr, @@ -264,6 +195,9 @@ def _flash_prefill_fwd_kernel( stride_os: tl.constexpr, stride_oh: tl.constexpr, stride_od: tl.constexpr, + stride_amb, + stride_amqs, + stride_amkvs, kv_group_num, head_dim_k, head_dim_v, @@ -329,6 +263,18 @@ def _flash_prefill_fwd_kernel( block_shape=(BLOCK_N, BLOCK_DV), order=(1, 0), ) + if apply_mask: + attn_mask_ptrs = tl.make_block_ptr( + base=attention_mask + batch_id * stride_amb + + start_m * BLOCK_M * stride_amqs, + shape=(q_seqlen, kv_seqlen), + strides=(stride_amqs, stride_amkvs), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(0, 1), + ) + else: + attn_mask_ptrs = tl.full([BLOCK_M, BLOCK_N], 0, dtype=tl.int32) if BLOCK_DK1 != 0: offs_dk1 = BLOCK_DK + tl.arange(0, BLOCK_DK1) @@ -376,6 +322,8 @@ def _flash_prefill_fwd_kernel( sm_scale, history_mask, kv_min_loc, + attn_mask_ptrs, + apply_mask=apply_mask, causal_mask=False, window_size=window_size, logit_softcapping=logit_softcapping, @@ -400,6 +348,8 @@ def _flash_prefill_fwd_kernel( sm_scale, history_mask, kv_min_loc, + attn_mask_ptrs, + apply_mask=apply_mask, causal_mask=True, window_size=window_size, logit_softcapping=logit_softcapping, @@ -420,203 +370,6 @@ def _flash_prefill_fwd_kernel( mask=(offs_m[:, None] < q_seqlen) & mask_dv[None, :]) -@triton.jit -def _flash_prefill_fwd_kernel_with_mask( - q_ptr, - k_ptr, - v_ptr, - o_ptr, - q_start_loc_ptr, - q_seqlens_ptr, - kv_start_loc_ptr, - kv_seqlens_ptr, - sm_scale, - attention_mask, - stride_qs: tl.constexpr, - stride_qh: tl.constexpr, - stride_qd: tl.constexpr, - stride_ks: tl.constexpr, - stride_kh, - stride_kd: tl.constexpr, - stride_vs: tl.constexpr, - stride_vh, - stride_vd: tl.constexpr, - stride_os: tl.constexpr, - stride_oh: tl.constexpr, - stride_od: tl.constexpr, - stride_amb, - stride_amqs, - stride_amkvs, - kv_group_num, - head_dim_k, - head_dim_v, - causal: tl.constexpr, - window_size: tl.constexpr, - logit_softcapping: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DK: tl.constexpr, - BLOCK_DK1: tl.constexpr, - BLOCK_DV: tl.constexpr, -): - """flash attention kernel.""" - start_m = tl.program_id(0) - head_id = tl.program_id(1) - batch_id = tl.program_id(2) - - q_seqlen = tl.load(q_seqlens_ptr + batch_id) - - if BLOCK_M * start_m >= q_seqlen: - return - - kv_head_id = head_id // kv_group_num - q_seqlen = q_seqlen.to(tl.int32) - kv_seqlen = tl.load(kv_seqlens_ptr + batch_id).to(tl.int32) - q_start_loc = tl.load(q_start_loc_ptr + batch_id).to(tl.int32) - kv_start_loc = tl.load(kv_start_loc_ptr + batch_id).to(tl.int32) - - history_len = kv_seqlen - q_seqlen - - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - - loop_start = 0 - kv_min_loc = tl.zeros([BLOCK_M], dtype=tl.int32) - if window_size > 0: - start_block_id = tl.maximum( - history_len + start_m * BLOCK_M - window_size, 0) // BLOCK_N - kv_min_loc = tl.maximum(history_len + offs_m - window_size, 0) - loop_start = start_block_id * BLOCK_N - - offs_dk = tl.arange(0, BLOCK_DK) - mask_dk = offs_dk < head_dim_k - offs_dk = tl.multiple_of(tl.max_contiguous(offs_dk % head_dim_k, BLOCK_DK), - BLOCK_DK) - off_q = ((q_start_loc + offs_m[:, None]) * stride_qs + - head_id * stride_qh + offs_dk[None, :] * stride_qd) - q_ptrs = q_ptr + off_q - q = tl.load(q_ptrs, mask=(offs_m[:, None] < q_seqlen and mask_dk[None, :])) - - k_ptrs = tl.make_block_ptr( - base=k_ptr + kv_start_loc * stride_ks + kv_head_id * stride_kh, - shape=(head_dim_k, kv_seqlen), - strides=(stride_kd, stride_ks), - offsets=(0, 0), - block_shape=(BLOCK_DK, BLOCK_N), - order=(0, 1), - ) - v_ptrs = tl.make_block_ptr( - base=v_ptr + kv_start_loc * stride_vs + kv_head_id * stride_vh, - shape=(kv_seqlen, head_dim_v), - strides=(stride_vs, stride_vd), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DV), - order=(1, 0), - ) - attn_mask_ptrs = tl.make_block_ptr( - base=attention_mask + batch_id * stride_amb + - start_m * BLOCK_M * stride_amqs, - shape=(q_seqlen, kv_seqlen), - strides=(stride_amqs, stride_amkvs), - offsets=(0, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(0, 1), - ) - - if BLOCK_DK1 != 0: - offs_dk1 = BLOCK_DK + tl.arange(0, BLOCK_DK1) - mask_dk1 = offs_dk1 < head_dim_k - offs_dk1 = tl.multiple_of( - tl.max_contiguous(offs_dk1 % head_dim_k, BLOCK_DK1), BLOCK_DK1) - offs_q1 = ((q_start_loc + offs_m[:, None]) * stride_qs + - head_id * stride_qh + offs_dk1[None, :] * stride_qd) - q1_ptrs = q_ptr + offs_q1 - q1 = tl.load(q1_ptrs, - mask=(offs_m[:, None] < q_seqlen and mask_dk1[None, :])) - k1_ptrs = tl.make_block_ptr( - base=k_ptr + kv_start_loc * stride_ks + kv_head_id * stride_kh, - shape=(head_dim_k, kv_seqlen), - strides=(stride_kd, stride_ks), - offsets=(BLOCK_DK, 0), - block_shape=(BLOCK_DK1, BLOCK_N), - order=(0, 1), - ) - else: - q1 = q - k1_ptrs = k_ptrs - - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) - - qk_scale = sm_scale * tl_log2(math.e) - if causal: - history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M) - loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N - else: - history_mask = tl.full([BLOCK_M], kv_seqlen - 1, dtype=tl.int32) - loop_end = kv_seqlen // BLOCK_N * BLOCK_N - - acc, l_i, m_i = _prefill_fwd_inner_with_mask( - acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - q1, - k1_ptrs, - loop_start, - loop_end, - qk_scale, - history_mask, - kv_min_loc, - attn_mask_ptrs, - causal_mask=False, - window_size=window_size, - logit_softcapping=logit_softcapping, - BLOCK_N=BLOCK_N, - BLOCK_DK1=BLOCK_DK1) - - loop_start = loop_end - if causal: - loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N) - else: - loop_end = kv_seqlen - acc, l_i, m_i = _prefill_fwd_inner_with_mask( - acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - q1, - k1_ptrs, - loop_start, - loop_end, - qk_scale, - history_mask, - kv_min_loc, - attn_mask_ptrs, - causal_mask=True, - window_size=window_size, - logit_softcapping=logit_softcapping, - BLOCK_N=BLOCK_N, - BLOCK_DK1=BLOCK_DK1) - # epilogue - m_i += tl.math.log2(l_i) - acc = acc / l_i[:, None] - - # initialize pointers to output - offs_dv = tl.arange(0, BLOCK_DV) - mask_dv = offs_dv < head_dim_v - off_o = ((q_start_loc + offs_m[:, None]) * stride_os + - head_id * stride_oh + offs_dv[None, :] * stride_od) - out_ptrs = o_ptr + off_o - tl.store(out_ptrs, - acc, - mask=(offs_m[:, None] < q_seqlen) & mask_dv[None, :]) - - _nv_cap = None @@ -694,83 +447,49 @@ def grid(args): num_stages = 3 else: num_stages = 4 + apply_mask = True if attention_mask is None: - _flash_prefill_fwd_kernel[grid]( - q_states, - k_states, - v_states, - o_states, - q_start_loc, - q_seqlens, - kv_start_loc, - kv_seqlens, - sm_scale=sm_scale, - stride_qs=q_states.stride(0), - stride_qh=q_states.stride(1), - stride_qd=q_states.stride(2), - stride_ks=k_states.stride(s_dim), - stride_kh=k_states.stride(h_dim), - stride_kd=k_states.stride(d_dim), - stride_vs=v_states.stride(s_dim), - stride_vh=v_states.stride(h_dim), - stride_vd=v_states.stride(d_dim), - stride_os=o_states.stride(0), - stride_oh=o_states.stride(1), - stride_od=o_states.stride(2), - kv_group_num=kv_group_num, - head_dim_k=head_dim_k, - head_dim_v=head_dim_v, - causal=causal, - window_size=window_size, - logit_softcapping=logit_softcapping, - BLOCK_DK=BLOCK_DK, - BLOCK_DK1=BLOCK_DK1, - BLOCK_DV=BLOCK_DV, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - ) - else: - _flash_prefill_fwd_kernel_with_mask[grid]( - q_states, - k_states, - v_states, - o_states, - q_start_loc, - q_seqlens, - kv_start_loc, - kv_seqlens, - sm_scale=sm_scale, - attention_mask=attention_mask, - stride_qs=q_states.stride(0), - stride_qh=q_states.stride(1), - stride_qd=q_states.stride(2), - stride_ks=k_states.stride(s_dim), - stride_kh=k_states.stride(h_dim), - stride_kd=k_states.stride(d_dim), - stride_vs=v_states.stride(s_dim), - stride_vh=v_states.stride(h_dim), - stride_vd=v_states.stride(d_dim), - stride_os=o_states.stride(0), - stride_oh=o_states.stride(1), - stride_od=o_states.stride(2), - stride_amb=attention_mask.stride(0), - stride_amqs=attention_mask.stride(1), - stride_amkvs=attention_mask.stride(2), - kv_group_num=kv_group_num, - head_dim_k=head_dim_k, - head_dim_v=head_dim_v, - causal=causal, - window_size=window_size, - logit_softcapping=logit_softcapping, - BLOCK_DK=BLOCK_DK, - BLOCK_DK1=BLOCK_DK1, - BLOCK_DV=BLOCK_DV, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - ) - + apply_mask = False + attention_mask = q_states.new_empty((1, 1, 1)) + _flash_prefill_fwd_kernel[grid]( + q_states, + k_states, + v_states, + o_states, + q_start_loc, + q_seqlens, + kv_start_loc, + kv_seqlens, + sm_scale=sm_scale, + attention_mask=attention_mask, + apply_mask=apply_mask, + stride_qs=q_states.stride(0), + stride_qh=q_states.stride(1), + stride_qd=q_states.stride(2), + stride_ks=k_states.stride(s_dim), + stride_kh=k_states.stride(h_dim), + stride_kd=k_states.stride(d_dim), + stride_vs=v_states.stride(s_dim), + stride_vh=v_states.stride(h_dim), + stride_vd=v_states.stride(d_dim), + stride_os=o_states.stride(0), + stride_oh=o_states.stride(1), + stride_od=o_states.stride(2), + stride_amb=attention_mask.stride(0), + stride_amqs=attention_mask.stride(1), + stride_amkvs=attention_mask.stride(2), + kv_group_num=kv_group_num, + head_dim_k=head_dim_k, + head_dim_v=head_dim_v, + causal=causal, + window_size=window_size, + logit_softcapping=logit_softcapping, + BLOCK_DK=BLOCK_DK, + BLOCK_DK1=BLOCK_DK1, + BLOCK_DV=BLOCK_DV, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + num_stages=num_stages, + ) return o_states From a6805d8697b8d09162137766855a308177bf249b Mon Sep 17 00:00:00 2001 From: AllentDan Date: Tue, 10 Dec 2024 15:10:05 +0800 Subject: [PATCH 17/21] fix UT --- lmdeploy/model.py | 2 +- .../pytorch/kernels/cuda/flashattention.py | 2 ++ requirements/runtime.txt | 2 +- tests/pytorch/kernel/test_flash_attention.py | 21 +++++++++++++------ tests/pytorch/paging/test_scheduler.py | 18 +++++++++------- 5 files changed, 29 insertions(+), 16 deletions(-) diff --git a/lmdeploy/model.py b/lmdeploy/model.py index b4c5eaab7d..a4355ea131 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -310,7 +310,7 @@ def __init__( eosys=' ', user='USER: ', eoh=' ', - assistant='ASSISTANT:', + assistant='ASSISTANT: ', eoa='', stop_words=[''], **kwargs): diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index 8651706c6e..645868ea87 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -149,6 +149,8 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, k_ptrs = tl.advance(k_ptrs, (0, BLOCK_N)) v_ptrs = tl.advance(v_ptrs, (BLOCK_N, 0)) + if apply_mask: + attn_mask_ptr = tl.advance(attn_mask_ptr, (0, BLOCK_N)) if BLOCK_DK1: k1_ptrs = tl.advance(k1_ptrs, (0, BLOCK_N)) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 400c492b09..ec4957608c 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -18,5 +18,5 @@ tiktoken torch<=2.4.0,>=2.0.0 torchvision<=0.19.0,>=0.15.0 transformers -triton>=2.2.0,<=3.0.0; sys_platform == "linux" +triton==3.0.0; sys_platform == "linux" uvicorn diff --git a/tests/pytorch/kernel/test_flash_attention.py b/tests/pytorch/kernel/test_flash_attention.py index ae4bf798f4..5edafc95a1 100644 --- a/tests/pytorch/kernel/test_flash_attention.py +++ b/tests/pytorch/kernel/test_flash_attention.py @@ -49,7 +49,9 @@ def _naive_attention(batched_q, batched_kv, bias, rand_mask): v = v.unsqueeze(2).expand(-1, -1, group, -1, -1).flatten(1, 2) qk = torch.matmul(q, k) / math.sqrt(head_dim) - attn_weight = qk + bias[:, None] + rand_mask[:, None] + attn_weight = qk + bias[:, None] + if rand_mask is not None: + attn_weight += rand_mask[:, None] attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) attn_weight = attn_weight.to(q.dtype) attn_output = torch.matmul(attn_weight, v) @@ -110,6 +112,10 @@ def num_heads_k(self, request): def causal(self, request): yield request.param + @pytest.fixture + def with_attention_mask(self, request): + yield request.param + @pytest.fixture def q_seqlens(self, request): yield torch.tensor(request.param, device='cuda') @@ -182,9 +188,12 @@ def mask(self, q_seqlens, history_lens, causal): yield _make_bias(q_seqlens, history_lens, neg_val, causal) @pytest.fixture - def rand_mask(self, mask): + def rand_mask(self, mask, with_attention_mask): neg_val = -1e30 - yield torch.rand_like(mask).round() * neg_val + if with_attention_mask: + yield torch.rand_like(mask).round() * neg_val + else: + yield None @pytest.fixture def gt(self, batched_q, batched_kv, mask, rand_mask): @@ -202,17 +211,17 @@ def conti_gt(self, gt, q_seqlens): @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([30, 50, 70, 90], [50, 40, 30, 20])], indirect=True) - @pytest.mark.parametrize('with_attention_mask', [True]) + @pytest.mark.parametrize('with_attention_mask', [True, False], + indirect=True) def test_flash_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, kv_start_loc, kv_seqlens, head_dim_v, causal, - conti_gt, rand_mask, with_attention_mask): + conti_gt, rand_mask): from lmdeploy.pytorch.kernels.cuda.flashattention import \ flash_attention_fwd max_seq_len = q_seqlens.max().item() conti_k, conti_v = conti_kv out = conti_q.new_empty(*conti_q.shape[:-1], head_dim_v) - rand_mask = rand_mask if with_attention_mask else rand_mask flash_attention_fwd(conti_q, conti_k, conti_v, diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index 2e7c1e1a0f..c988e9e5bf 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -14,11 +14,11 @@ def block_size(self): @pytest.fixture def num_cpu_blocks(self): - yield 4 + yield 12 @pytest.fixture def num_gpu_blocks(self): - yield 4 + yield 12 @pytest.fixture def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks): @@ -60,9 +60,9 @@ def test_schedule_base(self, scheduler, block_size, num_gpu_blocks): assert seq.status == MessageStatus.RUNNING assert seq in output.running assert len(block_tables) == 1 - assert len(block_tables[0]) == num_blocks - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - num_blocks + assert len(block_tables[0]) == num_blocks + 4 # medusa needs 4 more + assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - ( + num_blocks + 4) assert scheduler.has_unfinished() @@ -99,7 +99,8 @@ def test_update(self, scheduler, block_size, num_gpu_blocks): assert session_id1 in scheduler.sessions assert seq1 not in scheduler.running assert seq1 not in scheduler.hanging - assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 2 + assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - (2 + + 4) # stop session scheduler.stop_session(session_id2) @@ -136,7 +137,8 @@ def test_evict(self, scheduler, block_size, num_gpu_blocks, assert seq1.status == MessageStatus.RUNNING assert seq2.status == MessageStatus.RUNNING assert seq3.status == MessageStatus.WAITING - assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 3 + assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - ( + 3 + 4 * 2) # test: waiting alloc seq2.status = MessageStatus.STOPPED @@ -180,4 +182,4 @@ def test_evict(self, scheduler, block_size, num_gpu_blocks, # seq3: 3 nan assert seq1.status == MessageStatus.WAITING assert seq2.status == MessageStatus.RUNNING - assert block_manager.get_num_free_gpu_blocks() == 0 + assert block_manager.get_num_free_gpu_blocks() == 4 From fbcfd3a1b7fce163c48c4995ee57b5646f6ad802 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Tue, 31 Dec 2024 17:19:32 +0800 Subject: [PATCH 18/21] prefill cuda graph for medusa --- .../pytorch/backends/cuda/graph_runner.py | 3 +- lmdeploy/pytorch/backends/cuda/op_backend.py | 6 ++-- .../pytorch/kernels/cuda/flatten_kv_cache.py | 5 +-- lmdeploy/pytorch/models/utils/cudagraph.py | 31 +++++++++++++++++-- 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 33b1c85280..7b27ccaef6 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -139,7 +139,8 @@ def get_graph_key(self, input_ids: torch.Tensor, is_decoding = context.is_decoding num_tokens = input_ids.numel() new_num_tokens = next_power_of_2(num_tokens) - return (new_num_tokens, is_decoding) + mask_input = attn_metadata.medusa_attn_mask is not None + return (new_num_tokens, is_decoding, mask_input) def __call__(self, **kwargs): """call.""" diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index 084a43b8d9..bfd77a250f 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -124,8 +124,10 @@ def update_step_context(cls, step_context): attention_mask = torch.zeros((bs, max_q_seqlen, max_kv_seqlen), dtype=dtype, device=device) - attention_mask[:, -medusa_len:, -medusa_len:] = ( - 1 - step_context.medusa_attn_mask) * (-1e30) + masked_value = (1 - step_context.medusa_attn_mask) * (-1e30) + for i in range(bs): + attention_mask[i, :, kv_seqlens[i] - + medusa_len:kv_seqlens[i]] = masked_value # noqa step_context.medusa_attn_mask = attention_mask attn_metadata = attn_meta_cls( diff --git a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py index 90b135743e..e2b2091b84 100644 --- a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py @@ -48,7 +48,8 @@ def _flatten_kv_cache( seqlen = tl.load(seqlens_ptr + batch_id) start_loc = tl.load(start_loc_ptr + batch_id) # fill last block to prevent attention nan - if batch_id == num_batches - 1: + # seqlen>0 to filter cuda graph padding + if batch_id == num_batches - 1 and seqlen > 0: seqlen = OUT_SIZE - start_loc if page_id * BLOCK_BS >= seqlen: return @@ -140,7 +141,7 @@ def _flatten_kv_cache_quant( seqlen = tl.load(seqlens_ptr + batch_id) start_loc = tl.load(start_loc_ptr + batch_id) - if batch_id == num_batches - 1: + if batch_id == num_batches - 1 and seqlen > 0: seqlen = OUT_SIZE - start_loc if page_id * BLOCK_BS >= seqlen: return diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 74d090a9a3..3a7d2ec3d2 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -8,6 +8,8 @@ from lmdeploy.pytorch.model_inputs import StepContext BuffType = Dict[str, Tensor] +CUDA_GRAPH_PREFILL_SEQLEN = 256 +CUDA_GRAPH_PREFILL_KVLEN = 4096 def next_power_of_2(n: int): @@ -48,7 +50,11 @@ def support_cuda_graph( **kwargs, ): """return True is model support cudagraph.""" - return attn_metadata.is_decoding + # return attn_metadata.is_decoding + seq_lens = input_ids.size(1) + if attn_metadata.kv_flatten_size is None: + return seq_lens <= CUDA_GRAPH_PREFILL_SEQLEN + return seq_lens <= CUDA_GRAPH_PREFILL_SEQLEN and attn_metadata.kv_flatten_size <= CUDA_GRAPH_PREFILL_KVLEN # noqa def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) -> BuffType: @@ -78,6 +84,9 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, input_buffers['q_start_loc'] = input_buffers['qkv_lens'][0] input_buffers['q_seqlens'] = input_buffers['qkv_lens'][1] input_buffers['kv_seqlens'] = input_buffers['qkv_lens'][2] + input_buffers['kv_start_loc'] = torch.zeros(max_batches, + dtype=torch.int64, + device=device) input_buffers['local_adapter_ids'] = torch.zeros(max_batches, dtype=torch.int64, device=device) @@ -100,6 +109,7 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, q_start_loc: Tensor = attn_metadata.q_start_loc q_seqlens: Tensor = attn_metadata.q_seqlens kv_seqlens: Tensor = attn_metadata.kv_seqlens + kv_start_loc: Tensor = attn_metadata.kv_start_loc input_buffers: BuffType = graph_meta.input_buffers batch_size, num_blocks = block_offsets.size() @@ -121,9 +131,26 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_buffers['inputs_embeds'] = inputs_embeds.new_zeros( 1, max_num_tokens, emb_size) input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds + new_batch_size = next_power_of_2(batch_size) + if attn_metadata.medusa_attn_mask is not None: + medusa_attn_mask = attn_metadata.medusa_attn_mask + if 'medusa_attn_mask' not in input_buffers: + input_buffers['medusa_attn_mask'] = medusa_attn_mask.new_zeros( + graph_meta.max_batchs, 64, CUDA_GRAPH_PREFILL_KVLEN) + input_buffers[ + 'medusa_attn_mask'][:batch_size, :, :medusa_attn_mask. + shape[-1]] = medusa_attn_mask[:batch_size] + attn_metadata.medusa_attn_mask = input_buffers[ + 'medusa_attn_mask'][:new_batch_size] + if kv_start_loc is not None: + input_buffers['kv_start_loc'][:batch_size] = kv_start_loc + attn_metadata.kv_start_loc = input_buffers[ + 'kv_start_loc'][:new_batch_size] # create inputs - new_batch_size = next_power_of_2(batch_size) + if attn_metadata.kv_flatten_size is not None: + attn_metadata.kv_flatten_size = max(attn_metadata.kv_flatten_size, + CUDA_GRAPH_PREFILL_KVLEN) attn_metadata.block_offsets = input_buffers[ 'block_offsets'][:new_batch_size] attn_metadata.q_start_loc = input_buffers[ From 20f426c9899512f9945ac8fdba89a299ed5b2b2e Mon Sep 17 00:00:00 2001 From: AllentDan Date: Mon, 6 Jan 2025 10:21:17 +0800 Subject: [PATCH 19/21] fix cudagraph batch error --- lmdeploy/pytorch/backends/cuda/graph_runner.py | 5 ++++- lmdeploy/pytorch/models/utils/cudagraph.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 7b27ccaef6..638177e32e 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -140,7 +140,10 @@ def get_graph_key(self, input_ids: torch.Tensor, num_tokens = input_ids.numel() new_num_tokens = next_power_of_2(num_tokens) mask_input = attn_metadata.medusa_attn_mask is not None - return (new_num_tokens, is_decoding, mask_input) + seq_num = 0 + if is_decoding is False: + seq_num = next_power_of_2(attn_metadata.q_seqlens.shape[0]) + return (new_num_tokens, is_decoding, mask_input, seq_num) def __call__(self, **kwargs): """call.""" diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 3a7d2ec3d2..6973276d2b 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -54,7 +54,8 @@ def support_cuda_graph( seq_lens = input_ids.size(1) if attn_metadata.kv_flatten_size is None: return seq_lens <= CUDA_GRAPH_PREFILL_SEQLEN - return seq_lens <= CUDA_GRAPH_PREFILL_SEQLEN and attn_metadata.kv_flatten_size <= CUDA_GRAPH_PREFILL_KVLEN # noqa + enabled = seq_lens <= CUDA_GRAPH_PREFILL_SEQLEN and attn_metadata.kv_flatten_size <= CUDA_GRAPH_PREFILL_KVLEN # noqa + return enabled def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) -> BuffType: From 28f3d75b61a934a28ce50419be428aaf0b3cf0c3 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Mon, 6 Jan 2025 16:23:47 +0800 Subject: [PATCH 20/21] boundary check --- lmdeploy/pytorch/kernels/cuda/flashattention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index 645868ea87..cec5c44892 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -78,7 +78,9 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, k = _load_kv(k_ptrs, causal_mask, boundary_check=(1, )) if apply_mask: - attn_mask = tl.load(attn_mask_ptr) + attn_mask = tl.load(attn_mask_ptr, + boundary_check=(0, 1), + padding_option='zero') qk = tl.dot(q, k) if BLOCK_DK1 != 0: From 3fe5fe4403911d1cce9cbfb097dbac77f7385396 Mon Sep 17 00:00:00 2001 From: AllentDan Date: Mon, 13 Jan 2025 11:01:59 +0800 Subject: [PATCH 21/21] fix stop --- lmdeploy/pytorch/engine/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 19c8017b7a..908ed93f99 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -613,12 +613,12 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, continue eos_token_id = self.model_config.eos_token_id update_token, eos_stop = self.extract_tokens(token, eos_token_id) - stop = stop or eos_stop if stop: update_token = _EMPTY_TOKEN else: msg.num_new_tokens += len(update_token) msg.update_token_ids(update_token, model_meta=model_meta) + stop = stop or eos_stop if stop: msg.status = MessageStatus.STOPPED