diff --git a/lmdeploy/api.py b/lmdeploy/api.py index 42b7c6e4c1..4af1ed9d9f 100644 --- a/lmdeploy/api.py +++ b/lmdeploy/api.py @@ -11,6 +11,7 @@ def pipeline(model_path: str, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, + speculative_model: str = None, log_level: str = 'WARNING', max_log_len: int = None, **kwargs): @@ -33,6 +34,7 @@ def pipeline(model_path: str, config instance. Default to None. chat_template_config (ChatTemplateConfig): chat template configuration. Default to None. + speculative_model (str): the path of the draft model. log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG] max_log_len(int): Max number of prompt characters or prompt tokens @@ -86,6 +88,7 @@ def pipeline(model_path: str, backend=backend, backend_config=backend_config, chat_template_config=chat_template_config, + speculative_model=speculative_model, max_log_len=max_log_len, **kwargs) diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index a2083c6e64..48ee6b9978 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -43,6 +43,7 @@ def add_parser_convert(): help='The directory path of the model') ArgumentHelper.model_format(parser) ArgumentHelper.tp(parser) + ArgumentHelper.speculative_model(parser) # other args ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) @@ -121,6 +122,7 @@ def add_parser_chat(): # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.adapters(pt_group) + ArgumentHelper.speculative_model(pt_group) ArgumentHelper.device(pt_group) ArgumentHelper.eager_mode(pt_group) # common engine args @@ -270,6 +272,7 @@ def chat(args): quant_policy=args.quant_policy) run_chat(args.model_path, engine_config, + speculative_model=args.speculative_model, chat_template_config=chat_template_config) else: from lmdeploy.turbomind.chat import main as run_chat diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 5cf3453b7e..939d7a2f7b 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -159,6 +159,7 @@ def add_parser_api_server(): pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.adapters(pt_group) + ArgumentHelper.speculative_model(pt_group) ArgumentHelper.device(pt_group) ArgumentHelper.eager_mode(pt_group) @@ -338,6 +339,7 @@ def api_server(args): vision_config = VisionConfig(args.vision_max_batch_size) run_api_server(args.model_path, model_name=args.model_name, + speculative_model=args.speculative_model, backend=backend, backend_config=backend_config, chat_template_config=chat_template_config, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 4edf23d684..33d5d339cf 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -100,6 +100,15 @@ def model_name(parser): 'by the RESTful API `/v1/models`. If it is not specified, ' '`model_path` will be adopted') + @staticmethod + def speculative_model(parser): + """Add argument speculative_model to parser.""" + + return parser.add_argument('--speculative-model', + type=str, + default=None, + help='The path of the spculative model.') + @staticmethod def dtype(parser, default: str = 'auto'): return parser.add_argument( diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index f9227497f2..31546ae0e1 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -21,6 +21,7 @@ class TritonAttentionMetadata(AttentionMetadata): fill_seqlens: torch.Tensor = None quant_policy: Literal[0, 4, 8] = 0 kv_flatten_size: int = None + medusa_attn_mask: torch.Tensor = None def _cdiv(a, b): @@ -106,6 +107,9 @@ def forward( fill_seqlens = attn_metadata.fill_seqlens fill_max_q_seqlen = key.numel() // (key.size(-1) * key.size(-2)) fill_q_start_loc = fill_seqlens.cumsum(0) - fill_seqlens + attention_mask = None + if attn_metadata.medusa_attn_mask is not None: + attention_mask = attn_metadata.medusa_attn_mask # fill kv cache if key is not None and value is not None: @@ -167,6 +171,7 @@ def forward( flatten_k, flatten_v, attn_output, + attention_mask=attention_mask, q_start_loc=q_start_loc, q_seqlens=q_seqlens, kv_start_loc=kv_start_loc, diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 33b1c85280..638177e32e 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -139,7 +139,11 @@ def get_graph_key(self, input_ids: torch.Tensor, is_decoding = context.is_decoding num_tokens = input_ids.numel() new_num_tokens = next_power_of_2(num_tokens) - return (new_num_tokens, is_decoding) + mask_input = attn_metadata.medusa_attn_mask is not None + seq_num = 0 + if is_decoding is False: + seq_num = next_power_of_2(attn_metadata.q_seqlens.shape[0]) + return (new_num_tokens, is_decoding, mask_input, seq_num) def __call__(self, **kwargs): """call.""" diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index 7b2134aeef..cbc46352a5 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -116,6 +116,22 @@ def update_step_context(cls, step_context): if not step_context.is_decoding: kv_start_loc = kv_seqlens.cumsum(0) - kv_seqlens kv_flatten_size = kv_seqlens.sum().item() + if step_context.medusa_attn_mask is not None: + max_q_seqlen = q_seqlens.max() + max_kv_seqlen = kv_seqlens.max() + bs = q_seqlens.shape[0] + medusa_len = step_context.medusa_attn_mask.shape[-1] + dtype = step_context.medusa_attn_mask.dtype + device = step_context.medusa_attn_mask.device + attention_mask = torch.zeros((bs, max_q_seqlen, max_kv_seqlen), + dtype=dtype, + device=device) + masked_value = (1 - step_context.medusa_attn_mask) * (-1e30) + for i in range(bs): + attention_mask[i, :, kv_seqlens[i] - + medusa_len:kv_seqlens[i]] = masked_value # noqa + step_context.medusa_attn_mask = attention_mask + attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, @@ -125,6 +141,7 @@ def update_step_context(cls, step_context): kv_seqlens=kv_seqlens, kv_flatten_size=kv_flatten_size, quant_policy=step_context.kv_quant_policy, + medusa_attn_mask=step_context.medusa_attn_mask, ) cross_seqlens = step_context.cross_seqlens diff --git a/lmdeploy/pytorch/chat.py b/lmdeploy/pytorch/chat.py index 2b5ee85edc..3783b8255e 100644 --- a/lmdeploy/pytorch/chat.py +++ b/lmdeploy/pytorch/chat.py @@ -52,6 +52,7 @@ def _stop_words(stop_words: List[str], tokenizer: Tokenizer): def run_chat(model_path: str, engine_config: PytorchEngineConfig, + speculative_model: Optional[str] = None, gen_config: GenerationConfig = None, session_id: int = 1, trust_remote_code: bool = True, @@ -62,12 +63,14 @@ def run_chat(model_path: str, Args: model_path (str): the huggingface model path. engine_config (PytorchEngineConfig): Config of engine. + speculative_model (str): the path of the speculative model. gen_config (GenerationConfig): Config of generation. session_id (int): the identical id of a session. trust_remote_code (bool): trust remote code. """ from lmdeploy.pytorch.engine import Engine tm_model = Engine.from_pretrained(model_path, + speculative_model=speculative_model, engine_config=engine_config, trust_remote_code=trust_remote_code) tokenizer = tm_model.tokenizer diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 7783afd970..b2cdc304b7 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -112,6 +112,9 @@ class ModelConfig: hf_config: Any = None cogvlm_style: bool = False custom_module_map: Dict[str, setattr] = None + # medusa config + medusa_num_heads: int = None + medusa_num_layers: int = None def get_head_size(self): """get head size.""" @@ -134,12 +137,22 @@ def from_pretrained(cls, activations. Refer to `PyTorchEngineConfig` for details """ from transformers import AutoConfig - hf_config = AutoConfig.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=trust_remote_code) - if getattr(hf_config, 'model_type', None) in ['phi3']: - # phi3 + trust_remote_code leads to error when tp. + try: hf_config = AutoConfig.from_pretrained( - pretrained_model_name_or_path) + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code) + if getattr(hf_config, 'model_type', None) in ['phi3']: + # phi3 + trust_remote_code leads to error when tp. + hf_config = AutoConfig.from_pretrained( + pretrained_model_name_or_path) + except Exception as e: # noqa + from transformers import PretrainedConfig + hf_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code) + # medusa model config + if getattr(hf_config, 'medusa_num_heads', None) is not None: + setattr(hf_config, 'architectures', ['MedusaModel']) return cls.from_hf_config(hf_config, pretrained_model_name_or_path, dtype=dtype, diff --git a/lmdeploy/pytorch/configurations/medusa.py b/lmdeploy/pytorch/configurations/medusa.py new file mode 100644 index 0000000000..a4f705cd3f --- /dev/null +++ b/lmdeploy/pytorch/configurations/medusa.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.config import ModelConfig + +from .builder import AutoModelConfigBuilder + + +class MedusaModelConfigBuilder(AutoModelConfigBuilder): + + @classmethod + def condition(cls, hf_config): + """config.""" + return hf_config.architectures[0] == 'MedusaModel' + + @classmethod + def build(cls, hf_config, model_path: str = None, **kwargs): + """build.""" + from transformers import AutoConfig + base_config = AutoConfig.from_pretrained( + hf_config.base_model_name_or_path) + head_dim = base_config.hidden_size // base_config.num_attention_heads + # config is wrong + # https://huggingface.co/FasterDecoding/medusa-vicuna-7b-v1.3/blob/main/config.json#L3 + hf_config.medusa_num_heads = 5 + medusa_num_heads = hf_config.medusa_num_heads + medusa_num_layers = hf_config.medusa_num_layers + if getattr(hf_config, 'hidden_size', None) is None: + setattr(hf_config, 'hidden_size', base_config.hidden_size) + if getattr(hf_config, 'vocab_size', None) is None: + setattr(hf_config, 'vocab_size', base_config.vocab_size) + return ModelConfig( + hidden_size=base_config.hidden_size, + num_attention_heads=base_config.num_attention_heads, + num_layers=base_config.num_hidden_layers, + num_key_value_heads=base_config.num_key_value_heads, + bos_token_id=base_config.bos_token_id, + eos_token_id=base_config.eos_token_id, + head_dim=head_dim, + vocab_size=base_config.vocab_size, + hf_config=hf_config, + medusa_num_heads=medusa_num_heads, + medusa_num_layers=medusa_num_layers, + ) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index a674d609af..908ed93f99 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -3,7 +3,7 @@ import copy import os from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import numpy as np import torch @@ -105,12 +105,14 @@ class Engine: Args: model_path (str): The hugging face model path. engine_config (PytorchEngineConfig): The config of the Engine. + speculative_model (str): The path of the speculative model. trust_remote_code (bool): Trust remote code. """ def __init__(self, model_path: str, engine_config: PytorchEngineConfig = None, + speculative_model: Optional[str] = None, trust_remote_code: bool = True) -> None: if engine_config is None: engine_config = PytorchEngineConfig() @@ -150,6 +152,7 @@ def __init__(self, model_path, cache_config=cache_config, backend_config=backend_config, + speculative_model=speculative_model, trust_remote_code=trust_remote_code, adapters=adapters, tp=self.tp, @@ -179,6 +182,7 @@ def __init__(self, def from_pretrained(cls, pretrained_model_name_or_path: str, engine_config: PytorchEngineConfig = None, + speculative_model: Optional[str] = None, trust_remote_code: bool = True, **kwargs): """lmdeploy python inference engine. @@ -195,12 +199,14 @@ def from_pretrained(cls, "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. engine_config (PytorchEngineConfig): Pytorch engine config. + speculative_model (str): The path of the speculative model. trust_remote_code (bool): Trust remote code """ if len(kwargs) > 0: logger.debug(f'Get unexpected kwargs: {kwargs}') return cls(model_path=pretrained_model_name_or_path, engine_config=engine_config, + speculative_model=speculative_model, trust_remote_code=trust_remote_code) @property @@ -546,7 +552,10 @@ def _batch_stopping_criteria(self, token_ids: torch.Tensor, # one more step to cache last token(stop word) stopped = num_appendable_ids < 0 if stop_words is not None: - sw_stopped = (token_ids[:, None] == stop_words).any(1) + if len(token_ids.shape) == 1: + token_ids = token_ids[:, None] + # TODO speculative model supports multiple stop word + sw_stopped = (token_ids == stop_words).any(1) one_ids = torch.clamp_max(num_appendable_ids, 0) num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) @@ -579,6 +588,17 @@ def __get_last_logits(): return next_token_ids + def extract_tokens(self, token_ids, eos_token_ids): + """Token list containing eos.""" + if not isinstance(token_ids, np.ndarray): + return [token_ids], token_ids in eos_token_ids + for i, token_id in enumerate(token_ids): + if token_id in eos_token_ids: + return token_ids[:i], True + if token_id == -1: + return token_ids[:i], False + return token_ids, False + @logging_timer('UpdateRunning', logger) def update_running(self, running: SeqList, next_token_ids: torch.Tensor, stopped: torch.Tensor, model_metas: List[Dict[str, @@ -591,12 +611,14 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, stopped, model_metas): if msg.status != MessageStatus.RUNNING: continue - update_token = token + eos_token_id = self.model_config.eos_token_id + update_token, eos_stop = self.extract_tokens(token, eos_token_id) if stop: update_token = _EMPTY_TOKEN else: - msg.num_new_tokens += 1 + msg.num_new_tokens += len(update_token) msg.update_token_ids(update_token, model_meta=model_meta) + stop = stop or eos_stop if stop: msg.status = MessageStatus.STOPPED @@ -680,14 +702,24 @@ async def __long_context_single_forward(inputs): if not return_logits and not inputs.is_decoding: last_token_loc = inputs.seq_length.cumsum(0) - 1 ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] + if 'spec_hidden_states' in ret: + ret['spec_hidden_states'] = ret[ + 'spec_hidden_states'][:, last_token_loc] else: ret = await __long_context_single_forward(inputs) if not return_logits and not inputs.is_decoding: last_token_loc = [-1] ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] + if 'spec_hidden_states' in ret: + ret['spec_hidden_states'] = ret[ + 'spec_hidden_states'][:, last_token_loc] hidden_states = ret.pop('hidden_states') + spec_hidden_states = ret.pop('spec_hidden_states', None) logits = self.model_agent.get_logits(hidden_states) + if spec_hidden_states is not None: + spec_logits = self.model_agent.get_spec_logits(spec_hidden_states) + ret['spec_logits'] = spec_logits ret['logits'] = logits return ret @@ -797,6 +829,42 @@ def __update_inputs(next_token_ids): logits, all_ids, guided_input_ids, sampling_inputs, inputs, num_ignore_eos > 0) num_ignore_eos = num_ignore_eos - 1 + if 'spec_logits' in output: + spec_logits = output['spec_logits'] + (cart_candidates, tree_candidates, medusa_attn_mask, + medusa_position_ids, + retrieve_indices) = self.model_agent.generate_candidates( + spec_logits, next_token_ids) + bs, _, tree_decode_len = tree_candidates.shape + spec_inputs = copy.deepcopy(inputs) + spec_inputs.input_ids = tree_candidates.flatten().unsqueeze(0) + spec_inputs.history_lengths += spec_inputs.seq_length + spec_inputs.seq_length = torch.ones_like( + spec_inputs.seq_length) * tree_decode_len + spec_inputs.medusa_attn_mask = medusa_attn_mask + spec_inputs.medusa_position_ids = medusa_position_ids + logits = await self.model_agent.tree_decoding( + spec_inputs, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map, + retrieve_indices=retrieve_indices) + # NOTE currently only greedy sampling supported + # besides, we used the bonus token id predicted during + # tree decoding while original Medusa did not + proposal_len = cart_candidates.shape[-1] + greedy_token_ids = logits.argmax(-1) + posterior_mask = cart_candidates[..., 1:] == greedy_token_ids[ + ..., :-1] + accept_len, best_idx = torch.cumprod(posterior_mask, + dim=-1).sum(-1).max(-1) + greedy_token_ids = greedy_token_ids[torch.arange(bs), best_idx] + next_token_ids = torch.cat( + [next_token_ids[:, None], greedy_token_ids], -1) + mask_idx = torch.arange( + proposal_len + 1, + device=next_token_ids.device).expand_as(next_token_ids) + next_token_ids[mask_idx > (accept_len[:, None] + 1)] = -1 + num_appendable_ids = num_appendable_ids - accept_len - 2 # stopping criteria stopped, num_appendable_ids = self._batch_stopping_criteria( @@ -1049,6 +1117,8 @@ async def __step(): schedule_output = self.scheduler.schedule( is_prefill=prefill, prealloc_size=prefill_interval) + if self.model_agent.speculative_model is not None: + prefill = True in_que.put_nowait((prefill, schedule_output)) finish = False while not finish: diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 5487639d29..9b98dfb1fe 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -204,6 +204,8 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, backend_config: BackendConfig, + speculative_model: str = None, + speculative_model_config: ModelConfig = None, adapters: Dict[str, str] = None, trust_remote_code: bool = True): super().__init__(model_config=model_config, cache_config=cache_config) @@ -212,6 +214,7 @@ def __init__(self, self._adapters = adapters self.patched_model = self._build_model(model_path, + self.model_config, adapters, device=device) @@ -224,6 +227,13 @@ def __init__(self, cache_config=cache_config, backend_config=backend_config, device=device) + self.speculative_model = None + if speculative_model is not None: + self.speculative_model_config = speculative_model_config + self.speculative_model = self._build_model( + speculative_model, + self.speculative_model_config, + device=device) self.cache_engine = CacheEngine(cache_config, model_config) @@ -231,21 +241,22 @@ def __init__(self, def _build_model(self, model_path: str, + model_config: ModelConfig, adapters: Dict[str, str] = None, device: torch.device = 'cuda'): """build patched model.""" - custom_module_map = self.model_config.custom_module_map + custom_module_map = model_config.custom_module_map if custom_module_map is not None: update_custom_module_map(custom_module_map) logger.info('build model.') - patched_model = build_patched_model(self.model_config, device=device) + patched_model = build_patched_model(model_config, device=device) logger.info('loading weights.') load_model_weights(patched_model, model_path, device=device) logger.info('loading adapters.') if adapters is not None: add_adapters(patched_model, adapters, - dtype=self.model_config.dtype, + dtype=model_config.dtype, device=device) return patched_model @@ -261,6 +272,16 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, world_size=1, stream=self.stream, ) + if self.speculative_model is not None: + inputs.last_hidden_states = output['hidden_states'] + spec_outputs = model_forward( + self.speculative_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + output['spec_hidden_states'] = spec_outputs['hidden_states'] return output async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, @@ -278,10 +299,72 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, await asyncio.sleep(0) return output + async def score_proposal(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap, num_speculative_tokens): + """score the proposal. + + Args: + inputs (Dict): The input data comes from _make_inputs. + swap_in_map (SwapMap): Cache maps to swap in. + swap_out_map (SwapMap): Cache maps to swap out. + num_speculative_tokens (int): The number of the proposal tokens. + """ + cache_swapping(self.cache_engine, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map) + spec_outputs = model_forward( + self.patched_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + await asyncio.get_event_loop().run_in_executor(None, + self.stream.synchronize) + hidden_states = spec_outputs['hidden_states'] + hidden_states = hidden_states.reshape( + [-1, num_speculative_tokens + 1, hidden_states.shape[-1]]) + logits = self.get_logits(hidden_states) + return logits + + async def tree_decoding(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap, + retrieve_indices: torch.Tensor): + cache_swapping(self.cache_engine, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map) + bs = inputs.history_lengths.shape[0] + inputs.medusa_position_ids = inputs.medusa_position_ids.repeat( + inputs.history_lengths.shape[0], 1) + inputs.medusa_position_ids = inputs.medusa_position_ids.to( + inputs.history_lengths.device) + inputs.history_lengths[:, None] + spec_outputs = model_forward( + self.patched_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + await asyncio.get_event_loop().run_in_executor(None, + self.stream.synchronize) + hidden_states = spec_outputs['hidden_states'] + hidden_states = hidden_states.reshape(bs, -1, hidden_states.shape[-1]) + logits = self.get_logits(hidden_states)[:, retrieve_indices] + return logits + + def generate_candidates(self, draft_logits: torch.Tensor, + base_token_id: torch.Tensor): + return self.speculative_model.generate_candidates( + draft_logits, base_token_id) + def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" return self.patched_model.get_logits(hidden_states) + def get_spec_logits(self, hidden_states_list: List[torch.Tensor]): + """get logits of model output.""" + return self.speculative_model.get_logits(hidden_states_list) + def get_input_processor(self): """get input processor..""" return self.patched_model.get_input_processor() @@ -543,6 +626,8 @@ def __init__(self, backend_config: BackendConfig, world_size: int, adapters: Dict[str, str] = None, + speculative_model: str = None, + speculative_model_config: ModelConfig = None, trust_remote_code: bool = True) -> None: import signal @@ -575,6 +660,13 @@ def __signal_term_handler(sig, frame): world_size=world_size, barrier=self.mp_bar) + self.speculative_model = None + if speculative_model is not None: + self.speculative_model_config = speculative_model_config + self.speculative_model = self._build_speculative_model( + speculative_model, + self.speculative_model_config, + world_size=world_size) model, cache_engine, cache_config = self._build_model( model_path=model_path, model_config=model_config, @@ -685,6 +777,35 @@ def _build_model( return model, cache_engine, cache_config + @torch.inference_mode() + def _build_speculative_model( + self, + model_path: str, + model_config: ModelConfig, + world_size: int, + cache_config: CacheConfig = None, + backend_config: BackendConfig = None, + ): + """build model. + + Currently, cache engine and backend config not used. + """ + with get_dist_manager().context(self._dist_ctx): + rank = 0 + device_map = torch.device('cuda') + + custom_module_map = model_config.custom_module_map + if custom_module_map is not None: + update_custom_module_map(custom_module_map) + if rank == 0: + logger.info('build model.') + patched_model = build_patched_model(model_config, + device=device_map) + if rank == 0: + logger.info('loading weights.') + load_model_weights(patched_model, model_path, device=device_map) + return patched_model + def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """forward impl.""" @@ -704,8 +825,50 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, world_size=1, stream=self.stream, ) + if self.speculative_model is not None: + inputs.last_hidden_states = output['hidden_states'] + spec_outputs = model_forward( + self.speculative_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + output['spec_hidden_states'] = spec_outputs['hidden_states'] return output + async def score_proposal(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap, num_speculative_tokens): + """model forward. + + Args: + inputs (Dict): The input data comes from _make_inputs. + swap_in_map (SwapMap): Cache maps to swap in. + swap_out_map (SwapMap): Cache maps to swap out. + num_speculative_tokens (int): The number of the proposal tokens. + """ + with get_dist_manager().context(self._dist_ctx): + self.mp_bar.wait() + rank = 0 + _broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map], + self.stream) + cache_swapping(self.cache_engine, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map) + spec_outputs = model_forward( + self.patched_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + hidden_states = spec_outputs['hidden_states'] + hidden_states = hidden_states.reshape( + [-1, num_speculative_tokens + 1, hidden_states.shape[-1]]) + logits = self.get_logits(hidden_states) + self.stream.synchronize() + return logits + async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -721,10 +884,49 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, await asyncio.sleep(0) return output + async def tree_decoding(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap, + retrieve_indices: torch.Tensor): + bs = inputs.history_lengths.shape[0] + inputs.medusa_position_ids = inputs.medusa_position_ids.repeat( + inputs.history_lengths.shape[0], 1) + inputs.medusa_position_ids = inputs.medusa_position_ids.to( + inputs.history_lengths.device) + inputs.history_lengths[:, None] + with get_dist_manager().context(self._dist_ctx): + self.mp_bar.wait() + rank = 0 + _broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map], + self.stream) + cache_swapping(self.cache_engine, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map) + spec_outputs = model_forward( + self.patched_model, + inputs, + self.cache_engine, + world_size=1, + stream=self.stream, + ) + hidden_states = spec_outputs['hidden_states'] + hidden_states = hidden_states.reshape(bs, -1, + hidden_states.shape[-1]) + logits = self.get_logits(hidden_states)[:, retrieve_indices] + self.stream.synchronize() + return logits + + def generate_candidates(self, draft_logits: torch.Tensor, + base_token_id: torch.Tensor): + return self.speculative_model.generate_candidates( + draft_logits, base_token_id) + def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" return self.patched_model.get_logits(hidden_states) + def get_spec_logits(self, hidden_states_list: List[torch.Tensor]): + """get logits of model output.""" + return self.speculative_model.get_logits(hidden_states_list) + def get_input_processor(self): """get input processor..""" return self.patched_model.get_input_processor() @@ -739,6 +941,7 @@ def build_model_agent(model_path: str, cache_config: CacheConfig, backend_config: BackendConfig, trust_remote_code: bool, + speculative_model: str = None, adapters: Dict[str, str] = None, tp: int = 1, dtype: str = 'auto', @@ -750,6 +953,7 @@ def build_model_agent(model_path: str, cache_config (CacheConfig): config of kv cache backend_config (BackendConfig): config of backend devices trust_remote_code (bool): To use the remote modeling code or not + speculative_model (str): The path of the speculative model adapters (Dict): lora adapters tp (int): the number of devices to be used in tensor parallelism dtype (str): the data type of model weights and activations @@ -758,19 +962,31 @@ def build_model_agent(model_path: str, model_config = ModelConfig.from_pretrained( model_path, trust_remote_code=trust_remote_code, dtype=dtype, tp=tp) model_config.custom_module_map = custom_module_map + speculative_model_config = None + if speculative_model is not None: + speculative_model_config = ModelConfig.from_pretrained( + speculative_model, + trust_remote_code=trust_remote_code, + dtype=dtype) if tp == 1: - model_agent = BaseModelAgent(model_path, - model_config=model_config, - cache_config=cache_config, - backend_config=backend_config, - adapters=adapters, - trust_remote_code=trust_remote_code) + model_agent = BaseModelAgent( + model_path, + model_config=model_config, + cache_config=cache_config, + backend_config=backend_config, + speculative_model=speculative_model, + speculative_model_config=speculative_model_config, + adapters=adapters, + trust_remote_code=trust_remote_code) else: - model_agent = TPModelAgent(model_path, - model_config=model_config, - cache_config=cache_config, - backend_config=backend_config, - world_size=tp, - adapters=adapters, - trust_remote_code=trust_remote_code) + model_agent = TPModelAgent( + model_path, + model_config=model_config, + cache_config=cache_config, + backend_config=backend_config, + speculative_model=speculative_model, + speculative_model_config=speculative_model_config, + world_size=tp, + adapters=adapters, + trust_remote_code=trust_remote_code) return model_agent diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index 3d07225e43..cec5c44892 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -61,11 +61,13 @@ def _load_kv(ptrs, causal_mask: tl.constexpr, boundary_check: tl.constexpr): @triton.jit def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, sm_scale, history_mask, - kv_min_loc, causal_mask: tl.constexpr, - window_size: tl.constexpr, + kv_min_loc, attn_mask_ptr, apply_mask: tl.constexpr, + causal_mask: tl.constexpr, window_size: tl.constexpr, logit_softcapping: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DK1: tl.constexpr): k_ptrs = tl.advance(k_ptrs, (0, loop_start)) + if apply_mask: + attn_mask_ptr = tl.advance(attn_mask_ptr, (0, loop_start)) v_ptrs = tl.advance(v_ptrs, (loop_start, 0)) if BLOCK_DK1: k1_ptrs = tl.advance(k1_ptrs, (0, loop_start)) @@ -75,6 +77,10 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, start_n = tl.multiple_of(start_n, BLOCK_N) k = _load_kv(k_ptrs, causal_mask, boundary_check=(1, )) + if apply_mask: + attn_mask = tl.load(attn_mask_ptr, + boundary_check=(0, 1), + padding_option='zero') qk = tl.dot(q, k) if BLOCK_DK1 != 0: @@ -83,6 +89,8 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, if causal_mask: qk *= sm_scale + if apply_mask: + qk = qk + attn_mask qk = softcapping(qk, logit_softcapping) qk = qk * tl_log2(math.e) qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :]) @@ -98,6 +106,8 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, qk -= m_i_new[:, None] elif window_size > 0: qk *= sm_scale + if apply_mask: + qk = qk + attn_mask qk = softcapping(qk, logit_softcapping) qk = qk * tl_log2(math.e) qk_mask = ((start_n + offs_n[None, :]) >= kv_min_loc[:, None]) @@ -110,14 +120,19 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, qk -= m_i_new[:, None] elif logit_softcapping > 0: qk *= sm_scale + if apply_mask: + qk = qk + attn_mask qk = softcapping(qk, logit_softcapping) qk = qk * tl_log2(math.e) m_i_new = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_i_new[:, None] else: - qk_scale = sm_scale * tl_log2(math.e) - m_i_new = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_i_new[:, None] + qk *= sm_scale + if apply_mask: + qk = qk + attn_mask + qk *= tl_log2(math.e) + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_i_new[:, None] # -- compute p, m_i and l_i p = tl_exp2(qk) @@ -136,6 +151,8 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, k_ptrs = tl.advance(k_ptrs, (0, BLOCK_N)) v_ptrs = tl.advance(v_ptrs, (BLOCK_N, 0)) + if apply_mask: + attn_mask_ptr = tl.advance(attn_mask_ptr, (0, BLOCK_N)) if BLOCK_DK1: k1_ptrs = tl.advance(k1_ptrs, (0, BLOCK_N)) @@ -168,6 +185,8 @@ def _flash_prefill_fwd_kernel( kv_start_loc_ptr, kv_seqlens_ptr, sm_scale, + attention_mask, + apply_mask: tl.constexpr, stride_qs: tl.constexpr, stride_qh: tl.constexpr, stride_qd: tl.constexpr, @@ -180,6 +199,9 @@ def _flash_prefill_fwd_kernel( stride_os: tl.constexpr, stride_oh: tl.constexpr, stride_od: tl.constexpr, + stride_amb, + stride_amqs, + stride_amkvs, kv_group_num, head_dim_k, head_dim_v, @@ -245,6 +267,18 @@ def _flash_prefill_fwd_kernel( block_shape=(BLOCK_N, BLOCK_DV), order=(1, 0), ) + if apply_mask: + attn_mask_ptrs = tl.make_block_ptr( + base=attention_mask + batch_id * stride_amb + + start_m * BLOCK_M * stride_amqs, + shape=(q_seqlen, kv_seqlen), + strides=(stride_amqs, stride_amkvs), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(0, 1), + ) + else: + attn_mask_ptrs = tl.full([BLOCK_M, BLOCK_N], 0, dtype=tl.int32) if BLOCK_DK1 != 0: offs_dk1 = BLOCK_DK + tl.arange(0, BLOCK_DK1) @@ -292,6 +326,8 @@ def _flash_prefill_fwd_kernel( sm_scale, history_mask, kv_min_loc, + attn_mask_ptrs, + apply_mask=apply_mask, causal_mask=False, window_size=window_size, logit_softcapping=logit_softcapping, @@ -316,6 +352,8 @@ def _flash_prefill_fwd_kernel( sm_scale, history_mask, kv_min_loc, + attn_mask_ptrs, + apply_mask=apply_mask, causal_mask=True, window_size=window_size, logit_softcapping=logit_softcapping, @@ -348,6 +386,7 @@ def flash_attention_fwd( q_seqlens: Tensor, kv_start_loc: Tensor, kv_seqlens: Tensor, + attention_mask: Tensor = None, max_seqlen: int = None, window_size: int = None, sm_scale: float = None, @@ -412,6 +451,10 @@ def grid(args): num_stages = 3 else: num_stages = 4 + apply_mask = True + if attention_mask is None: + apply_mask = False + attention_mask = q_states.new_empty((1, 1, 1)) _flash_prefill_fwd_kernel[grid]( q_states, k_states, @@ -422,6 +465,8 @@ def grid(args): kv_start_loc, kv_seqlens, sm_scale=sm_scale, + attention_mask=attention_mask, + apply_mask=apply_mask, stride_qs=q_states.stride(0), stride_qh=q_states.stride(1), stride_qd=q_states.stride(2), @@ -434,6 +479,9 @@ def grid(args): stride_os=o_states.stride(0), stride_oh=o_states.stride(1), stride_od=o_states.stride(2), + stride_amb=attention_mask.stride(0), + stride_amqs=attention_mask.stride(1), + stride_amkvs=attention_mask.stride(2), kv_group_num=kv_group_num, head_dim_k=head_dim_k, head_dim_v=head_dim_v, @@ -448,5 +496,4 @@ def grid(args): num_warps=num_warps, num_stages=num_stages, ) - return o_states diff --git a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py index 3a77164046..5f59ac4651 100644 --- a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py @@ -48,7 +48,8 @@ def _flatten_kv_cache( seqlen = tl.load(seqlens_ptr + batch_id) start_loc = tl.load(start_loc_ptr + batch_id) # fill last block to prevent attention nan - if batch_id == num_batches - 1: + # seqlen>0 to filter cuda graph padding + if batch_id == num_batches - 1 and seqlen > 0: seqlen = OUT_SIZE - start_loc if page_id * BLOCK_BS >= seqlen: return @@ -140,7 +141,7 @@ def _flatten_kv_cache_quant( seqlen = tl.load(seqlens_ptr + batch_id) start_loc = tl.load(start_loc_ptr + batch_id) - if batch_id == num_batches - 1: + if batch_id == num_batches - 1 and seqlen > 0: seqlen = OUT_SIZE - start_loc if page_id * BLOCK_BS >= seqlen: return diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index d10da8557a..e984e39abe 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -136,6 +136,9 @@ class ModelInputs: vision_inputs: VisionModelInputs = None cross_length: torch.LongTensor = None history_cross_length: torch.LongTensor = None + last_hidden_states: torch.Tensor = None + medusa_attn_mask: torch.Tensor = None + medusa_position_ids: torch.Tensor = None model_metas: List[Dict[str, Any]] = None def update(self, input_ids: torch.LongTensor): @@ -293,6 +296,8 @@ class StepContext: cross_kv_seqlens: torch.LongTensor = None cross_attn_metadata: Any = None kv_quant_policy: Literal[0, 4, 8] = 0 + last_hidden_states: torch.Tensor = None + medusa_attn_mask: torch.Tensor = None model_metas: List[Dict[str, Any]] = None _outputs: Dict = field(default_factory=dict) @@ -328,12 +333,14 @@ def new( input_embeddings, input_embedding_indexing = \ inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens) + # for speculative decoding + last_hidden_states = inputs.last_hidden_states # kv_seqlens if inputs.is_decoding: attention_mask = torch.ones_like(q_seqlens)[:, None] position_ids = history_seqlens.unsqueeze(-1).clone() else: - max_q_seqlen = q_seqlens.max().item() + max_q_seqlen = q_seqlens.contiguous().max().item() mask_range = torch.arange(max_q_seqlen, device=device)[None, :] attention_mask = (mask_range < q_seqlens[:, None]).long() position_ids = attention_mask.long().cumsum(-1) - 1 @@ -352,6 +359,9 @@ def new( # seq_len + history_length kv_seqlens = q_seqlens + history_seqlens kv_seqlens -= inputs.num_ignored_history + # medusa + if inputs.medusa_position_ids is not None: + position_ids = inputs.medusa_position_ids.reshape(1, -1) ret = StepContext( input_ids=inputs.input_ids, @@ -370,6 +380,8 @@ def new( world_size=world_size, local_adapter_ids=inputs.local_adapter_ids, vision_inputs=inputs.vision_inputs, + last_hidden_states=last_hidden_states, + medusa_attn_mask=inputs.medusa_attn_mask, kv_quant_policy=kv_quant_policy, model_metas=inputs.model_metas, cross_seqlens=cross_seqlens, diff --git a/lmdeploy/pytorch/models/medusa.py b/lmdeploy/pytorch/models/medusa.py new file mode 100644 index 0000000000..28da3bad55 --- /dev/null +++ b/lmdeploy/pytorch/models/medusa.py @@ -0,0 +1,400 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.models.llama import LlamaConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn.linear import build_rowwise_linear +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin + +vicuna_7b_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (0, 0, 0), (1, 0), (2, ), + (0, 2), (0, 0, 1), (0, 3), (3, ), (0, 1, 0), (2, 0), (4, ), + (0, 0, 2), (0, 4), (1, 1), (1, 0, 0), (0, 0, 0, 0), (5, ), + (0, 0, 3), (0, 5), (0, 2, 0), (3, 0), (0, 1, 1), (0, 6), + (6, ), (0, 7), (0, 0, 4), (4, 0), (1, 2), (0, 8), (7, ), + (0, 3, 0), (0, 0, 0, 1), (0, 0, 5), (2, 1), (0, 0, 6), + (1, 0, 1), (0, 0, 1, 0), (2, 0, 0), (5, 0), (0, 9), + (0, 1, 2), (8, ), (0, 4, 0), (0, 2, 1), (1, 3), (0, 0, 7), + (0, 0, 0, 2), (0, 0, 8), (1, 1, 0), (0, 1, 0, 0), (6, 0), + (9, ), (0, 1, 3), (0, 0, 0, 3), (1, 0, 2), (0, 5, 0), + (3, 1), (0, 0, 2, 0), (7, 0), (1, 4)] + +vicuna_13b_stage2 = [(0, ), (0, 0), (1, ), (0, 0, 0), (0, 1), (1, 0), (2, ), + (0, 2), (0, 0, 1), (0, 1, 0), (3, ), (0, 3), (2, 0), + (0, 0, 2), (0, 0, 0, 0), (0, 4), (1, 0, 0), (1, 1), (4, ), + (0, 0, 3), (0, 5), (0, 2, 0), (5, ), (3, 0), (0, 1, 1), + (0, 6), (0, 0, 4), (0, 0, 0, 1), + (0, 7), (0, 0, 5), (1, 2), (0, 0, 1, 0), (0, 3, 0), + (1, 0, 1), (4, 0), (0, 0, 6), (0, 8), (2, 0, 0), (0, 9), + (6, ), (7, ), (2, 1), (5, 0), (0, 1, 2), (0, 0, 0, 2), + (8, ), (0, 4, 0), (0, 1, 0, 0), (0, 2, 1), (0, 0, 7), + (1, 1, 0), (1, 3), (0, 0, 2, 0), (9, ), (0, 0, 8), + (0, 5, 0), (0, 0, 0, 3), (0, 0, 9), (0, 1, 3), (1, 0, 2), + (0, 0, 1, 1), (3, 0, 0), (1, 0, 0, 0)] + +vicuna_33b_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (0, 0, 0), (1, 0), (2, ), + (0, 2), (0, 0, 1), (0, 3), (3, ), + (0, 1, 0), (2, 0), (0, 4), (4, ), (0, 0, 2), (1, 1), + (1, 0, 0), (0, 5), (5, ), (0, 0, 0, 0), (0, 0, 3), (3, 0), + (0, 2, 0), (0, 6), (0, 1, 1), (6, ), (0, 0, 4), (0, 7), + (7, ), (1, 2), (4, 0), (8, ), (0, 3, 0), (0, 0, 5), + (0, 0, 0, 1), (0, 8), (2, 1), (0, 9), (1, 0, 1), + (2, 0, 0), (0, 0, 6), (5, 0), (0, 0, 1, 0), (1, 3), + (0, 1, 2), (0, 4, 0), (0, 0, 7), (0, 2, 1), (9, ), + (1, 1, 0), (0, 0, 0, 2), (6, 0), (0, 0, 8), (0, 1, 0, 0), + (7, 0), (0, 1, 3), (0, 5, 0), (1, 4), (0, 0, 9), (3, 1), + (1, 0, 2), (2, 2)] + +zephyr_stage2 = [(0, ), (0, 0), (1, ), (0, 1), (2, ), + (0, 0, 0), (1, 0), (0, 2), (3, ), (0, 3), (4, ), (2, 0), + (0, 0, 1), (0, 4), (5, ), (0, 5), (0, 1, 0), (1, 1), (6, ), + (0, 0, 2), (3, 0), (0, 6), (7, ), (0, 7), (0, 8), (0, 0, 3), + (1, 0, 0), (0, 9), (0, 2, 0), (1, 2), (4, 0), (8, ), (9, ), + (2, 1), (0, 1, 1), (0, 0, 4), (0, 0, 0, 0), (5, 0), (0, 3, 0), + (1, 3), (0, 0, 5), (0, 0, 6), (6, 0), (2, 0, 0), (1, 0, 1), + (0, 1, 2), (0, 4, 0), (1, 4), (3, 1), (2, 2), (0, 0, 7), + (7, 0), (0, 2, 1), (0, 0, 8), (0, 1, 3), (0, 5, 0), (1, 5), + (0, 0, 9), (1, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0), (4, 1), + (2, 3)] +mc_sim_7b_63 = [[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], + [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, + 6], [6], + [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], + [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], + [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], + [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], + [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], + [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], + [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], + [0, 7, 0]] + +TOPK = 10 + + +def pad_path(path, length, pad_value=-2): + """Pad the given path list with a specific value up to a specified length. + + Args: + path (list): The original list that needs padding. + length (int): The desired length of the padded list. + pad_value (optional, default=-2): The value to use for padding. + + Returns: + list: A new list based on the original path but padded to the desired + length. + + Example: + >>> pad_path([1,2,3], 5) + [1, 2, 3, -2, -2] + + Note: + If the given path is already longer than the specified length, + then no padding occurs, and the original path is returned. + """ + + # Calculate the number of padding values needed by subtracting the length + # of the path from the desired length. + # Append the padding values to the original path and return the new list. + return path + [pad_value] * (length - len(path)) + + +class ResBlock(nn.Module): + """A Residual Block module. + + This module performs a linear transformation followed by a SiLU activation, + and then adds the result to the original input, creating a residual + connection. + + Args: + hidden_size (int): The size of the hidden layers in the block. + """ + + def __init__(self, + hidden_size, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.linear = build_rowwise_linear(hidden_size, + hidden_size, + bias=True, + dtype=dtype, + device=device) + # Use SiLU activation to keep consistent with the Llama model + self.act = nn.SiLU() + + def forward(self, x): + """Forward pass of the ResBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output after the residual connection and activation. + """ + return x + self.act(self.linear(x)) + + +class MedusaModel(nn.Module, CudaGraphMixin, DeployModelMixin): + """The medusa model architecture.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: LlamaConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build medusa + self.medusa_head = nn.ModuleList([ + nn.Sequential( + *([ + ResBlock( + self.config.hidden_size, device=device, dtype=dtype) + ] * self.config.medusa_num_layers), + build_rowwise_linear(self.config.hidden_size, + self.config.vocab_size, + bias=False, + dtype=dtype, + device=device), + ) for _ in range(self.config.medusa_num_heads) + ]) + self.medusa_choices = None + if 'vicuna-7b' in config.base_model_name_or_path: + self.medusa_choices = vicuna_7b_stage2 + elif 'vicuna-13b' in config.base_model_name_or_path: + self.medusa_choices = vicuna_13b_stage2 + elif 'vicuna-33b' in config.base_model_name_or_path: + self.medusa_choices = vicuna_33b_stage2 + elif 'zephyr' in config.base_model_name_or_path: + self.medusa_choices = zephyr_stage2 + else: + self.medusa_choices = mc_sim_7b_63 + self.generate_medusa_buffers(device=device) + + def generate_medusa_buffers(self, device: torch.dtype = None): + """Generate buffers for the Medusa structure based on the provided + choices. + + Args: + medusa_choices (list): A nested list representing tree in the + Medusa structure. + device (str): Device to which the tensors should be moved. + Default is "cuda". + + Returns: + dict: A dictionary containing buffers related to the + Medusa structure. + """ + if self.medusa_choices is None: + self.medusa_attn_mask = None + self.tree_indices = None + self.medusa_position_ids = None + self.retrieve_indices = None + return + + # Sort the medusa_choices based on their lengths and then their values + sorted_medusa_choices = sorted(self.medusa_choices, + key=lambda x: (len(x), x)) + medusa_len = len(sorted_medusa_choices) + 1 + + # Initialize depth_counts to keep track of how many choices have a + # particular depth + depth_counts = [] + prev_depth = 0 + for path in sorted_medusa_choices: + depth = len(path) + if depth != prev_depth: + depth_counts.append(0) + depth_counts[depth - 1] += 1 + prev_depth = depth + + # Create the attention mask for Medusa + medusa_attn_mask = torch.eye(medusa_len, medusa_len) + medusa_attn_mask[:, 0] = 1 + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_medusa_choice = sorted_medusa_choices[start + j] + # retrieve ancestor position + if len(cur_medusa_choice) == 1: + continue + ancestor_idx = [] + for c in range(len(cur_medusa_choice) - 1): + ancestor_idx.append( + sorted_medusa_choices.index(cur_medusa_choice[:c + + 1]) + 1) + medusa_attn_mask[j + start + 1, ancestor_idx] = 1 + start += depth_counts[i] + + # Generate tree indices for the Medusa structure + medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long) + medusa_tree_indices[0] = 0 + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_medusa_choice = sorted_medusa_choices[start + j] + medusa_tree_indices[start + j + + 1] = cur_medusa_choice[-1] + TOPK * i + 1 + start += depth_counts[i] + + # Generate position IDs for the Medusa structure + medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long) + start = 0 + for i in range(len(depth_counts)): + medusa_position_ids[start + 1:start + depth_counts[i] + 1] = i + 1 + start += depth_counts[i] + + # Generate retrieval indices for Medusa structure verification + retrieve_indices_nest = [] + retrieve_paths = [] + for i in range(len(sorted_medusa_choices)): + cur_medusa_choice = sorted_medusa_choices[-i - 1] + retrieve_indice = [] + if cur_medusa_choice in retrieve_paths: + continue + else: + for c in range(len(cur_medusa_choice)): + retrieve_indice.append( + sorted_medusa_choices.index(cur_medusa_choice[:c + 1])) + retrieve_paths.append(cur_medusa_choice[:c + 1]) + retrieve_indices_nest.append(retrieve_indice) + max_length = max([len(x) for x in retrieve_indices_nest]) + retrieve_indices = [ + pad_path(path, max_length) for path in retrieve_indices_nest + ] + retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) + retrieve_indices = retrieve_indices + 1 + retrieve_indices = torch.cat([ + torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), + retrieve_indices + ], + dim=1) + self.medusa_attn_mask = medusa_attn_mask.unsqueeze(0).unsqueeze(0).to( + device) + self.tree_indices = medusa_tree_indices.to(device) + self.medusa_position_ids = medusa_position_ids.to(device) + self.retrieve_indices = retrieve_indices.to(device) + + def generate_candidates(self, medusa_logits: torch.Tensor, + base_token_id: torch.Tensor): + """Generate candidates based on provided logits and indices. + + Args: + medusa_logits (torch.Tensor): Logits from a specialized Medusa + structure, aiding in candidate selection. Shape + [bs, speculative_num, vocab_size] + base_token_id (torch.Tensor): Standard logits from a language + model. Shape [bs] + + Returns: + tuple (torch.Tensor, torch.Tensor): A tuple containing two sets of candidates: + 1. Cartesian candidates derived from the combined original and Medusa logits. + 2. Tree candidates mapped from the Cartesian candidates using tree indices. + """ # noqa + # Greedy decoding: Select the most probable candidate from the original + # logits. here we only implement greedy decoding + bs = medusa_logits.shape[0] + candidates_logit = base_token_id.unsqueeze(-1) + # Extract the TOPK candidates from the medusa logits. + candidates_medusa_logits = torch.topk(medusa_logits, TOPK, + dim=-1).indices + + # Combine the selected candidate from the original logits with the + # topk medusa logits. + candidates = torch.cat( + [candidates_logit, + candidates_medusa_logits.view(bs, -1)], dim=-1) + + # Map the combined candidates to the tree indices to get tree + # candidates. + tree_candidates = candidates[:, self.tree_indices] + + # Extend the tree candidates by appending a zero. + tree_candidates_ext = torch.cat([ + tree_candidates, + torch.zeros( + (bs, 1), dtype=torch.long, device=tree_candidates.device) + ], + dim=-1) + + # Retrieve the cartesian candidates using the retrieve indices. + cart_candidates = tree_candidates_ext[:, self.retrieve_indices] + + # Unsqueeze the tree candidates for dimension consistency. + tree_candidates = tree_candidates.unsqueeze( + 1) # bs, 1, len(self.medusa_choices) + return (cart_candidates, tree_candidates, self.medusa_attn_mask, + self.medusa_position_ids, self.retrieve_indices) + + def support_cuda_graph( + self, + *args, + **kwargs, + ): + """support cudagraph.""" + return True + + def forward(self, last_hidden_states: torch.Tensor, + **kwargs) -> List[torch.Tensor]: + outputs = [head[0](last_hidden_states) for head in self.medusa_head] + outputs = torch.cat(outputs, 0) + return outputs + + def get_logits(self, hidden_states: List[torch.Tensor]): + """compute logits of the model output.""" + outputs = [] + for medusa_head, hidden_state in zip(self.medusa_head, hidden_states): + outputs.append(medusa_head[-1](hidden_state)) + outputs = torch.stack(outputs, 1) + return outputs + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + **kwargs, + ): + """prepare input.""" + return dict(last_hidden_states=context.last_hidden_states) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + name = 'medusa_head.' + name + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index c1b62736f7..b47ff77b3a 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -187,4 +187,9 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mllama.MllamaForConditionalGeneration', }) +# medusa +MODULE_MAP.update({ + 'MedusaModel': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.medusa.MedusaModel', +}) CUSTOM_MODULE_MAP = dict() diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 74d090a9a3..6973276d2b 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -8,6 +8,8 @@ from lmdeploy.pytorch.model_inputs import StepContext BuffType = Dict[str, Tensor] +CUDA_GRAPH_PREFILL_SEQLEN = 256 +CUDA_GRAPH_PREFILL_KVLEN = 4096 def next_power_of_2(n: int): @@ -48,7 +50,12 @@ def support_cuda_graph( **kwargs, ): """return True is model support cudagraph.""" - return attn_metadata.is_decoding + # return attn_metadata.is_decoding + seq_lens = input_ids.size(1) + if attn_metadata.kv_flatten_size is None: + return seq_lens <= CUDA_GRAPH_PREFILL_SEQLEN + enabled = seq_lens <= CUDA_GRAPH_PREFILL_SEQLEN and attn_metadata.kv_flatten_size <= CUDA_GRAPH_PREFILL_KVLEN # noqa + return enabled def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) -> BuffType: @@ -78,6 +85,9 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, input_buffers['q_start_loc'] = input_buffers['qkv_lens'][0] input_buffers['q_seqlens'] = input_buffers['qkv_lens'][1] input_buffers['kv_seqlens'] = input_buffers['qkv_lens'][2] + input_buffers['kv_start_loc'] = torch.zeros(max_batches, + dtype=torch.int64, + device=device) input_buffers['local_adapter_ids'] = torch.zeros(max_batches, dtype=torch.int64, device=device) @@ -100,6 +110,7 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, q_start_loc: Tensor = attn_metadata.q_start_loc q_seqlens: Tensor = attn_metadata.q_seqlens kv_seqlens: Tensor = attn_metadata.kv_seqlens + kv_start_loc: Tensor = attn_metadata.kv_start_loc input_buffers: BuffType = graph_meta.input_buffers batch_size, num_blocks = block_offsets.size() @@ -121,9 +132,26 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_buffers['inputs_embeds'] = inputs_embeds.new_zeros( 1, max_num_tokens, emb_size) input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds + new_batch_size = next_power_of_2(batch_size) + if attn_metadata.medusa_attn_mask is not None: + medusa_attn_mask = attn_metadata.medusa_attn_mask + if 'medusa_attn_mask' not in input_buffers: + input_buffers['medusa_attn_mask'] = medusa_attn_mask.new_zeros( + graph_meta.max_batchs, 64, CUDA_GRAPH_PREFILL_KVLEN) + input_buffers[ + 'medusa_attn_mask'][:batch_size, :, :medusa_attn_mask. + shape[-1]] = medusa_attn_mask[:batch_size] + attn_metadata.medusa_attn_mask = input_buffers[ + 'medusa_attn_mask'][:new_batch_size] + if kv_start_loc is not None: + input_buffers['kv_start_loc'][:batch_size] = kv_start_loc + attn_metadata.kv_start_loc = input_buffers[ + 'kv_start_loc'][:new_batch_size] # create inputs - new_batch_size = next_power_of_2(batch_size) + if attn_metadata.kv_flatten_size is not None: + attn_metadata.kv_flatten_size = max(attn_metadata.kv_flatten_size, + CUDA_GRAPH_PREFILL_KVLEN) attn_metadata.block_offsets = input_buffers[ 'block_offsets'][:new_batch_size] attn_metadata.q_start_loc = input_buffers[ diff --git a/lmdeploy/pytorch/nn/rejection_sampling.py b/lmdeploy/pytorch/nn/rejection_sampling.py new file mode 100644 index 0000000000..a89697cf5a --- /dev/null +++ b/lmdeploy/pytorch/nn/rejection_sampling.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + + +class RejectionSampler(nn.Module): + """apply rejection sampling according to "Accelerating Large Language Model + Decoding with Speculative Sampling". + + https://arxiv.org/pdf/2302.01318 + """ + + def __init__(self): + super().__init__() + + @staticmethod + def create_uniform_random(batch_size: int, num_speculative_tokens: int, + device: torch.device, dtype: torch.dtype): + """Generate a batch of random uniform samples. + + Args: + batch_size (int): The batch size. + num_speculative_tokens (int): The number of speculative tokens. + device (torch.device): The device of the output tensor. + dtype (torch.dtype): The dtype of the output tensor. + + Returns: + random_uniform (Tensor): The uniform tensor, shape + (batch_size, num_speculative_tokens). + """ + return torch.rand(batch_size, + num_speculative_tokens, + device=device, + dtype=dtype) + + @staticmethod + def adjusted_distribution(target_probs_without_bonus: torch.Tensor, + draft_probs: torch.Tensor): + """Adjust the distribution from the draft_probs if needed. + + Args: + target_probs_without_bonus (Tensor): The probability distribution + over token ids from the target model, shape + (batch_size, num_speculative_tokens, vocab_size). + + draft_probs (Tensor): The probability distribution over token ids + from the draft model, shape (batch_size, + num_speculative_tokens, vocab_size). + Returns: + adjusted_probs (Tensor): The adjusted probability distribution, + shape (batch_size, num_speculative_tokens, vocab_size). + """ + adjusted_probs = (target_probs_without_bonus - + draft_probs).clamp_min_(0) + adjusted_probs = adjusted_probs / adjusted_probs.sum( + -1, keepdim=True).clamp_min_(1e-5) # clamp to avoid div zero + return adjusted_probs + + # TODO fuse into a triton kernel + # TODO add seed + def forward(self, target_probs: torch.Tensor, draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor) -> torch.Tensor: + """Reject sampling probs and return token_ids. + + Args: + target_probs (Tensor): The probability distribution + over token ids from the target model, shape + (batch_size, num_speculative_tokens + 1, vocab_size). + + draft_probs (Tensor): The probability distribution over token ids + from the draft model, shape (batch_size, + num_speculative_tokens, vocab_size). + + draft_token_ids (Tensor): The proposal token id from the draft + model, shape (batch_size, num_speculative_tokens). + + Returns: + output_token_ids (Tensor): Token ids sampled through rejection + sampling. shape (batch)size, num_speculative_tokens + 1). + """ + target_probs_without_bonus = target_probs[:, :-1] + batch_size, num_speculative_tokens, _ = draft_probs.shape + device = draft_probs.device + batch_indices = torch.arange(batch_size, device=device) + probs_indicies = torch.arange(num_speculative_tokens, device=device) + draft_token_probs = draft_probs[batch_indices[:, None], probs_indicies, + draft_token_ids] + target_token_probs = target_probs_without_bonus[batch_indices[:, None], + probs_indicies, + draft_token_ids] + # target model scores draft token ids + scores = target_token_probs / draft_token_probs + random_uniform = self.create_uniform_random(batch_size, + num_speculative_tokens, + device=device, + dtype=scores.dtype) + rejected = scores < random_uniform + rejected_mask = rejected.cumsum(-1) > 0 + accepted_mask = ~rejected_mask + rejected_mask = torch.cat( + [rejected_mask, + rejected_mask.new_ones(batch_size, 1)], -1) + reject_idx = rejected_mask.float().argmax(-1, False) + # compute adjusted token ids + adjusted_probs = self.adjusted_distribution(target_probs_without_bonus, + draft_probs) + adjusted_probs = torch.cat([adjusted_probs, target_probs[:, -1:]], 1) + adjusted_probs = adjusted_probs[batch_indices, reject_idx] + adjusted_token_ids = torch.multinomial(adjusted_probs, + num_samples=1, + replacement=True).squeeze(-1) + output_token_ids = draft_token_ids.new_full( + (batch_size, num_speculative_tokens + 1), -1) + output_token_ids[~rejected_mask] = draft_token_ids[accepted_mask] + output_token_ids[batch_indices, reject_idx] = adjusted_token_ids + return output_token_ids, reject_idx + + +def test_rejection_sampler(): + batch_size = 4 + num_speculative_tokens = 5 + vocab_size = 1024 + dtype = torch.float32 + device = torch.device('cuda') + target_logits_with_bonus = torch.rand( + (batch_size, num_speculative_tokens + 1, vocab_size), + dtype=dtype, + device=device) + draft_logits = torch.rand((batch_size, num_speculative_tokens, vocab_size), + dtype=dtype, + device=device) + draft_token_ids = torch.randint(0, + vocab_size, + (batch_size, num_speculative_tokens), + device=device) + rejection_sampler = RejectionSampler() + rejection_sampler.forward(target_logits_with_bonus, draft_logits, + draft_token_ids) diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 722329a906..0d901d75a3 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -115,7 +115,7 @@ def add_sequence(self, seq: SchedulerSequence): self._set_message_status(seq, MessageStatus.WAITING) @logging_timer('SchedulePrefilling', logger) - def _schedule_prefill(self): + def _schedule_prefill(self, prealloc_size: int = 0): """Schedule for prefilling.""" current_running = self.running @@ -140,7 +140,7 @@ def __evict_for_seq(seq: SchedulerSequence, waiting): hanging = reversed(self.hanging) waiting = reversed(waiting) evictable = list(chain(hanging, waiting)) - return eviction_helper.evict_for_seq(seq, evictable, 0) + return eviction_helper.evict_for_seq(seq, evictable, prealloc_size) def _reorder_waiting(): """reorder waiting.""" @@ -164,7 +164,7 @@ def _reorder_waiting(): break # allocate session memory - self.block_manager.allocate(seq) + self.block_manager.allocate(seq, prealloc_size) _to_running(seq) return running, swap_in_map, swap_out_map, copy_map @@ -214,8 +214,9 @@ def __evict_for_seq(seq: SchedulerSequence): def schedule(self, is_prefill: bool, prealloc_size: int = 0): """Schedule inputs for next steps.""" + prealloc_size = max(prealloc_size, 64) # 64 for medusa tree decode if is_prefill: - output = self._schedule_prefill() + output = self._schedule_prefill(prealloc_size) else: output = self._schedule_decoding(prealloc_size) running, swap_in_map, swap_out_map, copy_map = output diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index cb548614c7..26169432db 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -12,6 +12,8 @@ logger = get_logger('lmdeploy') +MEDUSA_WEIGHT_NAME = 'medusa_lm_head.pt' + def load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, **kwargs): @@ -55,6 +57,9 @@ def _get_weight_type(model_path: str, use_safetensors: bool = None): # Load from a sharded PyTorch checkpoint weight_type = 'pytorch' is_sharded = True + elif osp.isfile(osp.join(model_path, MEDUSA_WEIGHT_NAME)): + # Load from a medusa head + weight_type = 'medusa' else: raise RuntimeError('Unknown weight type.') @@ -83,6 +88,8 @@ def _get_weight_path(model_path: str, weight_type: str): weight_name = SAFE_WEIGHTS_NAME elif weight_type == 'pytorch': weight_name = WEIGHTS_NAME + elif weight_type == 'medusa': + weight_name = MEDUSA_WEIGHT_NAME else: raise RuntimeError('Unknown weight type.') diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 2b2d02b38f..d2bd79b51d 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -250,6 +250,8 @@ class AsyncEngine(LogitsMixin): config instance. Default to none. chat_template_config (ChatTemplateConfig): chat template configuration. Default to None. + speculative_model (str): The path of the draft model. Only can be used + with pytorch backend for speculative decoding. max_log_len (int): Max number of prompt characters or prompt tokens being printed in log. Default: Unlimited """ @@ -261,6 +263,7 @@ def __init__(self, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, + speculative_model: Optional[str] = None, max_log_len: int = None, **kwargs) -> None: logger.info( @@ -279,12 +282,15 @@ def __init__(self, # build backend engine if backend == 'turbomind': + assert speculative_model is None, 'plese use '\ + '--backend pytorch to use speculative decoding' self._build_turbomind(model_path=model_path, backend_config=backend_config, **kwargs) elif backend == 'pytorch': self._build_pytorch(model_path=model_path, backend_config=backend_config, + speculative_model=speculative_model, **kwargs) else: raise ValueError(f'unsupported backend {backend}') @@ -341,11 +347,13 @@ def _build_pytorch( model_path: str, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, + speculative_model: Optional[str] = None, **kwargs): """Innter build method for pytorch backend.""" from lmdeploy.pytorch.engine import Engine self.engine = Engine(model_path=model_path, - engine_config=backend_config) + engine_config=backend_config, + speculative_model=speculative_model) self.backend_config = self.engine.engine_config self.hf_tm_cfg = getattr(self.engine.model_config, 'hf_config', None) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index a284250f21..f62c3301d1 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1014,6 +1014,7 @@ async def startup_event(): def serve(model_path: str, + speculative_model: Optional[str] = None, model_name: Optional[str] = None, backend: Literal['turbomind', 'pytorch'] = 'turbomind', backend_config: Optional[Union[PytorchEngineConfig, @@ -1049,6 +1050,7 @@ def serve(model_path: str, on huggingface.co, such as "internlm/internlm-chat-7b", "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. + speculative_model (str): the path of the speculative model. model_name (str): the name of the served model. It can be accessed by the RESTful API `/v1/models`. If it is not specified, `model_path` will be adopted @@ -1113,6 +1115,7 @@ def serve(model_path: str, _, pipeline_class = get_task(model_path) VariableInterface.async_engine = pipeline_class( model_path=model_path, + speculative_model=speculative_model, model_name=model_name, backend=backend, backend_config=backend_config, diff --git a/tests/pytorch/kernel/test_flash_attention.py b/tests/pytorch/kernel/test_flash_attention.py index e56de44b37..5edafc95a1 100644 --- a/tests/pytorch/kernel/test_flash_attention.py +++ b/tests/pytorch/kernel/test_flash_attention.py @@ -32,7 +32,7 @@ def _make_bias(q_seqlens, history_lens, neg_val, causal): return (~mask).float() * neg_val -def _naive_attention(batched_q, batched_kv, bias): +def _naive_attention(batched_q, batched_kv, bias, rand_mask): batched_k, batched_v = batched_kv num_heads_q = batched_q.shape[2] @@ -50,6 +50,8 @@ def _naive_attention(batched_q, batched_kv, bias): qk = torch.matmul(q, k) / math.sqrt(head_dim) attn_weight = qk + bias[:, None] + if rand_mask is not None: + attn_weight += rand_mask[:, None] attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) attn_weight = attn_weight.to(q.dtype) attn_output = torch.matmul(attn_weight, v) @@ -110,6 +112,10 @@ def num_heads_k(self, request): def causal(self, request): yield request.param + @pytest.fixture + def with_attention_mask(self, request): + yield request.param + @pytest.fixture def q_seqlens(self, request): yield torch.tensor(request.param, device='cuda') @@ -182,8 +188,16 @@ def mask(self, q_seqlens, history_lens, causal): yield _make_bias(q_seqlens, history_lens, neg_val, causal) @pytest.fixture - def gt(self, batched_q, batched_kv, mask): - yield _naive_attention(batched_q, batched_kv, mask) + def rand_mask(self, mask, with_attention_mask): + neg_val = -1e30 + if with_attention_mask: + yield torch.rand_like(mask).round() * neg_val + else: + yield None + + @pytest.fixture + def gt(self, batched_q, batched_kv, mask, rand_mask): + yield _naive_attention(batched_q, batched_kv, mask, rand_mask) @pytest.fixture def conti_gt(self, gt, q_seqlens): @@ -197,9 +211,11 @@ def conti_gt(self, gt, q_seqlens): @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([30, 50, 70, 90], [50, 40, 30, 20])], indirect=True) + @pytest.mark.parametrize('with_attention_mask', [True, False], + indirect=True) def test_flash_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, kv_start_loc, kv_seqlens, head_dim_v, causal, - conti_gt): + conti_gt, rand_mask): from lmdeploy.pytorch.kernels.cuda.flashattention import \ flash_attention_fwd max_seq_len = q_seqlens.max().item() @@ -210,6 +226,7 @@ def test_flash_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, conti_k, conti_v, out, + attention_mask=rand_mask, q_start_loc=q_start_loc, q_seqlens=q_seqlens, kv_start_loc=kv_start_loc, diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index 2e7c1e1a0f..c988e9e5bf 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -14,11 +14,11 @@ def block_size(self): @pytest.fixture def num_cpu_blocks(self): - yield 4 + yield 12 @pytest.fixture def num_gpu_blocks(self): - yield 4 + yield 12 @pytest.fixture def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks): @@ -60,9 +60,9 @@ def test_schedule_base(self, scheduler, block_size, num_gpu_blocks): assert seq.status == MessageStatus.RUNNING assert seq in output.running assert len(block_tables) == 1 - assert len(block_tables[0]) == num_blocks - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - num_blocks + assert len(block_tables[0]) == num_blocks + 4 # medusa needs 4 more + assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - ( + num_blocks + 4) assert scheduler.has_unfinished() @@ -99,7 +99,8 @@ def test_update(self, scheduler, block_size, num_gpu_blocks): assert session_id1 in scheduler.sessions assert seq1 not in scheduler.running assert seq1 not in scheduler.hanging - assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 2 + assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - (2 + + 4) # stop session scheduler.stop_session(session_id2) @@ -136,7 +137,8 @@ def test_evict(self, scheduler, block_size, num_gpu_blocks, assert seq1.status == MessageStatus.RUNNING assert seq2.status == MessageStatus.RUNNING assert seq3.status == MessageStatus.WAITING - assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 3 + assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - ( + 3 + 4 * 2) # test: waiting alloc seq2.status = MessageStatus.STOPPED @@ -180,4 +182,4 @@ def test_evict(self, scheduler, block_size, num_gpu_blocks, # seq3: 3 nan assert seq1.status == MessageStatus.WAITING assert seq2.status == MessageStatus.RUNNING - assert block_manager.get_num_free_gpu_blocks() == 0 + assert block_manager.get_num_free_gpu_blocks() == 4