Skip to content

SpD, multiprojection heads #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 44 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
60894c4
1st draft of adding proj heads
eplatero97 Feb 17, 2025
489fe85
multiproj head app
eplatero97 Mar 3, 2025
7e83364
giving user flexibility to define their own projection architecture
eplatero97 Mar 4, 2025
3d1c00c
take off unnecessary comments
eplatero97 Mar 4, 2025
2aeb313
rebasing
eplatero97 Mar 12, 2025
51f33a7
lint fix
eplatero97 Mar 12, 2025
e3da534
lint fixes
eplatero97 Mar 12, 2025
4ef24b0
add Tuple typhint
eplatero97 Mar 12, 2025
1eff3fc
lint fix
eplatero97 Mar 12, 2025
3a8e9bd
lint fix
eplatero97 Mar 12, 2025
fde155d
lint fix
eplatero97 Mar 12, 2025
3a4c3de
adding pytorch unit tests
eplatero97 Mar 14, 2025
9654b00
reformatting
eplatero97 Mar 14, 2025
1c9f766
refactoring a bit
eplatero97 Mar 18, 2025
56cabce
pytorch unit test
eplatero97 Mar 18, 2025
323d43e
adding vanilla spd support for qwen model
eplatero97 Mar 18, 2025
9f3c9ef
generalize to qwen
eplatero97 Mar 18, 2025
990e380
validated qwen vanilla spd unit tests
eplatero97 Mar 18, 2025
238acdb
lint fix
eplatero97 Mar 18, 2025
b15180e
adding back commented tests
eplatero97 Mar 20, 2025
9f1b618
adding ignore eos option
eplatero97 Mar 20, 2025
b4bed19
add ignore-eos argparse
eplatero97 Mar 20, 2025
4fa0e75
got rid of clone on hidde_states that caused issue when exporting pro…
eplatero97 Mar 25, 2025
6a9f6db
lint fix
eplatero97 Mar 25, 2025
c39e1a0
using tinyllama on llama tests
eplatero97 Mar 25, 2025
501377e
lint fix
eplatero97 Mar 25, 2025
799afd4
add back automatic chooser of qid
eplatero97 Mar 25, 2025
3d330ee
integrate turbo definition within QEff
eplatero97 Apr 9, 2025
47ea6e6
add variable
eplatero97 Apr 9, 2025
2c55faa
lint fixes
eplatero97 Apr 9, 2025
dcc528e
other lint
eplatero97 Apr 9, 2025
590a517
hopefully imports are sorted now
eplatero97 Apr 9, 2025
6ffde3f
cleaned inference app and unit tests to new flow
eplatero97 Apr 10, 2025
b89f0e6
lint fix
eplatero97 Apr 10, 2025
f251ecf
added explicit errors rather than asserts
eplatero97 Apr 10, 2025
1a299f6
changed based on 1st round of feedback
eplatero97 Apr 11, 2025
7d39975
passing from_pretrained kwargs to init
eplatero97 Apr 14, 2025
6418388
fixing 2nd round of revisions
eplatero97 Apr 21, 2025
8d743d7
Merge branch 'main' into turbo
eplatero97 Apr 21, 2025
1e8aab8
lint fix
eplatero97 Apr 21, 2025
1d51616
changing user-api to use to specify an tlm speculative decoding model
eplatero97 Apr 24, 2025
6e12b18
Merge branch 'main' into turbo
eplatero97 Apr 24, 2025
c786592
lint fix
eplatero97 Apr 24, 2025
a82f131
remove unnecessary comment
eplatero97 Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ def __init__(
self,
model: nn.Module,
continuous_batching: bool = False,
is_tlm: bool = False,
qaic_config: Optional[dict] = None,
**kwargs,
):
model_class_name = model.__class__.__name__
Expand All @@ -1317,11 +1317,9 @@ def __init__(
self.model.config.use_cache = True
self.num_layers = model.config.num_hidden_layers
self.continuous_batching = continuous_batching

if is_tlm:
# TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch
self.model, transformed = SpDTransform.apply(self.model)
self.is_tlm = is_tlm
# TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch
self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs)
self.is_tlm = transformed

