Skip to content

Commit 5c471b6

Browse files
authored
SpD, multiprojection heads (#306)
### Objective This PR seeks to implement post-attention hidden size projections used to speculate tokens ahead of the base model. This PR contains three primary components: 1. extending base model with multi-projection in `modeling_auto.py` 2. implementing multi-projections forward pass. 3. app demo of multi-projection model. ### Initial Implementation Initial implementation gives the user the flexibility to define its own projection architecture and pass it to `QEffAutoModelForCausalLM`. Then, QEfficient simply attaches these projections to the model to be used during the forward pass. The "attaching" of these projections is done by using the `accelerate` library. I used this library because it has a robust implementation to attach weights to an already existing model. We can implement our own abstraction if needed, but first we must agree on what the external API will be to the user. > NOTE: Please keep in mind that to integrate medusa, similar changes will be needed (instead of doing multiple hidden size projections, medusa uses multiple `lm_heads` to speculate ahead of the base model). --------- Signed-off-by: eplatero <[email protected]>
1 parent 4695485 commit 5c471b6

File tree

13 files changed

+389
-45
lines changed

13 files changed

+389
-45
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,7 +1298,7 @@ def __init__(
12981298
self,
12991299
model: nn.Module,
13001300
continuous_batching: bool = False,
1301-
is_tlm: bool = False,
1301+
qaic_config: Optional[dict] = None,
13021302
**kwargs,
13031303
):
13041304
model_class_name = model.__class__.__name__
@@ -1324,11 +1324,8 @@ def __init__(
13241324
self.model.config.use_cache = True
13251325
self.num_layers = model.config.num_hidden_layers
13261326
self.continuous_batching = continuous_batching
1327-
1328-
if is_tlm:
1329-
# TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch
1330-
self.model, transformed = SpDTransform.apply(self.model)
1331-
self.is_tlm = is_tlm
1327+
self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs)
1328+
self.is_tlm = transformed
13321329

13331330
@property
13341331
def model_name(self) -> str:
@@ -1343,7 +1340,12 @@ def __repr__(self) -> str:
13431340
@classmethod
13441341
@with_replaced_quantizers
13451342
def from_pretrained(
1346-
cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs
1343+
cls,
1344+
pretrained_model_name_or_path,
1345+
continuous_batching: bool = False,
1346+
qaic_config: Optional[dict] = None,
1347+
*args,
1348+
**kwargs,
13471349
):
13481350
"""
13491351
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM.
@@ -1388,6 +1390,8 @@ def from_pretrained(
13881390

13891391
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
13901392
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
1393+
if qaic_config is not None:
1394+
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
13911395

13921396
# This is support models that should be classified to in a different auto class but transformers load them via this class
13931397

@@ -1396,7 +1400,12 @@ def from_pretrained(
13961400
model, kv_offload=kv_offload
13971401
)
13981402

1399-
return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching)
1403+
return cls(
1404+
model,
1405+
continuous_batching=continuous_batching,
1406+
qaic_config=qaic_config,
1407+
**kwargs,
1408+
)
14001409

14011410
@property
14021411
def model_hash(self) -> str:
@@ -1571,15 +1580,7 @@ def compile(
15711580
raise TypeError("`prefill_only` must be a boolean.")
15721581

15731582
if self.is_tlm:
1574-
if num_speculative_tokens is None:
1575-
raise TypeError("`num_speculative_tokens` is required when `is_tlm=True`.")
1576-
if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2:
1577-
raise ValueError("`num_speculative_tokens` must be an integer >= 2.")
1578-
if prefill_seq_len < (num_speculative_tokens + 1):
1579-
raise ValueError(
1580-
f"`prefill_seq_len` must be at least `num_speculative_tokens + 1` "
1581-
f"({num_speculative_tokens + 1}), got {prefill_seq_len}."
1582-
)
1583+
num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len)
15831584

15841585
if self.continuous_batching and full_batch_size is None:
15851586
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")
@@ -1674,6 +1675,29 @@ def generate(
16741675
else:
16751676
raise NotImplementedError("Only AI_100 runtime is supported right now via generate API")
16761677

1678+
def check_and_get_num_speculative_tokens(self, num_speculative_tokens: Optional[int], prefill_seq_len: int):
1679+
if hasattr(self.model.config, "speculative_config"):
1680+
num_speculative_tokens_ = self.model.config.speculative_config["num_speculative_tokens"]
1681+
if num_speculative_tokens is not None:
1682+
logger.warning(
1683+
f"arg `num_speculative_tokens` is a fixed value of {num_speculative_tokens_} for this model."
1684+
f" Passed value of {num_speculative_tokens} will be ignored."
1685+
)
1686+
num_speculative_tokens = num_speculative_tokens_
1687+
elif num_speculative_tokens is None:
1688+
raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.")
1689+
1690+
if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2:
1691+
ValueError(
1692+
f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}"
1693+
)
1694+
num_logits_to_keep = num_speculative_tokens + 1
1695+
if prefill_seq_len < num_logits_to_keep:
1696+
raise ValueError(
1697+
f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})"
1698+
)
1699+
return num_speculative_tokens
1700+
16771701

16781702
class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin):
16791703
"""

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# -----------------------------------------------------------------------------
77

88
from types import MethodType
9-
from typing import Tuple
9+
from typing import Optional, Tuple
1010

1111
from torch import nn
1212
from transformers.models.codegen.modeling_codegen import (
@@ -266,7 +266,10 @@
266266
QEffWhisperModel,
267267
QEffWhisperPositionalEmbedding,
268268
)
269-
from QEfficient.transformers.spd.causal_lm_forward import tlm_forward
269+
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry
270+
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward
271+
272+
SPD_TARGET = "target"
270273

271274

272275
class CustomOpsTransform(ModuleMappingTransform):
@@ -423,19 +426,33 @@ class SpDTransform:
423426
_module_mapping = {
424427
# Llama
425428
QEffLlamaForCausalLM,
429+
QEffQwen2ForCausalLM,
426430
}
427431

428432
@classmethod
429-
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
433+
def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]:
430434
transformed = False
431-
if (model_class := model.__class__) in cls._module_mapping:
435+
if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None:
436+
return model, transformed
437+
elif speculative_model_type not in (
438+
supported_spd_model_types := [SPD_TARGET] + list(model_type_registry.keys())
439+
):
440+
raise ValueError(
441+
f"Specualtive model type {speculative_model_type} is not supported. we currently only support {supported_spd_model_types}"
442+
)
443+
elif (model_class := model.__class__) in cls._module_mapping:
432444
model.forward = MethodType(tlm_forward, model)
445+
if speculative_model_type != SPD_TARGET:
446+
# build and attach draft mlp
447+
pretrained_model_name_or_path = qaic_config["pretrained_model_name_or_path"]
448+
model = build_and_attach_mlp(
449+
model, pretrained_model_name_or_path, speculative_model_type=speculative_model_type, **kwargs
450+
)
433451
transformed = True
434452
else:
435453
raise NotImplementedError(
436454
f"model class {model_class} does not yet support returning multiple logits to keep."
437455
)
438-
439456
return model, transformed
440457

