Skip to content

Added Infra in QEfficient for execution of swiftkv models #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion QEfficient/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Expand All @@ -12,8 +12,19 @@
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

from transformers import AutoConfig

from QEfficient.transformers.modeling_utils import MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS
from QEfficient.utils.logging_utils import logger

# loop over all the model types which are not present in transformers and register them
for model_type, model_cls in MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS.items():
# Register the model config class based on the model type. This will be first element in the tuple
AutoConfig.register(model_type, model_cls[0])

# Register the non transformer library Class and config class using AutoModelClass
model_cls[2].register(model_cls[0], model_cls[1])


def check_qaic_sdk():
"""Check if QAIC SDK is installed"""
Expand Down
52 changes: 52 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,58 @@ class QEffDynamicCache(DynamicCache):

"""

def write_only(self, key_states, value_states, layer_idx, cache_kwargs):
# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)

# Scatter
if batch_index is not None:
invalid_scatter_index = torch.iinfo(torch.int32).max
scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids)

self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
)
self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
)
else:
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], position_ids, value_states
)

def read_only(self, layer_idx, cache_kwargs):
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)
ctx_len = k_out.shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)

v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

def update(
self,
key_states: torch.Tensor,
Expand Down
12 changes: 12 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
CodeGenBlock,
Expand Down Expand Up @@ -88,6 +89,12 @@

from QEfficient.customop import CustomRMSNormAIC

# Placeholder for all non-transformer models
from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import (
QeffLlamaSwiftKVConfig,
QeffLlamaSwiftKVForCausalLM,
)

from .models.codegen.modeling_codegen import (
QEffCodeGenAttention,
QeffCodeGenBlock,
Expand Down Expand Up @@ -271,6 +278,11 @@
WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration,
}

# Map of model type to config class, Modelling class and transformer model architecture class
MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = {
"llama_swiftkv": [QeffLlamaSwiftKVConfig, QeffLlamaSwiftKVForCausalLM, AutoModelForCausalLM],
}


def _prepare_cross_attention_mask(
cross_attention_mask: torch.Tensor,
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/llama_swiftkv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Loading
Loading