@property
def model_name(self) -> str:
Expand All @@ -1336,7 +1334,12 @@ def __repr__(self) -> str:
@classmethod
@with_replaced_quantizers
def from_pretrained(
cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs
cls,
pretrained_model_name_or_path,
continuous_batching: bool = False,
qaic_config: Optional[dict] = None,
*args,
**kwargs,
):
"""
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM.
Expand Down Expand Up @@ -1381,6 +1384,8 @@ def from_pretrained(

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
if qaic_config is not None:
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path

# This is support models that should be classified to in a different auto class but transformers load them via this class

Expand All @@ -1389,7 +1394,12 @@ def from_pretrained(
model, kv_offload=kv_offload
)

return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching)
return cls(
model,
continuous_batching=continuous_batching,
qaic_config=qaic_config,
**kwargs,
)

@property
def model_hash(self) -> str:
Expand Down Expand Up @@ -1564,16 +1574,9 @@ def compile(
raise TypeError("`prefill_only` must be a boolean.")

if self.is_tlm:
if num_speculative_tokens is None:
raise TypeError("`num_speculative_tokens` is required when `is_tlm=True`.")
if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2:
raise ValueError("`num_speculative_tokens` must be an integer >= 2.")
if prefill_seq_len < (num_speculative_tokens + 1):
raise ValueError(
f"`prefill_seq_len` must be at least `num_speculative_tokens + 1` "
f"({num_speculative_tokens + 1}), got {prefill_seq_len}."
)

num_speculative_tokens: int = self.check_and_get_num_speculative_tokens(
num_speculative_tokens, prefill_seq_len
)
if self.continuous_batching and full_batch_size is None:
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")

Expand Down Expand Up @@ -1667,6 +1670,29 @@ def generate(
else:
raise NotImplementedError("Only AI_100 runtime is supported right now via generate API")

def check_and_get_num_speculative_tokens(self, num_speculative_tokens: Optional[int], prefill_seq_len: int):
if hasattr(self.model.config, "speculative_config"):
num_speculative_tokens_ = self.model.config.speculative_config["num_speculative_tokens"]
if num_speculative_tokens is not None:
logger.warning(
f"arg `num_speculative_tokens` is a fixed value of {num_speculative_tokens_} for this model."
f" Passed value of {num_speculative_tokens} will be ignored."
)
num_speculative_tokens = num_speculative_tokens_
elif num_speculative_tokens is None:
raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.")

if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2:
ValueError(
f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}"
)
num_logits_to_keep = num_speculative_tokens + 1
if prefill_seq_len < num_logits_to_keep:
raise ValueError(
f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})"
)
return num_speculative_tokens


class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin):
"""
Expand Down
75 changes: 30 additions & 45 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from types import MethodType
from typing import Tuple
from typing import Optional, Tuple

