diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 04f54047d..749fb5e01 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1291,7 +1291,7 @@ def __init__( self, model: nn.Module, continuous_batching: bool = False, - is_tlm: bool = False, + qaic_config: Optional[dict] = None, **kwargs, ): model_class_name = model.__class__.__name__ @@ -1317,11 +1317,9 @@ def __init__( self.model.config.use_cache = True self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching - - if is_tlm: - # TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch - self.model, transformed = SpDTransform.apply(self.model) - self.is_tlm = is_tlm + # TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch + self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) + self.is_tlm = transformed @property def model_name(self) -> str: @@ -1336,7 +1334,12 @@ def __repr__(self) -> str: @classmethod @with_replaced_quantizers def from_pretrained( - cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs + cls, + pretrained_model_name_or_path, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + *args, + **kwargs, ): """ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. @@ -1381,6 +1384,8 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + if qaic_config is not None: + qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path # This is support models that should be classified to in a different auto class but transformers load them via this class @@ -1389,7 +1394,12 @@ def from_pretrained( model, kv_offload=kv_offload ) - return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching) + return cls( + model, + continuous_batching=continuous_batching, + qaic_config=qaic_config, + **kwargs, + ) @property def model_hash(self) -> str: @@ -1564,16 +1574,9 @@ def compile( raise TypeError("`prefill_only` must be a boolean.") if self.is_tlm: - if num_speculative_tokens is None: - raise TypeError("`num_speculative_tokens` is required when `is_tlm=True`.") - if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2: - raise ValueError("`num_speculative_tokens` must be an integer >= 2.") - if prefill_seq_len < (num_speculative_tokens + 1): - raise ValueError( - f"`prefill_seq_len` must be at least `num_speculative_tokens + 1` " - f"({num_speculative_tokens + 1}), got {prefill_seq_len}." - ) - + num_speculative_tokens: int = self.check_and_get_num_speculative_tokens( + num_speculative_tokens, prefill_seq_len + ) if self.continuous_batching and full_batch_size is None: raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") @@ -1667,6 +1670,29 @@ def generate( else: raise NotImplementedError("Only AI_100 runtime is supported right now via generate API") + def check_and_get_num_speculative_tokens(self, num_speculative_tokens: Optional[int], prefill_seq_len: int): + if hasattr(self.model.config, "speculative_config"): + num_speculative_tokens_ = self.model.config.speculative_config["num_speculative_tokens"] + if num_speculative_tokens is not None: + logger.warning( + f"arg `num_speculative_tokens` is a fixed value of {num_speculative_tokens_} for this model." + f" Passed value of {num_speculative_tokens} will be ignored." + ) + num_speculative_tokens = num_speculative_tokens_ + elif num_speculative_tokens is None: + raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.") + + if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2: + ValueError( + f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}" + ) + num_logits_to_keep = num_speculative_tokens + 1 + if prefill_seq_len < num_logits_to_keep: + raise ValueError( + f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" + ) + return num_speculative_tokens + class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin): """ diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 2e94908c8..f5573cc5e 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1,13 +1,14 @@ # ----------------------------------------------------------------------------- # -# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- from types import MethodType -from typing import Tuple +from typing import Optional, Tuple +import transformers from torch import nn from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -47,17 +48,6 @@ GraniteAttention, GraniteForCausalLM, GraniteModel, - GraniteRMSNorm, -) -from transformers.models.granitemoe.modeling_granitemoe import ( - GraniteMoeAttention, - GraniteMoeForCausalLM, - GraniteMoeModel, - GraniteMoeMoE, - GraniteMoeParallelExperts, - GraniteMoeRMSNorm, - GraniteMoeRotaryEmbedding, - GraniteMoeTopKGating, ) from transformers.models.llama.modeling_llama import ( LlamaAttention, @@ -69,9 +59,6 @@ from transformers.models.llava.modeling_llava import ( LlavaForConditionalGeneration, ) -from transformers.models.llava_next.modeling_llava_next import ( - LlavaNextForConditionalGeneration, -) from transformers.models.mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, @@ -133,6 +120,7 @@ from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMethodMapperTransform from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC +from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.models.codegen.modeling_codegen import ( QEffCodeGenAttention, QeffCodeGenBlock, @@ -180,16 +168,10 @@ QEffGraniteForCausalLM, QEffGraniteModel, ) -from QEfficient.transformers.models.granitemoe.modeling_granitemoe import ( - QEffGraniteMoeAttention, - QEffGraniteMoeForCausalLM, - QEffGraniteMoeModel, - QEffGraniteMoeMoE, - QEffGraniteMoeParallelExperts, - QEffGraniteMoeRotaryEmbedding, - QEffGraniteMoeTopKGating, +from QEfficient.transformers.models.internvl.modeling_internvl import ( + QEffInternVisionEmbeddings, + QEffInternVLModel, ) -from QEfficient.transformers.models.internvl.modeling_internvl import QEffInternVisionEmbeddings, QEffInternVLModel from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, @@ -199,9 +181,6 @@ from QEfficient.transformers.models.llava.modeling_llava import ( QEffLlavaForConditionalGeneration, ) -from QEfficient.transformers.models.llava_next.modeling_llava_next import ( - QEffLlavaNextForConditionalGeneration, -) from QEfficient.transformers.models.mistral.modeling_mistral import ( QEffMistralAttention, QEffMistralDecoderLayer, @@ -266,7 +245,10 @@ QEffWhisperModel, QEffWhisperPositionalEmbedding, ) -from QEfficient.transformers.spd.causal_lm_forward import tlm_forward +from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry +from QEfficient.transformers.spd.spd_transform_forward import tlm_forward + +SPD_TARGET = "target" class CustomOpsTransform(ModuleMappingTransform): @@ -279,8 +261,6 @@ class CustomOpsTransform(ModuleMappingTransform): Phi3RMSNorm: CustomRMSNormAIC, Qwen2RMSNorm: CustomRMSNormAIC, MllamaTextRMSNorm: CustomRMSNormAIC, - GraniteRMSNorm: CustomRMSNormAIC, - GraniteMoeRMSNorm: CustomRMSNormAIC, } @@ -313,8 +293,6 @@ class KVCacheTransform(ModuleMappingTransform): LlamaForCausalLM: QEffLlamaForCausalLM, # Llava LlavaForConditionalGeneration: QEffLlavaForConditionalGeneration, - # Llava Next - LlavaNextForConditionalGeneration: QEffLlavaNextForConditionalGeneration, # Gemma GemmaAttention: QEffGemmaAttention, GemmaDecoderLayer: QEffGemmaDecoderLayer, @@ -329,14 +307,6 @@ class KVCacheTransform(ModuleMappingTransform): GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM, GraniteAttention: QEffGraniteAttention, - # GraniteMoe - GraniteMoeModel: QEffGraniteMoeModel, - GraniteMoeForCausalLM: QEffGraniteMoeForCausalLM, - GraniteMoeAttention: QEffGraniteMoeAttention, - GraniteMoeRotaryEmbedding: QEffGraniteMoeRotaryEmbedding, - GraniteMoeParallelExperts: QEffGraniteMoeParallelExperts, - GraniteMoeTopKGating: QEffGraniteMoeTopKGating, - GraniteMoeMoE: QEffGraniteMoeMoE, # mllama MllamaTextRMSNorm: CustomRMSNormAIC, MllamaTextSelfAttention: QEffMllamaTextSelfAttention, @@ -401,6 +371,8 @@ class KVCacheTransform(ModuleMappingTransform): @classmethod def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: model, transformed = super().apply(model) + # FIXME: see if we can merge into _module_mapping dict + transformers.cache_utils.DynamicCache.update = QEffDynamicCache.update return model, transformed @@ -421,21 +393,34 @@ class SpDTransform: # supported architectures _module_mapping = { - # Llama QEffLlamaForCausalLM, + QEffQwen2ForCausalLM, } @classmethod - def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: transformed = False - if (model_class := model.__class__) in cls._module_mapping: + if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None: + return model, transformed + elif speculative_model_type not in ( + supported_spd_model_types := [SPD_TARGET] + list(model_type_registry.keys()) + ): + raise ValueError( + f"Specualtive model type {speculative_model_type} is not supported. we currently only support {supported_spd_model_types}" + ) + elif (model_class := model.__class__) in cls._module_mapping: model.forward = MethodType(tlm_forward, model) + if speculative_model_type != SPD_TARGET: + # build and attach draft mlp + pretrained_model_name_or_path = qaic_config["pretrained_model_name_or_path"] + model = build_and_attach_mlp( + model, pretrained_model_name_or_path, speculative_model_type=speculative_model_type, **kwargs + ) transformed = True else: raise NotImplementedError( f"model class {model_class} does not yet support returning multiple logits to keep." ) - return model, transformed diff --git a/QEfficient/transformers/post_processing.py b/QEfficient/transformers/post_processing.py new file mode 100644 index 000000000..04c279287 --- /dev/null +++ b/QEfficient/transformers/post_processing.py @@ -0,0 +1,28 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from QEfficient.transformers.spd.turbo import build_and_attach_turbo +from QEfficient.utils.spd_utils import get_speculative_config, get_speculative_weights + +model_type_registry = dict(turbo=build_and_attach_turbo) + + +def build_and_attach_mlp(model, pretrained_model_name_or_path, speculative_model_type: str, **kwargs): + speculative_config: dict = get_speculative_config(pretrained_model_name_or_path, **kwargs) + speculative_weights: str = get_speculative_weights(pretrained_model_name_or_path, **kwargs) + + if (model_type := speculative_config.get("model_type")) is None: + speculative_config["model_type"] = speculative_model_type + else: + if model_type != speculative_model_type: + raise ValueError( + f"`model_type` key from speculator config ({model_type} does not match input model type ({speculative_model_type})." + ) + func = model_type_registry[speculative_model_type] + model = func(model, speculative_config, speculative_weights) + model.config.speculative_config = speculative_config + return model diff --git a/QEfficient/transformers/spd/__init__.py b/QEfficient/transformers/spd/__init__.py index da26921c5..d259e435a 100644 --- a/QEfficient/transformers/spd/__init__.py +++ b/QEfficient/transformers/spd/__init__.py @@ -4,4 +4,3 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- - diff --git a/QEfficient/transformers/spd/causal_lm_forward.py b/QEfficient/transformers/spd/spd_transform_forward.py similarity index 81% rename from QEfficient/transformers/spd/causal_lm_forward.py rename to QEfficient/transformers/spd/spd_transform_forward.py index 46601c0c9..486e911fc 100644 --- a/QEfficient/transformers/spd/causal_lm_forward.py +++ b/QEfficient/transformers/spd/spd_transform_forward.py @@ -21,7 +21,7 @@ def filter_hidden_states( Filter hidden states based on whether this is a TLM SpD model ``Mandatory`` Args: - :hidden_states (torch.Tensor): Hidden states tensor. + :hidden_states (torch.Tensor): Last hidden state tensor. :position_ids (torch.Tensor): Position ids tensor. ``Optional`` Args: :num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model @@ -50,6 +50,28 @@ def filter_hidden_states( return hidden_states +def project_hidden_states(hidden_states, hidden_size_projections): + """ + Filter hidden states based on whether this is a TLM SpD model + + ``Mandatory`` Args: + :hidden_states (torch.Tensor): Last hidden state tensor. + :hidden_size_projections (torch.nn.ModuleList): Position ids tensor. + ``Optional`` Args: + :num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model + + Returns: + :torch.Tensor: Filtered hidden states. + """ + proj_hidden_states = [hidden_states] + num_projs = len(hidden_size_projections) + for i in range(num_projs): + hidden_states_i = hidden_size_projections[i](hidden_states) + proj_hidden_states.append(hidden_states_i) + hidden_states = torch.stack(proj_hidden_states, dim=2) # shape: [bsz, seq_len, num_projs, d_model] + return hidden_states + + def tlm_forward( self, input_ids: torch.LongTensor = None, @@ -113,7 +135,10 @@ def tlm_forward( ) hidden_states = filter_hidden_states(outputs[0], position_ids, num_logits_to_keep) - if self.config.pretraining_tp > 1: + hidden_size_projections = getattr(self, "projections", None) + if hidden_size_projections: + hidden_states = project_hidden_states(hidden_states, hidden_size_projections) + if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) diff --git a/QEfficient/transformers/spd/turbo.py b/QEfficient/transformers/spd/turbo.py new file mode 100644 index 000000000..b0f87caef --- /dev/null +++ b/QEfficient/transformers/spd/turbo.py @@ -0,0 +1,85 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch + +from QEfficient.utils.checkpoint_utils import load_checkpoint + + +class ResBlock(torch.nn.Module): + """ + A Residual Block module. + + This module performs a linear transformation followed by a SiLU activation, + and then adds the result to the original input, creating a residual connection. + + Args: + hidden_size (int): The size of the hidden layers in the block. + """ + + def __init__(self, hidden_size): + super().__init__() + self.linear = torch.nn.Linear(hidden_size, hidden_size) + # Initialize as an identity mapping + torch.nn.init.zeros_(self.linear.weight) + # Use SiLU activation to keep consistent with the Llama model + self.act = torch.nn.SiLU() + + def forward(self, x): + """ + Forward pass of the ResBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output after the residual connection and activation. + """ + return x + self.act(self.linear(x)) + + +def post_process_turbo_state_dict(state_dict: dict) -> dict: + """normaize turbo state dict keys + + Args: + state_dict (dict): turbo state dict + + Returns: + dict: normalized state dict + """ + new_state_dict = dict() + for name, weights in state_dict.items(): + new_name = name.replace("projections.", "") + new_state_dict[new_name] = weights + return new_state_dict + + +def build_and_attach_turbo(model, speculative_config: dict, speculative_weights: str): + """build and attach turbo projections + + Args: + model: model to attach projections to + speculative_config (dict): speculative config file used to build projections + + Returns: + model: model with turbo projections + """ + hidden_size = model.config.hidden_size + num_layers = speculative_config["turbo_num_layers"] + num_heads = speculative_config["turbo_num_heads"] + projections = torch.nn.ModuleList( + [ + torch.nn.Sequential( + *([ResBlock(hidden_size)] * num_layers), + ) + for _ in range(num_heads) + ], + ) + load_checkpoint(projections, speculative_weights, strict=True, post_process_func=post_process_turbo_state_dict) + model.projections = projections + speculative_config["num_speculative_tokens"] = num_heads + return model diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index ea09e97d7..ab52c32c2 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import inspect import json import os import subprocess @@ -580,3 +581,39 @@ def make_serializable(obj): qconfigs["qpc_config"]["aic_compiler_config"] = aic_compiler_config create_json(qconfig_file_path, qconfigs) + + +def filter_kwargs(func, kwargs): + """ + Filter a dictionary of keyword arguments to only include the valid arguments of a function. + + Args: + func: The function to check the arguments for. + kwargs: The dictionary of keyword arguments to filter. + + Returns: + A new dictionary containing only the valid keyword arguments. + """ + valid_args = inspect.signature(func).parameters + return {key: value for key, value in kwargs.items() if key in valid_args} + + +def check_and_get_num_logits_to_keep(self, num_speculative_tokens, prefill_seq_len): + if hasattr(self.model.config, "speculative_config"): + if num_speculative_tokens is not None: + logger.warning( + f"arg `num_speculative_tokens` is a fixed value of {self.num_speculative_tokens} for this model." + " Passed value will be ignored." + ) + num_speculative_tokens = self.config.speculative_config.num_speculative_tokens + elif num_speculative_tokens is None: + raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.") + + if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2: + ValueError(f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}") + num_logits_to_keep = num_speculative_tokens + 1 + if prefill_seq_len < num_logits_to_keep: + raise ValueError( + f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" + ) + return num_logits_to_keep diff --git a/QEfficient/utils/checkpoint_utils.py b/QEfficient/utils/checkpoint_utils.py new file mode 100644 index 000000000..7bd234b20 --- /dev/null +++ b/QEfficient/utils/checkpoint_utils.py @@ -0,0 +1,27 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from safetensors.torch import load_file + + +def load_checkpoint(model, checkpoint: str, strict=False, post_process_func=None): + """load weights ending with `.safetensors` extension + + Args: + model: model to load wights into + checkpoint (str): checkpoint path + strict (bool, optional): strictness of loading weights. Defaults to False. + post_process_func (optional): Optional post-processing of loaded state dict. Defaults to None. + + Returns: + model: model with applied weights + """ + state_dict: dict = load_file(checkpoint) + if post_process_func is not None: + state_dict = post_process_func(state_dict) + model.load_state_dict(state_dict, strict=strict) + return model diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index c45cfec41..006c8e7c8 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -98,10 +98,11 @@ def update_pytorch_inputs(self, inputs, pt_outputs): :Dict: Updated input_ids, position_ids and past_key_values """ updated_inputs = {} + logits = pt_outputs.logits.detach() + input_ids = logits.argmax(-1) if self.full_batch_size: batch_index = torch.arange(1).view(-1, 1) - input_ids = pt_outputs.logits.detach().argmax(2) updated_inputs["input_ids"] = torch.full((self.full_batch_size, 1), self.tokenizer.pad_token_id) updated_inputs["input_ids"][batch_index.view(-1)] = input_ids @@ -112,7 +113,7 @@ def update_pytorch_inputs(self, inputs, pt_outputs): updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) else: - updated_inputs["input_ids"] = pt_outputs["logits"].argmax(-1).reshape(-1, 1) + updated_inputs["input_ids"] = input_ids.reshape(-1, 1) updated_inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 updated_inputs["past_key_values"] = tuple( diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index 267b2bb9e..70d269316 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -134,7 +134,8 @@ def run_kv_model_on_pytorch(self, model): pt_outputs = model(**inputs) for _ in range(1, self.gen_len): - generated_ids.append(pt_outputs["logits"].argmax(-1).reshape(-1, 1)) + logits = pt_outputs["logits"].detach() + generated_ids.append(logits.argmax(-1).reshape(-1, 1)) inputs = self.input_handler.update_pytorch_inputs(inputs, pt_outputs) pt_outputs = model(**inputs) diff --git a/QEfficient/utils/spd_utils.py b/QEfficient/utils/spd_utils.py new file mode 100644 index 000000000..763af5a45 --- /dev/null +++ b/QEfficient/utils/spd_utils.py @@ -0,0 +1,41 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from pathlib import Path + +from huggingface_hub import hf_hub_download +from transformers import PretrainedConfig + +from QEfficient.utils._utils import filter_kwargs + + +def get_speculative_config(pretrained_model_name_or_path, **kwargs) -> dict: + if not isinstance(pretrained_model_name_or_path, (str, Path)): + raise ValueError( + f"`pretrained_config` must be a string or Path object but is of type {type(pretrained_model_name_or_path)}" + ) + try: + speculative_config, _ = PretrainedConfig.get_config_dict( + pretrained_model_name_or_path, _configuration_file="speculator_config.json", **kwargs + ) + except OSError as err: + raise OSError(f"{err}.\nFile 'speculator_config.json' is expected to exist to apply turbo projections.") + return speculative_config + + +def get_speculative_weights(pretrained_model_name_or_path, **kwargs) -> str: + turbo_weights_file = "speculator.safetensors" + hf_hub_kwargs = filter_kwargs(hf_hub_download, kwargs) + if (local_path := Path(pretrained_model_name_or_path)).exists(): + if not local_path.is_dir(): + raise ValueError(f"local model path {local_path} must point to an existing dictionary") + weights_path = local_path / turbo_weights_file + if not weights_path.exists(): + raise FileNotFoundError(f"weights path {weights_path} does not exist.") + else: + weights_path = hf_hub_download(pretrained_model_name_or_path, filename=turbo_weights_file, **hf_hub_kwargs) + return str(weights_path) diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 33b9a03d7..abab4cfc3 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -258,16 +258,17 @@ End to End demo examples for various models are available in **notebooks** direc ### Draft-Based Speculative Decoding Draft-based speculative decoding is a technique where a small Draft Language Model (DLM) makes `num_speculative_tokens` autoregressive speculations ahead of the Target Language Model (TLM). The objective is to predict what the TLM would have predicted if it would have been used instead of the DLM. This approach is beneficial when the autoregressive decode phase of the TLM is memory bound and thus, we can leverage the extra computing resources of our hardware by batching the speculations of the DLM as an input to TLM to validate the speculations. -To export and compile both DLM/TLM, add corresponding `is_tlm` and `num_speculative_tokens` for TLM and export DLM as you would any other QEfficient LLM model: +To export and compile both DLM/TLM, add corresponding `qaic_config` and `num_speculative_tokens` for TLM and export DLM as you would any other QEfficient LLM model: ```Python tlm_name = "meta-llama/Llama-2-70b-chat-hf" dlm_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" k = 3 # DLM will make `k` speculations -tlm = AutoModelForCausalLM.from_pretrained(tlm_name, is_tlm=True) +qaic_config = dict(speculative_model_type="target") +tlm = AutoModelForCausalLM.from_pretrained(tlm_name, qaic_config=qaic_config) dlm = AutoModelForCausalLM.from_pretrained(dlm_name) tlm.compile(num_speculative_tokens=k) dlm.compile() ``` -The `is_tlm` flag is fed during the instantiation of the model because slight changes to the ONNX graph are required. Once complete, the user can specify `num_speculative_tokens` to define the actual number of speculations that the TLM will take as input during the decode phase. As for the DLM, no new changes are required at the ONNX or compile level. +The `qaic_config` dictionary is fed during the instantiation of the model because slight changes to the ONNX graph are required. Once complete, the user can specify `num_speculative_tokens` to define the actual number of speculations that the TLM will take as input during the decode phase. As for the DLM, no new changes are required at the ONNX or compile level. diff --git a/examples/multiprojs_spd_inference.py b/examples/multiprojs_spd_inference.py new file mode 100644 index 000000000..a80bc955d --- /dev/null +++ b/examples/multiprojs_spd_inference.py @@ -0,0 +1,423 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse +from dataclasses import dataclass +from time import perf_counter +from typing import List, Optional, Union + +import numpy as np +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.constants import Constants + + +@dataclass +class SpDPerfMetrics: + """ + Holds all performance metrics + + Args: + :mean_ttft (float): Average TLM+DLM TTFT. + :batch_ttft (float): Total TLM+DLM Batch TTFT. + :decode_throughput (float): Decode throughput. + :e2e_throughput (float): E2E throughput. + :mean_num_accepted_tokens (float): Average number of accepted tokens. + :max_gen_len (int): Max generation length. + :generated_tokens_per_prompt (List[int]): Total generated tokens per prompt. + """ + + mean_ttft: float + batch_ttft: float + decode_throughput: float + e2e_throughput: float + mean_num_accepted_tokens: float + max_gen_len: int + generated_tokens_per_prompt: List[int] + e2e_time: float + decode_time: float + decode_target_time: float + decode_iterations: int + + +@dataclass +class CloudAI100ExecInfo: + """ + Holds all the information about Cloud AI 100 execution + + Args: + :prompts (List[str]): Prompts to perfrom inferencing on. + :batch_size (int): Batch size of the QPC compilation. + :generated_texts (Union[List[List[str]], List[str]]): Generated text(s). + :generated_ids (Union[List[np.ndarray], np.ndarray]): Generated IDs. + :perf_metrics (PerfMetrics): Performance metrics. + :num_speculative_tokens (int): Number of speculative tokens. + :prefill_seq_len (int): Prefill sequence length. + :ctx_len (int): Context length. + :prefill_bsz (int): Prefill batch size. + :draft_model_name (str): Draft model name. + :target_model_name (str): Target model name. + :full_batch_size (Optional[int]): Full batch size. + """ + + prompts: List[str] + batch_size: int + generated_texts: Union[List[str], List[List[str]]] + generated_ids: Union[List[np.ndarray], np.ndarray] + perf_metrics: SpDPerfMetrics + num_speculative_tokens: int + prefill_seq_len: int + ctx_len: int + prefill_bsz: int + model_name: str + full_batch_size: Optional[int] + + def __repr__(self): + return ( + f"Avg TLM+DLM TTFT = {round(self.perf_metrics.mean_ttft, 2)}\n" + f"Total TLM+DLM Batch TTFT = {round(self.perf_metrics.batch_ttft, 2)}\n" + f"Decode Throughput = {round(self.perf_metrics.decode_throughput, 2)}\n" + f"E2E Throughput = {round(self.perf_metrics.e2e_throughput, 2)}\n" + f"Avg number of accepted tokens = {round(self.perf_metrics.mean_num_accepted_tokens, 2)}\n" + f"Max generation len = {self.perf_metrics.max_gen_len}\n" + f"Total Generated Tokens per Prompt: = {self.perf_metrics.generated_tokens_per_prompt}" + ) + + +def run_prefill( + session: QAICInferenceSession, + inputs: dict, + prefill_seq_len: int, + slot_idx: int, +) -> np.ndarray: + input_len = inputs.input_ids.shape[1] + num_chunks = input_len // prefill_seq_len + cache_index = np.array([[0]], np.int64) + batch_index = np.array([[slot_idx]], np.int64) + inputs["batch_index"] = batch_index + + # Run chunked prefill + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len] + chunk_inputs["position_ids"] = inputs["position_ids"][ + :, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len + ] + + outputs = session.run(chunk_inputs) + cache_index += prefill_seq_len + + logits = outputs["logits"] + return logits + + +def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int): + """return padded input length (must be factor of `prefill_seq_len`) + + Args: + input_len (int): prompt length + prefill_seq_len (int): prefill sequence length + ctx_len (int): context length + + Returns: + input_len_padded (int): padded input length + """ + num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float + input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len + assert input_len_padded <= ctx_len, ( + "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len" + ) + return input_len_padded + + +def split_dlm_bonus_token_inputs(dlm_decode_inputs): + bonus_token_inputs = dict() + bonus, regular = np.hsplit(dlm_decode_inputs["input_ids"], 2) + bonus_token_inputs["input_ids"] = bonus + dlm_decode_inputs["input_ids"] = regular + bonus, regular = np.hsplit(dlm_decode_inputs["position_ids"], 2) + bonus_token_inputs["position_ids"] = bonus + dlm_decode_inputs["position_ids"] = regular + bonus_token_inputs["batch_index"] = dlm_decode_inputs["batch_index"] + return bonus_token_inputs, dlm_decode_inputs + + +def multiprojs_spec_decode_inference( + prompts: List[str], + num_speculative_tokens: int, + prefill_seq_len: int, + ctx_len: int, + prefill_bsz: int, + pretrained_model_name_or_path: str, + full_batch_size: Optional[int], + session: QAICInferenceSession, + ignore_eos_token: bool = False, +) -> CloudAI100ExecInfo: + """ + Perform draft speculative decode inference on the given prompts. + + Args: + prompts (List[str]): List of prompts to perform inference on. + num_speculative_tokens (int): Number of speculative tokens. + prefill_seq_len (int): Prefill sequence length. + ctx_len (int): Context length. + prefill_bsz (int): Prefill batch size. + pretrained_model_name_or_path (str): Name of multiprojection model + full_batch_size (Optional[int]): Full batch size. + device_group (List[int]): List of device IDs. + + Returns: + CloudAI100ExecInfo: Execution information, including performance metrics and generated text. + """ + # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size + # get vocab size + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, padding_side="right") + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + vocab_size = len(tokenizer) + # skip inputs/outputs buffers + session.skip_buffers(set([x for x in session.input_names if x.startswith("past_")])) + session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) + + is_cb = full_batch_size is not None + decode_batch_size = full_batch_size if is_cb else prefill_bsz + if len(prompts) < decode_batch_size: + prompts_exp = prompts * decode_batch_size + prompts = prompts_exp[:decode_batch_size] + # tokenize the prompts + prompts_tokenized: List[dict] = [] + for p in prompts: + input_len: int = tokenizer(p, return_tensors="np", padding=True).input_ids.shape[1] + input_len_padded: int = get_padded_input_len(input_len, prefill_seq_len, ctx_len) + p_tok: dict = tokenizer(p, return_tensors="np", padding="max_length", max_length=input_len_padded) + position_ids = np.where(p_tok.pop("attention_mask"), np.arange(input_len_padded), -1) + p_tok["position_ids"] = position_ids + p_tok["num_logits_to_keep"] = np.zeros((1, 1), dtype=np.int64) + prompts_tokenized.append(p_tok) + # create caches to hold generated ids and input prompt lengths + generated_ids = [[] for i in range(decode_batch_size)] + input_lengths = [0] * decode_batch_size + # mock input key "logits" to store the first batch of output logits + num_logits_to_keep = num_speculative_tokens + 1 # number of logits to keep + precode_inputs = dict( + input_ids=np.zeros((decode_batch_size, num_logits_to_keep), dtype=np.int64), + position_ids=np.zeros((decode_batch_size, num_logits_to_keep), dtype=np.int64), + batch_index=np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1), + num_logits_to_keep=np.zeros((num_logits_to_keep, 1), dtype=np.int64), + ) + max_gen_len = [ctx_len] * decode_batch_size + # setup buffers + prefill_logits_ph = np.zeros((prefill_bsz, 1, num_logits_to_keep, vocab_size), dtype=np.float32) + session.set_buffers({"logits": prefill_logits_ph}) + e2e_start = perf_counter() + ttfts = [] + for bi in range(decode_batch_size): + # assumes that prefill queue will always be popped from the front + start = perf_counter() + logits = run_prefill( # shape: [1, 1, num_logits_to_keep, vocab_size] + session=session, + inputs=prompts_tokenized[bi], + prefill_seq_len=prefill_seq_len, + slot_idx=bi, + ) + ttft = perf_counter() - start + ttfts.append(ttft) + input_ids = logits.argmax(-1).astype(np.int64) # shape: [1, 1, num_logits_to_keep] + generated_ids[bi].append(input_ids[0, 0, 0].item()) + precode_inputs["input_ids"][bi] = input_ids.flatten() + input_len = prompts_tokenized[bi]["position_ids"].max(1).item() + 1 + precode_inputs["position_ids"][bi] = np.arange(input_len, input_len + num_logits_to_keep, dtype=np.int64) + # assumes that prefill queue will always be popped from the front + input_lengths[bi] = input_len + max_gen_len[bi] -= input_lengths[bi] + batch_ttft = perf_counter() - e2e_start + + # set decode logits buffers + precode_logits_ph = np.zeros( + (decode_batch_size, num_logits_to_keep, num_logits_to_keep, vocab_size), dtype=np.float32 + ) + session.set_buffers({"logits": precode_logits_ph}) + # start decode phase + valid_batch_indices = np.full(decode_batch_size, True, dtype=bool) + seq_batch_indices = np.arange(decode_batch_size, dtype=np.int64) + it = 0 + mean_num_accepted_tokens = 0 + decode_target_time = 0.0 + decode_start = perf_counter() + while True: + it += 1 + # run precode + target_start = perf_counter() + tlm_outputs = session.run(precode_inputs) + target_logits = tlm_outputs[ + "logits" + ] # shape: [decode_batch_size, num_logits_to_keep, num_logits_to_keep, vocab_size] + # greedy sampling from target model + target_tokens = target_logits[:, :, 0].argmax(-1) # shape: [decode_batch_size, num_logits_to_keep] + target_end = perf_counter() - target_start + decode_target_time += target_end + # exact matching between draft and target tokens + draft_tokens = precode_inputs["input_ids"][:, 1:] # shape: [decode_batch_size, num_speculative_tokens] + matching = draft_tokens == target_tokens[:, :-1] # shape: [decode_batch_size, num_speculative_tokens] + num_tokens_selected = matching.cumprod(axis=1).sum(axis=1) + 1 # shape: [decode_batch_size] + mean_num_accepted_tokens += num_tokens_selected[valid_batch_indices].mean().item() + # append selected tokens to the generated_ids + for bi, valid in enumerate(valid_batch_indices): + if not valid: + continue + accepted_tokens = num_tokens_selected[bi] + num_tokens_to_append = min(accepted_tokens, max_gen_len[bi] - len(generated_ids[bi])) + accepted_tokens_arr = target_tokens[bi, :num_tokens_to_append] + generated_ids[bi].extend(accepted_tokens_arr.tolist()) + if len(generated_ids[bi]) >= max_gen_len[bi] or ( + (not ignore_eos_token) and (accepted_tokens_arr == tokenizer.eos_token_id).any() + ): + valid_batch_indices[bi] = False + # check if all generations are done + if not valid_batch_indices.any(): + break + # prepare decode inputs for next decode iteration + next_input_ids = ( + target_logits[seq_batch_indices, num_tokens_selected - 1].argmax(-1).astype(np.int64) + ) # shape: [decode_batch_size, num_logits_to_keep] + next_position_ids = precode_inputs["position_ids"] + num_tokens_selected[:, np.newaxis] + next_position_ids[~valid_batch_indices] = -1 + precode_inputs["input_ids"] = next_input_ids + precode_inputs["position_ids"] = next_position_ids + end = perf_counter() + # calculate performance metrics + decode_end = end - decode_start + e2e_end = end - e2e_start + mean_ttft = sum(ttfts) / len(ttfts) + generated_tokens_per_prompt = [len(gid) + 1 for gid in generated_ids] + decode_throughput = sum(generated_tokens_per_prompt) / decode_end + e2e_throughput = (sum(generated_tokens_per_prompt) + decode_batch_size) / e2e_end + batch_decode = tokenizer.batch_decode(generated_ids) + mean_num_accepted_tokens /= it + perf_metrics = SpDPerfMetrics( + mean_ttft, + batch_ttft, + decode_throughput, + e2e_throughput, + mean_num_accepted_tokens, + max_gen_len, + generated_tokens_per_prompt, + e2e_end, + decode_end, + decode_target_time, + it, + ) + exec_info = CloudAI100ExecInfo( + prompts, + decode_batch_size, + batch_decode, + generated_ids, + perf_metrics, + num_speculative_tokens, + prefill_seq_len, + ctx_len, + prefill_bsz, + pretrained_model_name_or_path, + full_batch_size, + ) + return exec_info + + +def optional_int(x): + if x is None: + return None + return int(x) + + +def comma_separated_ints(x: str): + return [int(qid) for qid in x.split(",")] + + +def arg_parse(): + parser = argparse.ArgumentParser(description="Draft-based SpD Inference") + parser.add_argument("--prompts", action="append", default=None, help="Input prompt(s)") + parser.add_argument("--prefill-seq-len", type=int, default=32, help="Prefill sequence length") + parser.add_argument("--ctx-len", type=int, default=128, help="Context length") + parser.add_argument("--prefill-bsz", type=int, default=1, help="Prefill batch size") + parser.add_argument( + "--pretrained-model-name-or-path", + type=str, + default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + help="Target model name", + ) + parser.add_argument("--full-batch-size", type=optional_int, default=None, help="Full batch size") + parser.add_argument("--device-group", type=comma_separated_ints, default="0", help="device QIDs") + parser.add_argument("--ignore-eos-token", action="store_true") + args = parser.parse_args() + return args + + +def get_session( + pretrained_model_name_or_path, + device_group, + prefill_seq_len, + ctx_len, + full_batch_size=None, +): + is_cb = full_batch_size is not None + qaic_config = dict(speculative_model_type="turbo") + qeff_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + continuous_batching=is_cb, + qaic_config=qaic_config, + ) + num_devices = len(device_group) + model_qpc_path: str = qeff_model.compile( + num_cores=16, + num_devices=num_devices, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + aic_enable_depth_first=True, + full_batch_size=full_batch_size, + ) + print(f"{model_qpc_path=}") + # init qaic session + session = QAICInferenceSession(model_qpc_path, device_ids=device_group) + num_speculative_tokens = qeff_model.model.config.speculative_config["num_speculative_tokens"] + return session, num_speculative_tokens + + +def main(): + args = arg_parse() + if args.prompts is None: + args.prompts = Constants.INPUT_STR + + session, num_speculative_tokens = get_session( + pretrained_model_name_or_path=args.pretrained_model_name_or_path, + device_group=args.device_group, + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + full_batch_size=args.full_batch_size, + ) + args.session = session + exec_info = multiprojs_spec_decode_inference( + args.prompts, + num_speculative_tokens, + args.prefill_seq_len, + args.ctx_len, + args.prefill_bsz, + args.pretrained_model_name_or_path, + args.full_batch_size, + args.session, + args.ignore_eos_token, + ) + print(exec_info) + prompts = exec_info.prompts + generated_texts = exec_info.generated_texts + for prompt, generation in zip(prompts, generated_texts): + print(f"{prompt=} {generation=}") + + +if __name__ == "__main__": + main() diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 67eec2e50..ff55aa8ba 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -55,6 +55,7 @@ spd_test_models = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "Qwen/Qwen2-0.5B", ] @@ -122,7 +123,10 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) is_tlm = False if num_speculative_tokens is None else True - qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm) + qaic_config = None + if is_tlm: + qaic_config = dict(speculative_model_type="target") + qeff_model = QEFFAutoModelForCausalLM(model_hf, qaic_config=qaic_config) pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) @@ -280,7 +284,6 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_qnn(model_name): ) -@pytest.mark.skip() # remove when the SDK 1.20.0 issue solved for compiling this model @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", spd_test_models) def test_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): diff --git a/tests/transformers/spd/test_pld_inference.py b/tests/transformers/spd/test_pld_inference.py index e5d472734..8d0ea7b1c 100644 --- a/tests/transformers/spd/test_pld_inference.py +++ b/tests/transformers/spd/test_pld_inference.py @@ -25,7 +25,7 @@ 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name 1, # full_batch_size 3, # max_ngram_size id="CB llama", @@ -247,8 +247,9 @@ def test_pld_spec_decode_inference( # export_and_compile tlm and dlm continuous_batching = full_batch_size is not None + qaic_config = dict(speculative_model_type="target") target_model = AutoModelForCausalLM.from_pretrained( - target_model_name, continuous_batching=continuous_batching, is_tlm=True + target_model_name, continuous_batching=continuous_batching, qaic_config=qaic_config ) num_devices = len(device_group) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index b78afdc38..3bc9201b2 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -30,6 +30,17 @@ 1, # full_batch_size id="CB llama", ), + pytest.param( + Constants.INPUT_STR, # prompts + 4, # num_speculative_tokens + 32, # prefill_seq_len + 128, # ctx_len + 1, # prefill_bsz + "Qwen/Qwen2-0.5B", # draft_model_name + "Qwen/Qwen2-0.5B", # target_model_name + 1, # full_batch_size + id="CB qwen", + ), ] @@ -94,7 +105,6 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs): @pytest.mark.on_qaic -@pytest.mark.skip() # remove when the SDK 1.20.0 issue solved for compiling this model @pytest.mark.parametrize( "prompts, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size", configs, @@ -119,11 +129,14 @@ def test_spec_decode_inference( if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id vocab_size = len(tokenizer) + if target_model_name == "Qwen/Qwen2-0.5B": + vocab_size = 151936 # export_and_compile tlm and dlm continuous_batching = full_batch_size is not None + qaic_config = dict(speculative_model_type="target") target_model = AutoModelForCausalLM.from_pretrained( - target_model_name, continuous_batching=continuous_batching, is_tlm=True + target_model_name, continuous_batching=continuous_batching, qaic_config=qaic_config ) draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=continuous_batching) @@ -169,6 +182,7 @@ def test_spec_decode_inference( p_tok: dict = tokenizer(p, return_tensors="np", padding="max_length", max_length=input_len_padded) position_ids = np.where(p_tok.pop("attention_mask"), np.arange(input_len_padded), -1) p_tok["position_ids"] = position_ids + p_tok["num_logits_to_keep"] = np.zeros((1, 1), dtype=np.int64) prompts_tokenized.append(p_tok) # create caches to hold generated ids and input prompt lengths generated_ids = [[] for i in range(decode_batch_size)] @@ -181,13 +195,14 @@ def test_spec_decode_inference( np.array(np.arange(decode_batch_size), np.int64), (decode_batch_size, 1) ) # mock input key "logits" to store the first batch of output logits + num_logits_to_keep = num_speculative_tokens + 1 tlm_precode_inputs = dict( input_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), position_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), batch_index=np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1), + num_logits_to_keep=np.zeros((num_logits_to_keep, 1), dtype=np.int64), ) max_gen_len = [ctx_len] * decode_batch_size - num_logits_to_keep = num_speculative_tokens + 1 # setup buffers tlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) dlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index d7151aad7..7c57459bd 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -18,6 +18,7 @@ from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform +from QEfficient.transformers.spd.turbo import ResBlock from QEfficient.utils._utils import get_padding_shape_from_config from QEfficient.utils.logging_utils import logger @@ -71,6 +72,11 @@ ("llama", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), ("llama", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), ("llama", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("qwen2", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +] + +SpDTransformProjTestConfigs = [ + ("llama", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), ] @@ -160,10 +166,27 @@ def run_kv_cache_transform_and_test( ) else: original_model_outputs = hf_model(input_ids=input_ids, output_hidden_states=True) + hidden_size_projections = ( + hf_model.hidden_size_projections if hasattr(hf_model, "hidden_size_projections") else None + ) + if hidden_size_projections: + # compute projections + last_hidden_size = original_model_outputs.hidden_states[-1] # shape: [bsz, seq_len, d_model] + proj_hidden_sizes = [last_hidden_size] + for proj in hidden_size_projections: + proj_i = proj(last_hidden_size) + proj_hidden_sizes.append(proj_i) + proj_hidden_sizes = torch.stack(proj_hidden_sizes, dim=2) + logits = hf_model.lm_head(proj_hidden_sizes) + original_model_outputs.logits = logits # Apply transforms - is_tlm = "num_logits_to_keep" in qaic_model_inputs - hf_model = QEFFAutoModelForCausalLM(hf_model, is_tlm=is_tlm).model + qaic_config = None + if "num_logits_to_keep" in qaic_model_inputs: + qaic_config = dict(speculative_model_type="target") + hf_model = QEFFAutoModelForCausalLM(hf_model, qaic_config=qaic_config).model + if hidden_size_projections is not None: + hf_model.projections = hidden_size_projections # Run KV model with torch.inference_mode(): @@ -300,6 +323,62 @@ def test_spd_transform(config_class, num_hidden_layers, num_attention_heads, hid ) +@pytest.mark.parametrize( + "config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance", + SpDTransformProjTestConfigs, +) +def test_spd_proj_transform( + config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance +): + config = AutoConfig.for_model( + config_class, + **kwargs, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + use_cache=True, + cache_position=None, + position_embeddings=None, + ) + hf_model = AutoModelForCausalLM.from_config(config=config, attn_implementation="eager") + proj_num_layers = 1 + num_speculative_tokens = 3 + hidden_size_projections = torch.nn.ModuleList( + [ + torch.nn.Sequential( + *([ResBlock(hidden_size)] * proj_num_layers), + ) + for _ in range(num_speculative_tokens) + ], + ) + hf_model.hidden_size_projections = hidden_size_projections + + kv_cache = None + if hasattr(config, "cache_implementation") and config.cache_implementation == "hybrid": + # Create a KV Cache from HybridCache class to pass as an object for models which use Hybrid KV Cache + # Refer https://github.com/huggingface/transformers/issues/32896 for more info + # This requires torch._dynamo present in torch>=2.3.0 + kv_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=32) + + padding_shape = get_padding_shape_from_config(config=config, batch_size=1, seq_len=32) + + # Prepare KV model inputs + qaic_model_inputs = create_qaic_model_inputs( + input_len=8, + vocab_size=config.vocab_size, + padding_shape=padding_shape, + num_hidden_layers=num_hidden_layers, + is_tlm=True, + ) + + run_kv_cache_transform_and_test( + hf_model, + qaic_model_inputs=qaic_model_inputs, + logits_tolerance=logits_tolerance, + kv_cache=kv_cache, + ) + + @pytest.mark.parametrize("in_features", [2048, 4096]) @pytest.mark.parametrize("out_features", [2048, 4096]) @pytest.mark.skipif(platform.machine() == "aarch64", reason="Test skipped on aarch64 platform")