diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index f54a2ac92..6b5deb8db 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1298,7 +1298,7 @@ def __init__( self, model: nn.Module, continuous_batching: bool = False, - is_tlm: bool = False, + qaic_config: Optional[dict] = None, **kwargs, ): model_class_name = model.__class__.__name__ @@ -1324,11 +1324,8 @@ def __init__( self.model.config.use_cache = True self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching - - if is_tlm: - # TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch - self.model, transformed = SpDTransform.apply(self.model) - self.is_tlm = is_tlm + self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) + self.is_tlm = transformed @property def model_name(self) -> str: @@ -1343,7 +1340,12 @@ def __repr__(self) -> str: @classmethod @with_replaced_quantizers def from_pretrained( - cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs + cls, + pretrained_model_name_or_path, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + *args, + **kwargs, ): """ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. @@ -1388,6 +1390,8 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + if qaic_config is not None: + qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path # This is support models that should be classified to in a different auto class but transformers load them via this class @@ -1396,7 +1400,12 @@ def from_pretrained( model, kv_offload=kv_offload ) - return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching) + return cls( + model, + continuous_batching=continuous_batching, + qaic_config=qaic_config, + **kwargs, + ) @property def model_hash(self) -> str: @@ -1571,15 +1580,7 @@ def compile( raise TypeError("`prefill_only` must be a boolean.") if self.is_tlm: - if num_speculative_tokens is None: - raise TypeError("`num_speculative_tokens` is required when `is_tlm=True`.") - if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2: - raise ValueError("`num_speculative_tokens` must be an integer >= 2.") - if prefill_seq_len < (num_speculative_tokens + 1): - raise ValueError( - f"`prefill_seq_len` must be at least `num_speculative_tokens + 1` " - f"({num_speculative_tokens + 1}), got {prefill_seq_len}." - ) + num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len) if self.continuous_batching and full_batch_size is None: raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") @@ -1674,6 +1675,29 @@ def generate( else: raise NotImplementedError("Only AI_100 runtime is supported right now via generate API") + def check_and_get_num_speculative_tokens(self, num_speculative_tokens: Optional[int], prefill_seq_len: int): + if hasattr(self.model.config, "speculative_config"): + num_speculative_tokens_ = self.model.config.speculative_config["num_speculative_tokens"] + if num_speculative_tokens is not None: + logger.warning( + f"arg `num_speculative_tokens` is a fixed value of {num_speculative_tokens_} for this model." + f" Passed value of {num_speculative_tokens} will be ignored." + ) + num_speculative_tokens = num_speculative_tokens_ + elif num_speculative_tokens is None: + raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.") + + if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2: + ValueError( + f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}" + ) + num_logits_to_keep = num_speculative_tokens + 1 + if prefill_seq_len < num_logits_to_keep: + raise ValueError( + f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" + ) + return num_speculative_tokens + class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin): """ diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 2e94908c8..333c734ba 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- from types import MethodType -from typing import Tuple +from typing import Optional, Tuple from torch import nn from transformers.models.codegen.modeling_codegen import ( @@ -266,7 +266,10 @@ QEffWhisperModel, QEffWhisperPositionalEmbedding, ) -from QEfficient.transformers.spd.causal_lm_forward import tlm_forward +from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry +from QEfficient.transformers.spd.spd_transform_forward import tlm_forward + +SPD_TARGET = "target" class CustomOpsTransform(ModuleMappingTransform): @@ -423,19 +426,33 @@ class SpDTransform: _module_mapping = { # Llama QEffLlamaForCausalLM, + QEffQwen2ForCausalLM, } @classmethod - def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: transformed = False - if (model_class := model.__class__) in cls._module_mapping: + if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None: + return model, transformed + elif speculative_model_type not in ( + supported_spd_model_types := [SPD_TARGET] + list(model_type_registry.keys()) + ): + raise ValueError( + f"Specualtive model type {speculative_model_type} is not supported. we currently only support {supported_spd_model_types}" + ) + elif (model_class := model.__class__) in cls._module_mapping: model.forward = MethodType(tlm_forward, model) + if speculative_model_type != SPD_TARGET: + # build and attach draft mlp + pretrained_model_name_or_path = qaic_config["pretrained_model_name_or_path"] + model = build_and_attach_mlp( + model, pretrained_model_name_or_path, speculative_model_type=speculative_model_type, **kwargs + ) transformed = True else: raise NotImplementedError( f"model class {model_class} does not yet support returning multiple logits to keep." ) - return model, transformed diff --git a/QEfficient/transformers/post_processing.py b/QEfficient/transformers/post_processing.py new file mode 100644 index 000000000..40be6e2bd --- /dev/null +++ b/QEfficient/transformers/post_processing.py @@ -0,0 +1,28 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from QEfficient.transformers.spd.turbo import build_and_attach_turbo +from QEfficient.utils.spd_utils import get_speculative_config, get_speculative_weights + +model_type_registry = dict(turbo=build_and_attach_turbo) + + +def build_and_attach_mlp(model, pretrained_model_name_or_path, speculative_model_type: str, **kwargs): + speculative_config: dict = get_speculative_config(pretrained_model_name_or_path, **kwargs) + speculative_weights: str = get_speculative_weights(pretrained_model_name_or_path, **kwargs) + + if (model_type := speculative_config.get("model_type")) is None: + speculative_config["model_type"] = speculative_model_type + else: + if model_type != speculative_model_type: + raise ValueError( + f"`model_type` key from speculator config ({model_type} does not match input model type ({speculative_model_type})." + ) + func = model_type_registry[speculative_model_type] + model = func(model, speculative_config, speculative_weights) + model.config.speculative_config = speculative_config + return model diff --git a/QEfficient/transformers/spd/causal_lm_forward.py b/QEfficient/transformers/spd/spd_transform_forward.py similarity index 80% rename from QEfficient/transformers/spd/causal_lm_forward.py rename to QEfficient/transformers/spd/spd_transform_forward.py index 46601c0c9..a3f1a83e5 100644 --- a/QEfficient/transformers/spd/causal_lm_forward.py +++ b/QEfficient/transformers/spd/spd_transform_forward.py @@ -21,7 +21,7 @@ def filter_hidden_states( Filter hidden states based on whether this is a TLM SpD model ``Mandatory`` Args: - :hidden_states (torch.Tensor): Hidden states tensor. + :hidden_states (torch.Tensor): Last hidden state tensor. :position_ids (torch.Tensor): Position ids tensor. ``Optional`` Args: :num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model @@ -50,6 +50,26 @@ def filter_hidden_states( return hidden_states +def project_hidden_states(hidden_states: torch.Tensor, hidden_size_projections: torch.nn.ModuleList) -> torch.Tensor: + """ + Filter hidden states based on whether this is a TLM SpD model + ``Mandatory`` Args: + :hidden_states (torch.Tensor): Last hidden state tensor. + :hidden_size_projections (torch.nn.ModuleList): Position ids tensor. + ``Optional`` Args: + :num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model + Returns: + :torch.Tensor: Filtered hidden states. + """ + proj_hidden_states = [hidden_states] + num_projs = len(hidden_size_projections) + for i in range(num_projs): + hidden_states_i = hidden_size_projections[i](hidden_states) + proj_hidden_states.append(hidden_states_i) + hidden_states = torch.stack(proj_hidden_states, dim=2) # shape: [bsz, seq_len, num_projs, d_model] + return hidden_states + + def tlm_forward( self, input_ids: torch.LongTensor = None, @@ -113,7 +133,10 @@ def tlm_forward( ) hidden_states = filter_hidden_states(outputs[0], position_ids, num_logits_to_keep) - if self.config.pretraining_tp > 1: + hidden_size_projections = getattr(self, "projections", None) + if hidden_size_projections: + hidden_states = project_hidden_states(hidden_states, hidden_size_projections) + if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) diff --git a/QEfficient/transformers/spd/turbo.py b/QEfficient/transformers/spd/turbo.py new file mode 100644 index 000000000..828f7ad56 --- /dev/null +++ b/QEfficient/transformers/spd/turbo.py @@ -0,0 +1,77 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch + +from QEfficient.utils.checkpoint_utils import load_checkpoint + + +class ResBlock(torch.nn.Module): + """ + A Residual Block module. + This module performs a linear transformation followed by a SiLU activation, + and then adds the result to the original input, creating a residual connection. + Args: + hidden_size (int): The size of the hidden layers in the block. + """ + + def __init__(self, hidden_size): + super().__init__() + self.linear = torch.nn.Linear(hidden_size, hidden_size) + # Initialize as an identity mapping + torch.nn.init.zeros_(self.linear.weight) + # Use SiLU activation to keep consistent with the Llama model + self.act = torch.nn.SiLU() + + def forward(self, x): + """ + Forward pass of the ResBlock. + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: Output after the residual connection and activation. + """ + return x + self.act(self.linear(x)) + + +def post_process_turbo_state_dict(state_dict: dict) -> dict: + """normaize turbo state dict keys + Args: + state_dict (dict): turbo state dict + Returns: + dict: normalized state dict + """ + new_state_dict = dict() + for name, weights in state_dict.items(): + new_name = name.replace("projections.", "") + new_state_dict[new_name] = weights + return new_state_dict + + +def build_and_attach_turbo(model, speculative_config: dict, speculative_weights: str): + """build and attach turbo projections + Args: + model: model to attach projections to + speculative_config (dict): speculative config file used to build projections + Returns: + model: model with turbo projections + """ + hidden_size = model.config.hidden_size + num_layers = speculative_config["turbo_num_layers"] + num_heads = speculative_config["turbo_num_heads"] + projections = torch.nn.ModuleList( + [ + torch.nn.Sequential( + *([ResBlock(hidden_size)] * num_layers), + ) + for _ in range(num_heads) + ], + ) + load_checkpoint(projections, speculative_weights, strict=True, post_process_func=post_process_turbo_state_dict) + model.projections = projections + speculative_config["num_speculative_tokens"] = num_heads + return model diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 05cd63968..b6af66be5 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import inspect import json import os import subprocess @@ -626,3 +627,16 @@ def make_serializable(obj): qconfigs["qpc_config"]["aic_compiler_config"] = aic_compiler_config create_json(qconfig_file_path, qconfigs) + + +def filter_kwargs(func, kwargs): + """ + Filter a dictionary of keyword arguments to only include the valid arguments of a function. + Args: + func: The function to check the arguments for. + kwargs: The dictionary of keyword arguments to filter. + Returns: + A new dictionary containing only the valid keyword arguments. + """ + valid_args = inspect.signature(func).parameters + return {key: value for key, value in kwargs.items() if key in valid_args} diff --git a/QEfficient/utils/checkpoint_utils.py b/QEfficient/utils/checkpoint_utils.py new file mode 100644 index 000000000..a823ce8b1 --- /dev/null +++ b/QEfficient/utils/checkpoint_utils.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from safetensors.torch import load_file + + +def load_checkpoint(model, checkpoint: str, strict=False, post_process_func=None): + """load weights ending with `.safetensors` extension + Args: + model: model to load wights into + checkpoint (str): checkpoint path + strict (bool, optional): strictness of loading weights. Defaults to False. + post_process_func (optional): Optional post-processing of loaded state dict. Defaults to None. + Returns: + model: model with applied weights + """ + state_dict: dict = load_file(checkpoint) + if post_process_func is not None: + state_dict = post_process_func(state_dict) + model.load_state_dict(state_dict, strict=strict) + return model diff --git a/QEfficient/utils/spd_utils.py b/QEfficient/utils/spd_utils.py new file mode 100644 index 000000000..aacec4f46 --- /dev/null +++ b/QEfficient/utils/spd_utils.py @@ -0,0 +1,41 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from pathlib import Path + +from huggingface_hub import hf_hub_download +from transformers import PretrainedConfig + +from QEfficient.utils._utils import filter_kwargs + + +def get_speculative_config(pretrained_model_name_or_path, **kwargs) -> dict: + if not isinstance(pretrained_model_name_or_path, (str, Path)): + raise ValueError( + f"`pretrained_config` must be a string or Path object but is of type {type(pretrained_model_name_or_path)}" + ) + try: + speculative_config, _ = PretrainedConfig.get_config_dict( + pretrained_model_name_or_path, _configuration_file="speculator_config.json", **kwargs + ) + except OSError as err: + raise OSError(f"{err}.\nFile 'speculator_config.json' is expected to exist to apply turbo projections.") + return speculative_config + + +def get_speculative_weights(pretrained_model_name_or_path, **kwargs) -> str: + turbo_weights_file = "speculator.safetensors" + hf_hub_kwargs = filter_kwargs(hf_hub_download, kwargs) + if (local_path := Path(pretrained_model_name_or_path)).exists(): + if not local_path.is_dir(): + raise ValueError(f"local model path {local_path} must point to an existing dictionary") + weights_path = local_path / turbo_weights_file + if not weights_path.exists(): + raise FileNotFoundError(f"weights path {weights_path} does not exist.") + else: + weights_path = hf_hub_download(pretrained_model_name_or_path, filename=turbo_weights_file, **hf_hub_kwargs) + return str(weights_path) diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 33b9a03d7..abab4cfc3 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -258,16 +258,17 @@ End to End demo examples for various models are available in **notebooks** direc ### Draft-Based Speculative Decoding Draft-based speculative decoding is a technique where a small Draft Language Model (DLM) makes `num_speculative_tokens` autoregressive speculations ahead of the Target Language Model (TLM). The objective is to predict what the TLM would have predicted if it would have been used instead of the DLM. This approach is beneficial when the autoregressive decode phase of the TLM is memory bound and thus, we can leverage the extra computing resources of our hardware by batching the speculations of the DLM as an input to TLM to validate the speculations. -To export and compile both DLM/TLM, add corresponding `is_tlm` and `num_speculative_tokens` for TLM and export DLM as you would any other QEfficient LLM model: +To export and compile both DLM/TLM, add corresponding `qaic_config` and `num_speculative_tokens` for TLM and export DLM as you would any other QEfficient LLM model: ```Python tlm_name = "meta-llama/Llama-2-70b-chat-hf" dlm_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" k = 3 # DLM will make `k` speculations -tlm = AutoModelForCausalLM.from_pretrained(tlm_name, is_tlm=True) +qaic_config = dict(speculative_model_type="target") +tlm = AutoModelForCausalLM.from_pretrained(tlm_name, qaic_config=qaic_config) dlm = AutoModelForCausalLM.from_pretrained(dlm_name) tlm.compile(num_speculative_tokens=k) dlm.compile() ``` -The `is_tlm` flag is fed during the instantiation of the model because slight changes to the ONNX graph are required. Once complete, the user can specify `num_speculative_tokens` to define the actual number of speculations that the TLM will take as input during the decode phase. As for the DLM, no new changes are required at the ONNX or compile level. +The `qaic_config` dictionary is fed during the instantiation of the model because slight changes to the ONNX graph are required. Once complete, the user can specify `num_speculative_tokens` to define the actual number of speculations that the TLM will take as input during the decode phase. As for the DLM, no new changes are required at the ONNX or compile level. diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 67eec2e50..efa2187b7 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -55,6 +55,7 @@ spd_test_models = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "Qwen/Qwen2-0.5B", ] diff --git a/tests/transformers/spd/test_pld_inference.py b/tests/transformers/spd/test_pld_inference.py index e5d472734..c80fe5969 100644 --- a/tests/transformers/spd/test_pld_inference.py +++ b/tests/transformers/spd/test_pld_inference.py @@ -25,7 +25,7 @@ 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name + "JackFram/llama-68m", # target_model_name 1, # full_batch_size 3, # max_ngram_size id="CB llama", @@ -247,14 +247,14 @@ def test_pld_spec_decode_inference( # export_and_compile tlm and dlm continuous_batching = full_batch_size is not None + qaic_config = dict(speculative_model_type="target") target_model = AutoModelForCausalLM.from_pretrained( - target_model_name, continuous_batching=continuous_batching, is_tlm=True + target_model_name, continuous_batching=continuous_batching, qaic_config=qaic_config ) - num_devices = len(device_group) target_model_qpc_path: str = target_model.compile( - num_cores=16, - num_devices=num_devices, + num_cores=8, + num_devices=1, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, aic_enable_depth_first=True, @@ -402,7 +402,7 @@ def test_pld_spec_decode_inference( all_ids[bi, prompt_plus_gen_idx[bi] : prompt_plus_gen_idx[bi] + num_tokens_to_append] = gen_ids prompt_plus_gen_idx[bi] += num_tokens_to_append generated_ids[bi].extend(gen_ids.tolist()) - if len(generated_ids[bi]) >= max_gen_len[bi]: + if len(generated_ids[bi]) + num_logits_to_keep >= max_gen_len[bi]: valid_batch_indices[bi] = False # check if all generations are done if not valid_batch_indices.any(): diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index b78afdc38..6f6bdb268 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -25,11 +25,22 @@ 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name + "JackFram/llama-160m", # draft_model_name + "JackFram/llama-160m", # target_model_name 1, # full_batch_size id="CB llama", ), + pytest.param( + Constants.INPUT_STR, # prompts + 4, # num_speculative_tokens + 32, # prefill_seq_len + 128, # ctx_len + 1, # prefill_bsz + "Qwen/Qwen2-0.5B", # draft_model_name + "Qwen/Qwen2-0.5B", # target_model_name + 1, # full_batch_size + id="CB qwen", + ), ] @@ -94,7 +105,6 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs): @pytest.mark.on_qaic -@pytest.mark.skip() # remove when the SDK 1.20.0 issue solved for compiling this model @pytest.mark.parametrize( "prompts, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size", configs, @@ -119,18 +129,20 @@ def test_spec_decode_inference( if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id vocab_size = len(tokenizer) + if target_model_name == "Qwen/Qwen2-0.5B": + vocab_size = 151936 # export_and_compile tlm and dlm continuous_batching = full_batch_size is not None + qaic_config = dict(speculative_model_type="target") target_model = AutoModelForCausalLM.from_pretrained( - target_model_name, continuous_batching=continuous_batching, is_tlm=True + target_model_name, continuous_batching=continuous_batching, qaic_config=qaic_config ) draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=continuous_batching) - num_devices = len(device_group) target_model_qpc_path: str = target_model.compile( - num_cores=11, - num_devices=num_devices, + num_cores=6, + num_devices=1, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, aic_enable_depth_first=True, @@ -138,7 +150,7 @@ def test_spec_decode_inference( num_speculative_tokens=num_speculative_tokens, ) draft_model_qpc_path: str = draft_model.compile( - num_cores=5, + num_cores=2, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, aic_enable_depth_first=True, @@ -169,6 +181,7 @@ def test_spec_decode_inference( p_tok: dict = tokenizer(p, return_tensors="np", padding="max_length", max_length=input_len_padded) position_ids = np.where(p_tok.pop("attention_mask"), np.arange(input_len_padded), -1) p_tok["position_ids"] = position_ids + p_tok["num_logits_to_keep"] = np.zeros((1, 1), dtype=np.int64) prompts_tokenized.append(p_tok) # create caches to hold generated ids and input prompt lengths generated_ids = [[] for i in range(decode_batch_size)] @@ -181,13 +194,14 @@ def test_spec_decode_inference( np.array(np.arange(decode_batch_size), np.int64), (decode_batch_size, 1) ) # mock input key "logits" to store the first batch of output logits + num_logits_to_keep = num_speculative_tokens + 1 tlm_precode_inputs = dict( input_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), position_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), batch_index=np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1), + num_logits_to_keep=np.zeros((num_logits_to_keep, 1), dtype=np.int64), ) max_gen_len = [ctx_len] * decode_batch_size - num_logits_to_keep = num_speculative_tokens + 1 # setup buffers tlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) dlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) @@ -267,7 +281,7 @@ def test_spec_decode_inference( accepted_tokens = num_tokens_selected[bi] num_tokens_to_append = min(accepted_tokens, max_gen_len[bi] - len(generated_ids[bi])) generated_ids[bi].extend(target_tokens[bi, :num_tokens_to_append].tolist()) - if len(generated_ids[bi]) >= max_gen_len[bi]: + if len(generated_ids[bi]) + num_logits_to_keep >= max_gen_len[bi]: valid_batch_indices[bi] = False # check if all generations are done if not valid_batch_indices.any(): diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index d7151aad7..7c57459bd 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -18,6 +18,7 @@ from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform +from QEfficient.transformers.spd.turbo import ResBlock from QEfficient.utils._utils import get_padding_shape_from_config from QEfficient.utils.logging_utils import logger @@ -71,6 +72,11 @@ ("llama", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), ("llama", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), ("llama", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("qwen2", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +] + +SpDTransformProjTestConfigs = [ + ("llama", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), ] @@ -160,10 +166,27 @@ def run_kv_cache_transform_and_test( ) else: original_model_outputs = hf_model(input_ids=input_ids, output_hidden_states=True) + hidden_size_projections = ( + hf_model.hidden_size_projections if hasattr(hf_model, "hidden_size_projections") else None + ) + if hidden_size_projections: + # compute projections + last_hidden_size = original_model_outputs.hidden_states[-1] # shape: [bsz, seq_len, d_model] + proj_hidden_sizes = [last_hidden_size] + for proj in hidden_size_projections: + proj_i = proj(last_hidden_size) + proj_hidden_sizes.append(proj_i) + proj_hidden_sizes = torch.stack(proj_hidden_sizes, dim=2) + logits = hf_model.lm_head(proj_hidden_sizes) + original_model_outputs.logits = logits # Apply transforms - is_tlm = "num_logits_to_keep" in qaic_model_inputs - hf_model = QEFFAutoModelForCausalLM(hf_model, is_tlm=is_tlm).model + qaic_config = None + if "num_logits_to_keep" in qaic_model_inputs: + qaic_config = dict(speculative_model_type="target") + hf_model = QEFFAutoModelForCausalLM(hf_model, qaic_config=qaic_config).model + if hidden_size_projections is not None: + hf_model.projections = hidden_size_projections # Run KV model with torch.inference_mode(): @@ -300,6 +323,62 @@ def test_spd_transform(config_class, num_hidden_layers, num_attention_heads, hid ) +@pytest.mark.parametrize( + "config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance", + SpDTransformProjTestConfigs, +) +def test_spd_proj_transform( + config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance +): + config = AutoConfig.for_model( + config_class, + **kwargs, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + use_cache=True, + cache_position=None, + position_embeddings=None, + ) + hf_model = AutoModelForCausalLM.from_config(config=config, attn_implementation="eager") + proj_num_layers = 1 + num_speculative_tokens = 3 + hidden_size_projections = torch.nn.ModuleList( + [ + torch.nn.Sequential( + *([ResBlock(hidden_size)] * proj_num_layers), + ) + for _ in range(num_speculative_tokens) + ], + ) + hf_model.hidden_size_projections = hidden_size_projections + + kv_cache = None + if hasattr(config, "cache_implementation") and config.cache_implementation == "hybrid": + # Create a KV Cache from HybridCache class to pass as an object for models which use Hybrid KV Cache + # Refer https://github.com/huggingface/transformers/issues/32896 for more info + # This requires torch._dynamo present in torch>=2.3.0 + kv_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=32) + + padding_shape = get_padding_shape_from_config(config=config, batch_size=1, seq_len=32) + + # Prepare KV model inputs + qaic_model_inputs = create_qaic_model_inputs( + input_len=8, + vocab_size=config.vocab_size, + padding_shape=padding_shape, + num_hidden_layers=num_hidden_layers, + is_tlm=True, + ) + + run_kv_cache_transform_and_test( + hf_model, + qaic_model_inputs=qaic_model_inputs, + logits_tolerance=logits_tolerance, + kv_cache=kv_cache, + ) + + @pytest.mark.parametrize("in_features", [2048, 4096]) @pytest.mark.parametrize("out_features", [2048, 4096]) @pytest.mark.skipif(platform.machine() == "aarch64", reason="Test skipped on aarch64 platform")