import transformers
from torch import nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
Expand Down Expand Up @@ -47,17 +48,6 @@
GraniteAttention,
GraniteForCausalLM,
GraniteModel,
GraniteRMSNorm,
)
from transformers.models.granitemoe.modeling_granitemoe import (
GraniteMoeAttention,
GraniteMoeForCausalLM,
GraniteMoeModel,
GraniteMoeMoE,
GraniteMoeParallelExperts,
GraniteMoeRMSNorm,
GraniteMoeRotaryEmbedding,
GraniteMoeTopKGating,
)
from transformers.models.llama.modeling_llama import (
LlamaAttention,
Expand All @@ -69,9 +59,6 @@
from transformers.models.llava.modeling_llava import (
LlavaForConditionalGeneration,
)
from transformers.models.llava_next.modeling_llava_next import (
LlavaNextForConditionalGeneration,
)
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
Expand Down Expand Up @@ -133,6 +120,7 @@

from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMethodMapperTransform
from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC
from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.models.codegen.modeling_codegen import (
QEffCodeGenAttention,
QeffCodeGenBlock,
Expand Down Expand Up @@ -180,16 +168,10 @@
QEffGraniteForCausalLM,
QEffGraniteModel,
)
from QEfficient.transformers.models.granitemoe.modeling_granitemoe import (
QEffGraniteMoeAttention,
QEffGraniteMoeForCausalLM,
QEffGraniteMoeModel,
QEffGraniteMoeMoE,
QEffGraniteMoeParallelExperts,
QEffGraniteMoeRotaryEmbedding,
QEffGraniteMoeTopKGating,
from QEfficient.transformers.models.internvl.modeling_internvl import (
QEffInternVisionEmbeddings,
QEffInternVLModel,
)
from QEfficient.transformers.models.internvl.modeling_internvl import QEffInternVisionEmbeddings, QEffInternVLModel
from QEfficient.transformers.models.llama.modeling_llama import (
QEffLlamaAttention,
QEffLlamaDecoderLayer,
Expand All @@ -199,9 +181,6 @@
from QEfficient.transformers.models.llava.modeling_llava import (
QEffLlavaForConditionalGeneration,
)
from QEfficient.transformers.models.llava_next.modeling_llava_next import (
QEffLlavaNextForConditionalGeneration,
)
from QEfficient.transformers.models.mistral.modeling_mistral import (
QEffMistralAttention,
QEffMistralDecoderLayer,
Expand Down Expand Up @@ -266,7 +245,10 @@
QEffWhisperModel,
QEffWhisperPositionalEmbedding,
)
from QEfficient.transformers.spd.causal_lm_forward import tlm_forward
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward

SPD_TARGET = "target"


class CustomOpsTransform(ModuleMappingTransform):
Expand All @@ -279,8 +261,6 @@ class CustomOpsTransform(ModuleMappingTransform):
Phi3RMSNorm: CustomRMSNormAIC,
Qwen2RMSNorm: CustomRMSNormAIC,
MllamaTextRMSNorm: CustomRMSNormAIC,
GraniteRMSNorm: CustomRMSNormAIC,
GraniteMoeRMSNorm: CustomRMSNormAIC,
}


Expand Down Expand Up @@ -313,8 +293,6 @@ class KVCacheTransform(ModuleMappingTransform):
LlamaForCausalLM: QEffLlamaForCausalLM,
# Llava
LlavaForConditionalGeneration: QEffLlavaForConditionalGeneration,
# Llava Next
LlavaNextForConditionalGeneration: QEffLlavaNextForConditionalGeneration,
# Gemma
GemmaAttention: QEffGemmaAttention,
GemmaDecoderLayer: QEffGemmaDecoderLayer,
Expand All @@ -329,14 +307,6 @@ class KVCacheTransform(ModuleMappingTransform):
GraniteModel: QEffGraniteModel,
GraniteForCausalLM: QEffGraniteForCausalLM,
GraniteAttention: QEffGraniteAttention,
# GraniteMoe
GraniteMoeModel: QEffGraniteMoeModel,
GraniteMoeForCausalLM: QEffGraniteMoeForCausalLM,
GraniteMoeAttention: QEffGraniteMoeAttention,
GraniteMoeRotaryEmbedding: QEffGraniteMoeRotaryEmbedding,
GraniteMoeParallelExperts: QEffGraniteMoeParallelExperts,
GraniteMoeTopKGating: QEffGraniteMoeTopKGating,
GraniteMoeMoE: QEffGraniteMoeMoE,
# mllama
MllamaTextRMSNorm: CustomRMSNormAIC,
MllamaTextSelfAttention: QEffMllamaTextSelfAttention,
Expand Down Expand Up @@ -401,6 +371,8 @@ class KVCacheTransform(ModuleMappingTransform):
@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
model, transformed = super().apply(model)
# FIXME: see if we can merge into _module_mapping dict
transformers.cache_utils.DynamicCache.update = QEffDynamicCache.update
return model, transformed


Expand All @@ -421,21 +393,34 @@ class SpDTransform:

# supported architectures
_module_mapping = {
# Llama
QEffLlamaForCausalLM,
QEffQwen2ForCausalLM,
}

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]:
transformed = False
if (model_class := model.__class__) in cls._module_mapping:
if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None:
return model, transformed
elif speculative_model_type not in (
supported_spd_model_types := [SPD_TARGET] + list(model_type_registry.keys())
):
raise ValueError(
f"Specualtive model type {speculative_model_type} is not supported. we currently only support {supported_spd_model_types}"
)
elif (model_class := model.__class__) in cls._module_mapping:
model.forward = MethodType(tlm_forward, model)
if speculative_model_type != SPD_TARGET:
# build and attach draft mlp
pretrained_model_name_or_path = qaic_config["pretrained_model_name_or_path"]
model = build_and_attach_mlp(
model, pretrained_model_name_or_path, speculative_model_type=speculative_model_type, **kwargs
)
transformed = True
else:
raise NotImplementedError(
f"model class {model_class} does not yet support returning multiple logits to keep."
)

return model, transformed


Expand Down
28 changes: 28 additions & 0 deletions QEfficient/transformers/post_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from QEfficient.transformers.spd.turbo import build_and_attach_turbo
from QEfficient.utils.spd_utils import get_speculative_config, get_speculative_weights

model_type_registry = dict(turbo=build_and_attach_turbo)


def build_and_attach_mlp(model, pretrained_model_name_or_path, speculative_model_type: str, **kwargs):
speculative_config: dict = get_speculative_config(pretrained_model_name_or_path, **kwargs)
speculative_weights: str = get_speculative_weights(pretrained_model_name_or_path, **kwargs)

if (model_type := speculative_config.get("model_type")) is None:
speculative_config["model_type"] = speculative_model_type
else:
if model_type != speculative_model_type:
raise ValueError(
f"`model_type` key from speculator config ({model_type} does not match input model type ({speculative_model_type})."
)
func = model_type_registry[speculative_model_type]
model = func(model, speculative_config, speculative_weights)
model.config.speculative_config = speculative_config
return model
1 change: 0 additions & 1 deletion QEfficient/transformers/spd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def filter_hidden_states(
Filter hidden states based on whether this is a TLM SpD model

``Mandatory`` Args:
:hidden_states (torch.Tensor): Hidden states tensor.
:hidden_states (torch.Tensor): Last hidden state tensor.
:position_ids (torch.Tensor): Position ids tensor.
``Optional`` Args:
:num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model
Expand Down Expand Up @@ -50,6 +50,28 @@ def filter_hidden_states(
return hidden_states


def project_hidden_states(hidden_states, hidden_size_projections):
"""
Filter hidden states based on whether this is a TLM SpD model

``Mandatory`` Args:
:hidden_states (torch.Tensor): Last hidden state tensor.
:hidden_size_projections (torch.nn.ModuleList): Position ids tensor.
``Optional`` Args:
:num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model

Returns:
:torch.Tensor: Filtered hidden states.
"""
proj_hidden_states = [hidden_states]
num_projs = len(hidden_size_projections)
for i in range(num_projs):
hidden_states_i = hidden_size_projections[i](hidden_states)
proj_hidden_states.append(hidden_states_i)
hidden_states = torch.stack(proj_hidden_states, dim=2) # shape: [bsz, seq_len, num_projs, d_model]
return hidden_states


def tlm_forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -113,7 +135,10 @@ def tlm_forward(
)

hidden_states = filter_hidden_states(outputs[0], position_ids, num_logits_to_keep)
if self.config.pretraining_tp > 1:
hidden_size_projections = getattr(self, "projections", None)
if hidden_size_projections:
hidden_states = project_hidden_states(hidden_states, hidden_size_projections)
if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
Expand Down
Loading
Loading