Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Medusa speculative decoding #2859

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lmdeploy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def pipeline(model_path: str,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
speculative_model: str = None,
log_level: str = 'WARNING',
max_log_len: int = None,
**kwargs):
Expand All @@ -33,6 +34,7 @@ def pipeline(model_path: str,
config instance. Default to None.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
speculative_model (str): the path of the draft model.
log_level(str): set log level whose value among [CRITICAL, ERROR,
WARNING, INFO, DEBUG]
max_log_len(int): Max number of prompt characters or prompt tokens
Expand Down Expand Up @@ -86,6 +88,7 @@ def pipeline(model_path: str,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
speculative_model=speculative_model,
max_log_len=max_log_len,
**kwargs)

Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def add_parser_convert():
help='The directory path of the model')
ArgumentHelper.model_format(parser)
ArgumentHelper.tp(parser)
ArgumentHelper.speculative_model(parser)
# other args
ArgumentHelper.revision(parser)
ArgumentHelper.download_dir(parser)
Expand Down Expand Up @@ -121,6 +122,7 @@ def add_parser_chat():
# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.adapters(pt_group)
ArgumentHelper.speculative_model(pt_group)
ArgumentHelper.device(pt_group)
ArgumentHelper.eager_mode(pt_group)
# common engine args
Expand Down Expand Up @@ -270,6 +272,7 @@ def chat(args):
quant_policy=args.quant_policy)
run_chat(args.model_path,
engine_config,
speculative_model=args.speculative_model,
chat_template_config=chat_template_config)
else:
from lmdeploy.turbomind.chat import main as run_chat
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def add_parser_api_server():
pt_group = parser.add_argument_group('PyTorch engine arguments')

ArgumentHelper.adapters(pt_group)
ArgumentHelper.speculative_model(pt_group)
ArgumentHelper.device(pt_group)
ArgumentHelper.eager_mode(pt_group)

Expand Down Expand Up @@ -338,6 +339,7 @@ def api_server(args):
vision_config = VisionConfig(args.vision_max_batch_size)
run_api_server(args.model_path,
model_name=args.model_name,
speculative_model=args.speculative_model,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
Expand Down
9 changes: 9 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ def model_name(parser):
'by the RESTful API `/v1/models`. If it is not specified, '
'`model_path` will be adopted')

@staticmethod
def speculative_model(parser):
"""Add argument speculative_model to parser."""

return parser.add_argument('--speculative-model',
type=str,
default=None,
help='The path of the spculative model.')