441458

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
from QEfficient.transformers.spd.turbo import build_and_attach_turbo
9+
from QEfficient.utils.spd_utils import get_speculative_config, get_speculative_weights
10+
11+
model_type_registry = dict(turbo=build_and_attach_turbo)
12+
13+
14+
def build_and_attach_mlp(model, pretrained_model_name_or_path, speculative_model_type: str, **kwargs):
15+
speculative_config: dict = get_speculative_config(pretrained_model_name_or_path, **kwargs)
16+
speculative_weights: str = get_speculative_weights(pretrained_model_name_or_path, **kwargs)
17+
18+
if (model_type := speculative_config.get("model_type")) is None:
19+
speculative_config["model_type"] = speculative_model_type
20+
else:
21+
if model_type != speculative_model_type:
22+
raise ValueError(
23+
f"`model_type` key from speculator config ({model_type} does not match input model type ({speculative_model_type})."
24+
)
25+
func = model_type_registry[speculative_model_type]
26+
model = func(model, speculative_config, speculative_weights)
27+
model.config.speculative_config = speculative_config
28+
return model

QEfficient/transformers/spd/causal_lm_forward.py renamed to QEfficient/transformers/spd/spd_transform_forward.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def filter_hidden_states(
2121
Filter hidden states based on whether this is a TLM SpD model
2222
2323
``Mandatory`` Args:
24-
:hidden_states (torch.Tensor): Hidden states tensor.
24+
:hidden_states (torch.Tensor): Last hidden state tensor.
2525
:position_ids (torch.Tensor): Position ids tensor.
2626
``Optional`` Args:
2727
:num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model
@@ -50,6 +50,26 @@ def filter_hidden_states(
5050
return hidden_states
5151

5252

53+
def project_hidden_states(hidden_states: torch.Tensor, hidden_size_projections: torch.nn.ModuleList) -> torch.Tensor:
54+
"""
55+
Filter hidden states based on whether this is a TLM SpD model
56+
``Mandatory`` Args:
57+
:hidden_states (torch.Tensor): Last hidden state tensor.
58+
:hidden_size_projections (torch.nn.ModuleList): Position ids tensor.
59+
``Optional`` Args:
60+
:num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model
61+
Returns:
62+
:torch.Tensor: Filtered hidden states.
63+
"""
64+
proj_hidden_states = [hidden_states]
65+
num_projs = len(hidden_size_projections)
66+
for i in range(num_projs):
67+
hidden_states_i = hidden_size_projections[i](hidden_states)
68+
proj_hidden_states.append(hidden_states_i)
69+
hidden_states = torch.stack(proj_hidden_states, dim=2) # shape: [bsz, seq_len, num_projs, d_model]
70+
return hidden_states
71+
72+
5373
def tlm_forward(
5474
self,
5575
input_ids: torch.LongTensor = None,
@@ -113,7 +133,10 @@ def tlm_forward(
113133
)
114134

115135
hidden_states = filter_hidden_states(outputs[0], position_ids, num_logits_to_keep)
116-
if self.config.pretraining_tp > 1:
136+
hidden_size_projections = getattr(self, "projections", None)
137+
if hidden_size_projections:
138+
hidden_states = project_hidden_states(hidden_states, hidden_size_projections)
139+
if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1:
117140
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
118141
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
119142
logits = torch.cat(logits, dim=-1)

QEfficient/transformers/spd/turbo.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import torch
9+
10+
from QEfficient.utils.checkpoint_utils import load_checkpoint
11+
12+
13+
class ResBlock(torch.nn.Module):
14+
"""
15+
A Residual Block module.
16+
This module performs a linear transformation followed by a SiLU activation,
17+
and then adds the result to the original input, creating a residual connection.
18+
Args:
19+
hidden_size (int): The size of the hidden layers in the block.
20+
"""
21+
22+
def __init__(self, hidden_size):
23+
super().__init__()
24+
self.linear = torch.nn.Linear(hidden_size, hidden_size)
25+
# Initialize as an identity mapping
26+
torch.nn.init.zeros_(self.linear.weight)
27+
# Use SiLU activation to keep consistent with the Llama model
28+
self.act = torch.nn.SiLU()
29+
30+
def forward(self, x):
31+
"""
32+
Forward pass of the ResBlock.
33+
Args:
34+
x (torch.Tensor): Input tensor.
35+
Returns:
36+
torch.Tensor: Output after the residual connection and activation.
37+
"""
38+
return x + self.act(self.linear(x))
39+
40+
41+
def post_process_turbo_state_dict(state_dict: dict) -> dict:
42+
"""normaize turbo state dict keys
43+
Args:
44+
state_dict (dict): turbo state dict
45+
Returns:
46+
dict: normalized state dict
47+
"""
48+
new_state_dict = dict()
49+
for name, weights in state_dict.items():
50+
new_name = name.replace("projections.", "")
51+
new_state_dict[new_name] = weights
52+
return new_state_dict
53+
54+
55+
def build_and_attach_turbo(model, speculative_config: dict, speculative_weights: str):
56+
"""build and attach turbo projections
57+
Args:
58+
model: model to attach projections to
59+
speculative_config (dict): speculative config file used to build projections
60+
Returns:
61+
model: model with turbo projections
62+
"""
63+
hidden_size = model.config.hidden_size
64+
num_layers = speculative_config["turbo_num_layers"]
65+
num_heads = speculative_config["turbo_num_heads"]
66+
projections = torch.nn.ModuleList(
67+
[
68+
torch.nn.Sequential(
69+
*([ResBlock(hidden_size)] * num_layers),
70+
)
71+
for _ in range(num_heads)
72+
],
73+
)
74+
load_checkpoint(projections, speculative_weights, strict=True, post_process_func=post_process_turbo_state_dict)
75+
model.projections = projections
76+
speculative_config["num_speculative_tokens"] = num_heads
77+
return model

QEfficient/utils/_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import inspect
89
import json
910
import os
1011
import subprocess
@@ -626,3 +627,16 @@ def make_serializable(obj):
626627
qconfigs["qpc_config"]["aic_compiler_config"] = aic_compiler_config
627628

628629
create_json(qconfig_file_path, qconfigs)
630+
631+
632+
def filter_kwargs(func, kwargs):
633+
"""
634+
Filter a dictionary of keyword arguments to only include the valid arguments of a function.
635+
Args:
636+
func: The function to check the arguments for.
637+
kwargs: The dictionary of keyword arguments to filter.
638+
Returns:
639+
A new dictionary containing only the valid keyword arguments.
640+
"""
641+
valid_args = inspect.signature(func).parameters
642+
return {key: value for key, value in kwargs.items() if key in valid_args}

QEfficient/utils/checkpoint_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
from safetensors.torch import load_file
9+
10+
11+
def load_checkpoint(model, checkpoint: str, strict=False, post_process_func=None):
12+
"""load weights ending with `.safetensors` extension
13+
Args:
14+
model: model to load wights into
15+
checkpoint (str): checkpoint path
16+
strict (bool, optional): strictness of loading weights. Defaults to False.
17+
post_process_func (optional): Optional post-processing of loaded state dict. Defaults to None.
18+
Returns:
19+
model: model with applied weights
20+
"""
21+
state_dict: dict = load_file(checkpoint)
22+
if post_process_func is not None:
23+
state_dict = post_process_func(state_dict)
24+
model.load_state_dict(state_dict, strict=strict)
25+
return model

0 commit comments

Comments
 (0)