Skip to content

Adding SwiftKV Infra changes in QEFF for execution of swiftKV models #285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
30 changes: 30 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,36 @@ class QEffDynamicCache(DynamicCache):

"""

def write_only(self, key_states, value_states, layer_idx, cache_kwargs):
# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
position_ids = cache_kwargs.get("position_ids")
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)

def read_only(self, layer_idx, cache_kwargs):
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
position_ids = cache_kwargs.get("position_ids")
ctx_len = k_out.shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

def update(
self,
key_states: torch.Tensor,
Expand Down
18 changes: 18 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,154 +5,157 @@
#
# -----------------------------------------------------------------------------

from collections import namedtuple
from typing import Dict, Optional, Tuple, Type

import torch
import torch.nn as nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
CodeGenBlock,
CodeGenForCausalLM,
CodeGenModel,
)
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconForCausalLM,
FalconModel,
)
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaDecoderLayer,
GemmaForCausalLM,
GemmaModel,
GemmaRMSNorm,
)
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2DecoderLayer,
Gemma2ForCausalLM,
Gemma2Model,
Gemma2RMSNorm,
)
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeAttention,
GPTBigCodeBlock,
GPTBigCodeForCausalLM,
GPTBigCodeModel,
)
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJForCausalLM, GPTJModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
MistralForCausalLM,
MistralModel,
MistralRMSNorm,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralModel,
MixtralRMSNorm,
MixtralSparseMoeBlock,
)
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
from transformers.models.phi.modeling_phi import PhiAttention, PhiForCausalLM, PhiModel
from transformers.models.phi3.modeling_phi3 import Phi3Attention, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2ForCausalLM, Qwen2Model, Qwen2RMSNorm
from transformers.models.starcoder2.modeling_starcoder2 import (
Starcoder2Attention,
Starcoder2DecoderLayer,
Starcoder2ForCausalLM,
Starcoder2Model,
)
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
WhisperDecoderLayer,
WhisperEncoder,
WhisperForConditionalGeneration,
WhisperModel,
WhisperPositionalEmbedding,
)

from QEfficient.customop import CustomRMSNormAIC

from .models.codegen.modeling_codegen import (
QEffCodeGenAttention,
QeffCodeGenBlock,
QEffCodeGenForCausalLM,
QEffCodeGenModel,
)
from .models.falcon.modeling_falcon import (
QEffFalconAttention,
QEffFalconForCausalLM,
QEffFalconModel,
)
from .models.gemma.modeling_gemma import QEffGemmaAttention, QEffGemmaDecoderLayer, QEffGemmaForCausalLM, QEffGemmaModel
from .models.gemma2.modeling_gemma2 import (
QEffGemma2Attention,
QEffGemma2DecoderLayer,
QEffGemma2ForCausalLM,
QEffGemma2Model,
)
from .models.gpt2.modeling_gpt2 import QEffGPT2Attention, QEffGPT2Block, QEffGPT2LMHeadModel, QEffGPT2Model
from .models.gpt_bigcode.modeling_gpt_bigcode import (
QEffGPTBigCodeAttention,
QEffGPTBigCodeBlock,
QEffGPTBigCodeForCausalLM,
QEffGPTBigCodeModel,
)
from .models.gptj.modeling_gptj import QEffGPTJAttention, QEffGPTJForCausalLM, QEffGPTJModel
from .models.llama.modeling_llama import (
QEffLlamaAttention,
QEffLlamaDecoderLayer,
QEffLlamaForCausalLM,
QEffLlamaModel,
)
from .models.mistral.modeling_mistral import (
QEffMistralAttention,
QEffMistralDecoderLayer,
QEffMistralForCausalLM,
QEffMistralModel,
)
from .models.mixtral_moe.modeling_mixtral import (
QEffMixtralAttention,
QeffMixtralDecoderLayer,
QEffMixtralForCausalLM,
QEffMixtralModel,
QEffMixtralSparseMoeBlock,
)
from .models.mpt.modeling_mpt import QEffMptAttention, QEffMptBlock, QEffMptForCausalLM, QEFfMptModel
from .models.phi.modeling_phi import QEffPhiAttention, QEffPhiForCausalLM, QEffPhiModel
from .models.phi3.modeling_phi3 import QEffPhi3Attention, QEffPhi3ForCausalLM, QEffPhi3Model
from .models.qwen2.modeling_qwen2 import QEffQwen2Attention, QEffQwen2ForCausalLM, QEffQwen2Model
from .models.starcoder2.modeling_starcoder2 import (
QEffStarcoder2Attention,
QEFFStarcoder2DecoderLayer,
QEffStarcoder2ForCausalLM,
QEffStarcoder2Model,
)
from .models.whisper.modeling_whisper import (
QEffWhisperAttention,
QEffWhisperDecoder,
QEffWhisperDecoderLayer,
QEffWhisperEncoder,
QEffWhisperForConditionalGeneration,
QEffWhisperModel,
QEffWhisperPositionalEmbedding,
)

from QEfficient.transformers.models.llama_swiftkv.config_llama_swiftkv import LlamaSwiftKVConfig
from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import LlamaSwiftKVForCausalLM

Check failure on line 157 in QEfficient/transformers/modeling_utils.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/transformers/modeling_utils.py:8:1: I001 Import block is un-sorted or un-formatted

# Define a named tuple for ModelArchitectures
# Required for the Automation tool
ModelArchitectures = namedtuple("ModelArchitectures", ["architectures"])
Expand Down Expand Up @@ -362,3 +365,18 @@
attention_mask = attention_mask.unsqueeze(1)

return attention_mask


# Define a SwiftKV Model card name to Model type dictionary
# While onboarding new models make sure to add the new SwiftKV model card names to this dictionary.
SwiftKVModelCardNameToSwiftKVModelTypeDict: Dict[Type[str], Type[str]] = {
# LlamaSwiftKV Model
"Snowflake/Llama-3.1-SwiftKV-8B-Instruct": "llama_swiftkv"
}

# Define a SwiftKV Model type to ConfigClass and ModelArchitecture class dictionary
# While onboarding new models make sure to add the new SwiftKV model card names to this dictionary.
SwiftKVModelTypeToConfigClassAndModelArchClassDict = {
# LlamaSwiftKV Model
"llama_swiftkv": [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM]
}
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/llama_swiftkv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -----------------------------------------------------------------------------
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this peice of code in the modelling file

#
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
# The Modules are updated as required by Cloud AI 100 HW requirements.


"""Inference-only LLaMA model compatible with HuggingFace weights."""

from typing import Optional
from transformers import LlamaConfig

Check failure on line 13 in QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py:12:1: I001 Import block is un-sorted or un-formatted


class LlamaSwiftKVConfig(LlamaConfig):
"""
Args:
num_key_value_layers (int, optional):
The number of layers, from the first layer, that have keys and
values. If None, all layers have keys and values.
last_key_value_heads (int, optional):
The number of heads in the last layer that have keys and values.
If None, the number of heads in the last key-value layer is equal
to the number of heads in all the other key-value layers.
"""

model_type = "llama_swiftkv"

def __init__(
self,
swiftkv: bool = False,
num_key_value_layers: Optional[int] = None,
key_value_group_size: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.swiftkv = swiftkv
self.num_key_value_layers = num_key_value_layers or self.num_hidden_layers
self.key_value_group_size = key_value_group_size or 1
assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0
Loading
Loading