@staticmethod
def dtype(parser, default: str = 'auto'):
return parser.add_argument(
Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class TritonAttentionMetadata(AttentionMetadata):
fill_seqlens: torch.Tensor = None
quant_policy: Literal[0, 4, 8] = 0
kv_flatten_size: int = None
medusa_attn_mask: torch.Tensor = None


def _cdiv(a, b):
Expand Down Expand Up @@ -106,6 +107,9 @@ def forward(
fill_seqlens = attn_metadata.fill_seqlens
fill_max_q_seqlen = key.numel() // (key.size(-1) * key.size(-2))
fill_q_start_loc = fill_seqlens.cumsum(0) - fill_seqlens
attention_mask = None
if attn_metadata.medusa_attn_mask is not None:
attention_mask = attn_metadata.medusa_attn_mask

# fill kv cache
if key is not None and value is not None:
Expand Down Expand Up @@ -167,6 +171,7 @@ def forward(
flatten_k,
flatten_v,
attn_output,
attention_mask=attention_mask,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_start_loc=kv_start_loc,
Expand Down
6 changes: 5 additions & 1 deletion lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def get_graph_key(self, input_ids: torch.Tensor,
is_decoding = context.is_decoding
num_tokens = input_ids.numel()
new_num_tokens = next_power_of_2(num_tokens)
return (new_num_tokens, is_decoding)
mask_input = attn_metadata.medusa_attn_mask is not None
seq_num = 0
if is_decoding is False:
seq_num = next_power_of_2(attn_metadata.q_seqlens.shape[0])
return (new_num_tokens, is_decoding, mask_input, seq_num)

def __call__(self, **kwargs):
"""call."""
Expand Down
17 changes: 17 additions & 0 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,22 @@ def update_step_context(cls, step_context):
if not step_context.is_decoding:
kv_start_loc = kv_seqlens.cumsum(0) - kv_seqlens
kv_flatten_size = kv_seqlens.sum().item()
if step_context.medusa_attn_mask is not None:
max_q_seqlen = q_seqlens.max()
max_kv_seqlen = kv_seqlens.max()
bs = q_seqlens.shape[0]
medusa_len = step_context.medusa_attn_mask.shape[-1]
dtype = step_context.medusa_attn_mask.dtype
device = step_context.medusa_attn_mask.device
attention_mask = torch.zeros((bs, max_q_seqlen, max_kv_seqlen),
dtype=dtype,
device=device)
masked_value = (1 - step_context.medusa_attn_mask) * (-1e30)
for i in range(bs):
attention_mask[i, :, kv_seqlens[i] -
medusa_len:kv_seqlens[i]] = masked_value # noqa
step_context.medusa_attn_mask = attention_mask

attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets,
Expand All @@ -125,6 +141,7 @@ def update_step_context(cls, step_context):
kv_seqlens=kv_seqlens,
kv_flatten_size=kv_flatten_size,
quant_policy=step_context.kv_quant_policy,
medusa_attn_mask=step_context.medusa_attn_mask,
)

cross_seqlens = step_context.cross_seqlens
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def _stop_words(stop_words: List[str], tokenizer: Tokenizer):

def run_chat(model_path: str,
engine_config: PytorchEngineConfig,
speculative_model: Optional[str] = None,
gen_config: GenerationConfig = None,
session_id: int = 1,
trust_remote_code: bool = True,
Expand All @@ -62,12 +63,14 @@ def run_chat(model_path: str,
Args:
model_path (str): the huggingface model path.
engine_config (PytorchEngineConfig): Config of engine.
speculative_model (str): the path of the speculative model.
gen_config (GenerationConfig): Config of generation.
session_id (int): the identical id of a session.
trust_remote_code (bool): trust remote code.
"""
from lmdeploy.pytorch.engine import Engine
tm_model = Engine.from_pretrained(model_path,
speculative_model=speculative_model,
engine_config=engine_config,
trust_remote_code=trust_remote_code)
tokenizer = tm_model.tokenizer
Expand Down
23 changes: 18 additions & 5 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class ModelConfig:
hf_config: Any = None
cogvlm_style: bool = False
custom_module_map: Dict[str, setattr] = None
# medusa config
medusa_num_heads: int = None
medusa_num_layers: int = None

def get_head_size(self):
"""get head size."""
Expand All @@ -134,12 +137,22 @@ def from_pretrained(cls,
activations. Refer to `PyTorchEngineConfig` for details
"""
from transformers import AutoConfig
hf_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
if getattr(hf_config, 'model_type', None) in ['phi3']:
# phi3 + trust_remote_code leads to error when tp.
try:
hf_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path)
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code)
if getattr(hf_config, 'model_type', None) in ['phi3']:
# phi3 + trust_remote_code leads to error when tp.
hf_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path)
except Exception as e: # noqa
from transformers import PretrainedConfig
hf_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code)
# medusa model config
if getattr(hf_config, 'medusa_num_heads', None) is not None:
setattr(hf_config, 'architectures', ['MedusaModel'])
return cls.from_hf_config(hf_config,
pretrained_model_name_or_path,
dtype=dtype,
Expand Down
42 changes: 42 additions & 0 deletions lmdeploy/pytorch/configurations/medusa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import ModelConfig

from .builder import AutoModelConfigBuilder


class MedusaModelConfigBuilder(AutoModelConfigBuilder):

@classmethod
def condition(cls, hf_config):
"""config."""
return hf_config.architectures[0] == 'MedusaModel'

@classmethod
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
from transformers import AutoConfig
base_config = AutoConfig.from_pretrained(
hf_config.base_model_name_or_path)
head_dim = base_config.hidden_size // base_config.num_attention_heads
# config is wrong
# https://huggingface.co/FasterDecoding/medusa-vicuna-7b-v1.3/blob/main/config.json#L3
hf_config.medusa_num_heads = 5
medusa_num_heads = hf_config.medusa_num_heads
medusa_num_layers = hf_config.medusa_num_layers
if getattr(hf_config, 'hidden_size', None) is None:
setattr(hf_config, 'hidden_size', base_config.hidden_size)
if getattr(hf_config, 'vocab_size', None) is None:
setattr(hf_config, 'vocab_size', base_config.vocab_size)
return ModelConfig(
hidden_size=base_config.hidden_size,
num_attention_heads=base_config.num_attention_heads,
num_layers=base_config.num_hidden_layers,
num_key_value_heads=base_config.num_key_value_heads,
bos_token_id=base_config.bos_token_id,
eos_token_id=base_config.eos_token_id,
head_dim=head_dim,
vocab_size=base_config.vocab_size,
hf_config=hf_config,
medusa_num_heads=medusa_num_heads,
medusa_num_layers=medusa_num_layers,
)
Loading
Loading