From 229b75d5e0f3691f1a18fb5933ccba0a64c7a6a5 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 00:39:03 +0530 Subject: [PATCH 01/18] added initial version of SwiftKV for AI 100 Signed-off-by: Onkar Chougule --- QEfficient/transformers/cache_utils.py | 29 ++ .../llama_swiftkv/modeling_llama_swiftkv.py | 411 ++++++++++++++++++ exps/run_swiftkv.py | 28 ++ 3 files changed, 468 insertions(+) create mode 100644 QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py create mode 100644 exps/run_swiftkv.py diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index a5c375c6e..fe56b197c 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -36,6 +36,35 @@ class QEffDynamicCache(DynamicCache): """ + def write_only(self, key_states, value_states, layer_idx, cache_kwargs): + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + else: + position_ids = cache_kwargs.get("position_ids") + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + + def read_only(self, layer_idx, cache_kwargs): + position_ids = cache_kwargs.get("position_ids") + ctx_len = position_ids.shape[-1] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + def update( self, key_states: torch.Tensor, diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py new file mode 100644 index 000000000..a33c83d3a --- /dev/null +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -0,0 +1,411 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" + +import logging +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers.cache_utils import Cache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, LlamaRMSNorm, repeat_kv + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.transformers.models.llama.modeling_llama import ( + QEffLlamaDecoderLayer, + QEffLlamaRotaryEmbedding, + qeff_apply_rotary_pos_emb, +) + +logger = logging.get_logger(__name__) + + +class LlamaSwiftKVAttention(LlamaAttention): + def __init__(self, config, layer_idx) -> None: + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj_swiftkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj_swiftkv = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj_swiftkv = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask=None, + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + query, _ = self.q_proj_swiftkv(hidden_states) + + # Reshape the query, key, and value tensors. + query_states = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = position_ids.shape[-1] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + key_states, value_states = past_key_value.read_only(self.layer_idx, position_ids=position_ids) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, _ = qeff_apply_rotary_pos_emb(query_states, torch.empty_like(key_states), cos, sin, position_ids) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, past_key_value + + +class LlamaSwiftKVDecoderLayer(nn.Module): + def __init__(self, config, layer_idx) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LlamaSwiftKVAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, past_key_values = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + past_key_value=past_key_values, + attention_mask=causal_mask, + ) + + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, past_key_values + + +class LlamaSwiftKVModel(nn.Module): + def __init__(self, config): + super().__init__() + self.vocab_size = config.vocab_size + self.config = config + + self.embed_tokens = nn.Embedding( + self.vocab_size, config.hidden_size, None + ) # TODO: Not sure if padding_idx shoudl eb NONE + self.layers = torch.nn.ModuleList( + [ + QEffLlamaDecoderLayer(config=config, layer_idx=idx) + if idx < config.num_key_value_layers + else LlamaSwiftKVDecoderLayer(config=config, layer_idx=idx) + for idx in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def _run_swiftkv_layers( + self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask + ) -> torch.Tensor: + for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): + layer = self.layers[layer_idx] + + hidden_states, past_key_values = layer(hidden_states, position_ids, past_key_values, causal_mask) + + return hidden_states, past_key_values + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + self.config._attn_implementation = "eager" + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + else: + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def forward( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_values: List[torch.Tensor], + ): + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + use_cache = True + + if use_cache and not isinstance(past_key_values, Cache): + if past_key_values is None: + past_key_values = QEffDynamicCache() + else: + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(None, inputs_embeds, cache_position, past_key_values, False) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + next_decoder_cache = None + + for layer_idx in range(self.config.num_key_value_layers): + layer = self.layers[layer_idx] + hidden_states, next_decoder_cache = layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=False, + use_cache=True, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + swiftkv_hidden_states = self.norm_swiftkv(hidden_states) + + #################################### + ## THE MAGIC OF SWIFT KV BEGINS HERE + #################################### + for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): + self_attn = self.layers[layer_idx].self_attn + key_states = self_attn.k_proj_swiftkv(swiftkv_hidden_states) + value_states = self_attn.v_proj_swiftkv(swiftkv_hidden_states) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_values is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_values.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + _, key_states = qeff_apply_rotary_pos_emb( + torch.empty_like(swiftkv_hidden_states), key_states, cos, sin, position_ids + ) + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids} + past_key_values.write_only(key_states, value_states, self.layer_idx, cache_kwargs) + + hidden_states, next_decoder_cache = self._run_swiftkv_layers( + hidden_states, position_ids, past_key_values, causal_mask + ) + #################################### + ## THE MAGIC OF SWIFT KV ENDS HERE + #################################### + + next_cache = next_decoder_cache.to_legacy_cache() + return hidden_states, next_cache + + +class LlamaSwiftKVForCausalLM(nn.Module): + """ + # packed_modules_mapping = { + # "kv_proj_swiftkv": ["k_proj_swiftkv", "v_proj_swiftkv"], + # "qkv_proj": ["q_proj", "k_proj", "v_proj"], + # "gate_up_proj": ["gate_proj", "up_proj"], + # } + + # # BitandBytes specific attributes + # default_bitsandbytes_target_modules = [ + # ".gate_proj.", + # ".down_proj.", + # ".up_proj.", + # ".q_proj.", + # ".k_proj.", + # ".v_proj.", + # ".o_proj.", + # ".k_proj_swiftkv.", + # ".v_proj_swiftkv.", + # ] + + # # in TP, these weights are partitioned along the column dimension (dim=-1) + # column_parallel_weights_modules = [ + # ".q_proj_swiftkv.", + # ".down_proj.", + # ".o_proj.", + # ] + # bitsandbytes_stacked_params_mapping = { + # # shard_name, weight_name, index + # "k_proj_swiftkv": ("kv_proj_swiftkv", 1), + # "v_proj_swiftkv": ("kv_proj_swiftkv", 2), + # "q_proj": ("qkv_proj", 0), + # "k_proj": ("qkv_proj", 1), + # "v_proj": ("qkv_proj", 2), + # "gate_proj": ("gate_up_proj", 0), + # "up_proj": ("gate_up_proj", 1), + # } + """ + + def __init__(self, *, config): + super().__init__() + + self.model = LlamaSwiftKVModel( + config=config, + ) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + ): + hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values) + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + return logits, output_past_key_values diff --git a/exps/run_swiftkv.py b/exps/run_swiftkv.py new file mode 100644 index 000000000..cf180f609 --- /dev/null +++ b/exps/run_swiftkv.py @@ -0,0 +1,28 @@ +import json +import os + +from safetensors import safe_open + +from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import LlamaSwiftKVForCausalLM + +WEIGHTS = "/local/mnt/workspace/open-source/myown/efficient-transformers/cache_dir/swiftkv_model_weights" + + +def load_safetensors(path): + state_dict = {} + f = safe_open(path, framework="pt", device="cpu") + for key in f.keys(): + tensor = f.get_tensor(key) + state_dict[key] = tensor + return state_dict + + +config = json.load(open(os.path.join(WEIGHTS, "config.json"), "r")) + +config.num_hidden_layers = 1 + +model = LlamaSwiftKVForCausalLM(config=config) +state_dict_0 = load_safetensors(os.path.join(WEIGHTS, "model-00001-of-00009.safetensors")) + +for k in model.state_dict().keys() - state_dict_0.keys(): + del state_dict_0[k] From e91d8e2141de78623a3cd24b4daa4ff7c4783c14 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 01:36:22 +0530 Subject: [PATCH 02/18] BUGFIX Signed-off-by: Onkar Chougule --- .../models/llama_swiftkv/modeling_llama_swiftkv.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index a33c83d3a..5b5fcd77f 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -22,7 +22,6 @@ # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -import logging import math from typing import List, Optional, Tuple, Union @@ -30,7 +29,7 @@ from torch import nn from transformers.cache_utils import Cache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, LlamaRMSNorm, repeat_kv +from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, logger, repeat_kv from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask @@ -40,12 +39,10 @@ qeff_apply_rotary_pos_emb, ) -logger = logging.get_logger(__name__) - -class LlamaSwiftKVAttention(LlamaAttention): +class LlamaSwiftKVAttention(nn.Module): def __init__(self, config, layer_idx) -> None: - super().__init__(config, layer_idx) + super().__init__() self.hidden_size = config.hidden_size self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -56,7 +53,7 @@ def __init__(self, config, layer_idx) -> None: self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - + self.layer_idx = layer_idx self.q_proj_swiftkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj_swiftkv = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias From 58b89098892db765906175cd277090fd564922b0 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 01:39:46 +0530 Subject: [PATCH 03/18] BUGFIX Signed-off-by: Onkar Chougule --- .../transformers/models/llama_swiftkv/modeling_llama_swiftkv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 5b5fcd77f..2022d2c9b 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -63,7 +63,7 @@ def __init__(self, config, layer_idx) -> None: ) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) + self.rotary_emb = QEffLlamaRotaryEmbedding(config=config) def forward( self, From 12ec8bb9afad0dda736063db520a98ddbdf585db Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 01:46:12 +0530 Subject: [PATCH 04/18] BUGFIX Signed-off-by: Onkar Chougule --- .../transformers/models/llama_swiftkv/modeling_llama_swiftkv.py | 1 + 1 file changed, 1 insertion(+) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 2022d2c9b..4f22e82e0 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -394,6 +394,7 @@ def __init__(self, *, config): ) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.config = config def forward( self, From 597dc187d64d12090000e3bc1a2d9597eaa044d5 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 02:07:57 +0530 Subject: [PATCH 05/18] BUGFIX Signed-off-by: Onkar Chougule --- .../models/llama_swiftkv/modeling_llama_swiftkv.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 4f22e82e0..24b88746a 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -286,7 +286,9 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(None, inputs_embeds, cache_position, past_key_values, False) + causal_mask = self._update_causal_mask( + None, inputs_embeds, cache_position, position_ids, past_key_values, False + ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers From b8b6dbc57500460ecdd50c805f304095ea062d52 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 02:16:52 +0530 Subject: [PATCH 06/18] BUGFIX Signed-off-by: Onkar Chougule --- .../models/llama_swiftkv/modeling_llama_swiftkv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 24b88746a..8eaef4521 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -292,7 +292,7 @@ def forward( hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) + # position_embeddings = self.rotary_emb(hidden_states, position_ids) next_decoder_cache = None for layer_idx in range(self.config.num_key_value_layers): @@ -305,7 +305,7 @@ def forward( output_attentions=False, use_cache=True, cache_position=cache_position, - position_embeddings=position_embeddings, + position_embeddings=None, ) bsz, q_len, _ = hidden_states.size() From f5fd0bf8065ad796506c312f7fbb4fe7f4664541 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 02:23:24 +0530 Subject: [PATCH 07/18] BUGFIX Signed-off-by: Onkar Chougule --- .../models/llama_swiftkv/modeling_llama_swiftkv.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 8eaef4521..19887c77e 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -123,6 +123,8 @@ class LlamaSwiftKVDecoderLayer(nn.Module): def __init__(self, config, layer_idx) -> None: super().__init__() self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) self.self_attn = LlamaSwiftKVAttention(config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) @@ -318,8 +320,10 @@ def forward( self_attn = self.layers[layer_idx].self_attn key_states = self_attn.k_proj_swiftkv(swiftkv_hidden_states) value_states = self_attn.v_proj_swiftkv(swiftkv_hidden_states) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose( + 1, 2 + ) kv_seq_len = key_states.shape[-2] if past_key_values is not None: @@ -331,12 +335,12 @@ def forward( ) kv_seq_len = past_key_values.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) _, key_states = qeff_apply_rotary_pos_emb( torch.empty_like(swiftkv_hidden_states), key_states, cos, sin, position_ids ) cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids} - past_key_values.write_only(key_states, value_states, self.layer_idx, cache_kwargs) + past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) hidden_states, next_decoder_cache = self._run_swiftkv_layers( hidden_states, position_ids, past_key_values, causal_mask From f0a80b9240a5d6db8877a6157360c007f7beb511 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 12:14:45 +0530 Subject: [PATCH 08/18] BUGFIX Signed-off-by: Onkar Chougule --- .../transformers/models/llama_swiftkv/modeling_llama_swiftkv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 19887c77e..20a91ef45 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -124,7 +124,7 @@ def __init__(self, config, layer_idx) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_key_value_heads = config.num_key_value_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_heads) self.self_attn = LlamaSwiftKVAttention(config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) From 7281ccb9c089849cec9be9df94dacfd6e05cc27d Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 12:33:39 +0530 Subject: [PATCH 09/18] BUGFIX Signed-off-by: Onkar Chougule --- .../transformers/models/llama_swiftkv/modeling_llama_swiftkv.py | 1 - 1 file changed, 1 deletion(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 20a91ef45..b4160a312 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -124,7 +124,6 @@ def __init__(self, config, layer_idx) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_key_value_heads = config.num_key_value_heads - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_heads) self.self_attn = LlamaSwiftKVAttention(config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) From 10cc781cdae59bed16cff330d2b6a710f4389729 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 13:05:36 +0530 Subject: [PATCH 10/18] BUGFIX Signed-off-by: Onkar Chougule --- .../models/llama_swiftkv/modeling_llama_swiftkv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index b4160a312..4d8bfb754 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -326,13 +326,13 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_values is not None: - if self.layer_idx is None: + if self_attn.layer_idx is None: raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + f"The cache structure has changed since version v4.36. If you are using {self_attn.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_values.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_values.get_usable_length(kv_seq_len, self_attn.layer_idx) cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) _, key_states = qeff_apply_rotary_pos_emb( From 9b87d8f6d2bfdf4fae743b7af358427d88a47e74 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 13:08:51 +0530 Subject: [PATCH 11/18] BUGFIX Signed-off-by: Onkar Chougule --- .../models/llama_swiftkv/modeling_llama_swiftkv.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 4d8bfb754..4015a6c95 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -335,9 +335,7 @@ def forward( kv_seq_len = past_key_values.get_usable_length(kv_seq_len, self_attn.layer_idx) cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) - _, key_states = qeff_apply_rotary_pos_emb( - torch.empty_like(swiftkv_hidden_states), key_states, cos, sin, position_ids - ) + _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids} past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) From 4c607123b68c4600bb6ad6d19e9a49d605a4f617 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 17 Dec 2024 13:18:16 +0530 Subject: [PATCH 12/18] BUGFIX Signed-off-by: Onkar Chougule --- .../transformers/models/llama_swiftkv/modeling_llama_swiftkv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 4015a6c95..8ba2ad78e 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -73,7 +73,7 @@ def forward( attention_mask=None, ) -> torch.Tensor: bsz, q_len, _ = hidden_states.size() - query, _ = self.q_proj_swiftkv(hidden_states) + query = self.q_proj_swiftkv(hidden_states) # Reshape the query, key, and value tensors. query_states = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) From 1fa0a3cbf945778b6086e2f807f4b64e4c451537 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 19 Dec 2024 16:21:20 +0530 Subject: [PATCH 13/18] all bugfixes in Signed-off-by: Onkar Chougule --- .../llama_swiftkv/modeling_llama_swiftkv.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 8ba2ad78e..d93d7cb44 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -90,7 +90,11 @@ def forward( key_states, value_states = past_key_value.read_only(self.layer_idx, position_ids=position_ids) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, _ = qeff_apply_rotary_pos_emb(query_states, torch.empty_like(key_states), cos, sin, position_ids) + position_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) + position_ids = position_ids[:, position_idx[0]] + query_states, _ = qeff_apply_rotary_pos_emb( + query_states, torch.empty_like(query_states), cos, sin, position_ids + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -160,9 +164,7 @@ def __init__(self, config): self.vocab_size = config.vocab_size self.config = config - self.embed_tokens = nn.Embedding( - self.vocab_size, config.hidden_size, None - ) # TODO: Not sure if padding_idx shoudl eb NONE + self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, None) self.layers = torch.nn.ModuleList( [ QEffLlamaDecoderLayer(config=config, layer_idx=idx) @@ -179,9 +181,9 @@ def _run_swiftkv_layers( ) -> torch.Tensor: for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): layer = self.layers[layer_idx] - hidden_states, past_key_values = layer(hidden_states, position_ids, past_key_values, causal_mask) + hidden_states = self.norm(hidden_states) return hidden_states, past_key_values def _update_causal_mask( @@ -339,15 +341,21 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids} past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) + last_pos_id = position_ids.to(torch.int32).argmax(1, keepdim=True) + orig_hidden_states = hidden_states + hidden_states = orig_hidden_states[:, last_pos_id[0], :] + causal_mask = causal_mask[:, :, last_pos_id[0], :] + hidden_states, next_decoder_cache = self._run_swiftkv_layers( hidden_states, position_ids, past_key_values, causal_mask ) + orig_hidden_states[:, last_pos_id[0], :] = hidden_states #################################### ## THE MAGIC OF SWIFT KV ENDS HERE #################################### next_cache = next_decoder_cache.to_legacy_cache() - return hidden_states, next_cache + return orig_hidden_states, next_cache class LlamaSwiftKVForCausalLM(nn.Module): From 4534ba6ea04a02e3516979b5a534c80dbe4acbf6 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 19 Dec 2024 16:21:56 +0530 Subject: [PATCH 14/18] added init file Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/llama_swiftkv/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 QEfficient/transformers/models/llama_swiftkv/__init__.py diff --git a/QEfficient/transformers/models/llama_swiftkv/__init__.py b/QEfficient/transformers/models/llama_swiftkv/__init__.py new file mode 100644 index 000000000..e69de29bb From 3c95661915e73347f99d709409f4025a8404a083 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 9 Jan 2025 16:38:13 +0530 Subject: [PATCH 15/18] all changes except BQA are in with this --- QEfficient/transformers/cache_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index fe56b197c..2a07d9f10 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -47,8 +47,9 @@ def write_only(self, key_states, value_states, layer_idx, cache_kwargs): self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) def read_only(self, layer_idx, cache_kwargs): + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] position_ids = cache_kwargs.get("position_ids") - ctx_len = position_ids.shape[-1] + ctx_len = k_out.shape[2] ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit @@ -59,7 +60,7 @@ def read_only(self, layer_idx, cache_kwargs): invalid_idx_value = 0 ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + k_out = CtxGatherFunc.apply(k_out, ctx_indices) v_out = CtxGatherFunc.apply(v_out, ctx_indices) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) From dd5633d8e56bc3baf96f4c131ea6a59543b16396 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 5 Feb 2025 09:20:06 +0530 Subject: [PATCH 16/18] more updates Signed-off-by: Onkar Chougule --- .../models/llama_swiftkv/__init__.py | 6 ++ .../llama_swiftkv/modeling_llama_swiftkv.py | 68 +++---------------- 2 files changed, 14 insertions(+), 60 deletions(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/__init__.py b/QEfficient/transformers/models/llama_swiftkv/__init__.py index e69de29bb..d259e435a 100644 --- a/QEfficient/transformers/models/llama_swiftkv/__init__.py +++ b/QEfficient/transformers/models/llama_swiftkv/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index d93d7cb44..365f0b6d2 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -1,25 +1,13 @@ -# coding=utf-8 -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# ----------------------------------------------------------------------------- # -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ----------------------------------------------------------------------------- +# This file is adapted from vllm implementation by snowflake here: https://github.com/Snowflake-Labs/vllm/blob/swiftkv/vllm/model_executor/models/llama_swiftkv.py +# The Modules are updated as required by Cloud AI 100 HW requirements. + + """Inference-only LLaMA model compatible with HuggingFace weights.""" import math @@ -294,8 +282,6 @@ def forward( ) hidden_states = inputs_embeds - # create position embeddings to be shared across the decoder layers - # position_embeddings = self.rotary_emb(hidden_states, position_ids) next_decoder_cache = None for layer_idx in range(self.config.num_key_value_layers): @@ -359,44 +345,6 @@ def forward( class LlamaSwiftKVForCausalLM(nn.Module): - """ - # packed_modules_mapping = { - # "kv_proj_swiftkv": ["k_proj_swiftkv", "v_proj_swiftkv"], - # "qkv_proj": ["q_proj", "k_proj", "v_proj"], - # "gate_up_proj": ["gate_proj", "up_proj"], - # } - - # # BitandBytes specific attributes - # default_bitsandbytes_target_modules = [ - # ".gate_proj.", - # ".down_proj.", - # ".up_proj.", - # ".q_proj.", - # ".k_proj.", - # ".v_proj.", - # ".o_proj.", - # ".k_proj_swiftkv.", - # ".v_proj_swiftkv.", - # ] - - # # in TP, these weights are partitioned along the column dimension (dim=-1) - # column_parallel_weights_modules = [ - # ".q_proj_swiftkv.", - # ".down_proj.", - # ".o_proj.", - # ] - # bitsandbytes_stacked_params_mapping = { - # # shard_name, weight_name, index - # "k_proj_swiftkv": ("kv_proj_swiftkv", 1), - # "v_proj_swiftkv": ("kv_proj_swiftkv", 2), - # "q_proj": ("qkv_proj", 0), - # "k_proj": ("qkv_proj", 1), - # "v_proj": ("qkv_proj", 2), - # "gate_proj": ("gate_up_proj", 0), - # "up_proj": ("gate_up_proj", 1), - # } - """ - def __init__(self, *, config): super().__init__() From a73da94ed86762cfa7b129ca0881b9e4c2c9d7d1 Mon Sep 17 00:00:00 2001 From: Hem Agnihotri Date: Thu, 27 Feb 2025 06:17:43 +0000 Subject: [PATCH 17/18] Enabling the SwiftKV model in the QEFF Infra Signed-off-by: Hem Agnihotri --- QEfficient/transformers/modeling_utils.py | 19 ++++++++ .../llama_swiftkv/config_llama_swiftkv.py | 45 +++++++++++++++++++ .../llama_swiftkv/modeling_llama_swiftkv.py | 17 ++++--- .../transformers/models/modeling_auto.py | 6 +++ QEfficient/utils/_utils.py | 2 +- 5 files changed, 82 insertions(+), 7 deletions(-) create mode 100644 QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index ccad5e020..aec82e8cd 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -153,6 +153,9 @@ QEffWhisperPositionalEmbedding, ) +from QEfficient.transformers.models.llama_swiftkv.config_llama_swiftkv import LlamaSwiftKVConfig +from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import LlamaSwiftKVForCausalLM + # Define a named tuple for ModelArchitectures # Required for the Automation tool ModelArchitectures = namedtuple("ModelArchitectures", ["architectures"]) @@ -362,3 +365,19 @@ def _create_causal_mask( attention_mask = attention_mask.unsqueeze(1) return attention_mask + + +# Define a SwiftKV Model card name to Model type dictionary +# While onboarding new models make sure to add the new SwiftKV model card names to this dictionary. +SwiftKVModelCardNameToSwiftKVModelTypeDict: Dict[Type[str], Type[str]] = { + # LlamaSwiftKV Model + "Snowflake/Llama-3.1-SwiftKV-8B-Instruct": "llama_swiftkv" +} + +# Define a SwiftKV Model type to ConfigClass and ModelArchitecture class dictionary +# While onboarding new models make sure to add the new SwiftKV model card names to this dictionary. +SwiftKVModelTypeToConfigClassAndModelArchClassDict = { + # LlamaSwiftKV Model + "llama_swiftkv" : [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM] +} + diff --git a/QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py new file mode 100644 index 000000000..fa97388de --- /dev/null +++ b/QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py @@ -0,0 +1,45 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +# The Modules are updated as required by Cloud AI 100 HW requirements. + + +"""Inference-only LLaMA model compatible with HuggingFace weights.""" + + + +from typing import Optional +from transformers import LlamaConfig + + +class LlamaSwiftKVConfig(LlamaConfig): + """ + Args: + num_key_value_layers (int, optional): + The number of layers, from the first layer, that have keys and + values. If None, all layers have keys and values. + last_key_value_heads (int, optional): + The number of heads in the last layer that have keys and values. + If None, the number of heads in the last key-value layer is equal + to the number of heads in all the other key-value layers. + """ + + model_type = "llama_swiftkv" + + def __init__( + self, + swiftkv: bool = False, + num_key_value_layers: Optional[int] = None, + key_value_group_size: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.swiftkv = swiftkv + self.num_key_value_layers = num_key_value_layers or self.num_hidden_layers + self.key_value_group_size = key_value_group_size or 1 + assert ( + self.num_hidden_layers - self.num_key_value_layers + ) % self.key_value_group_size == 0 \ No newline at end of file diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 365f0b6d2..e2bd5a08a 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -18,6 +18,7 @@ from transformers.cache_utils import Cache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, logger, repeat_kv +from transformers.modeling_utils import PreTrainedModel from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask @@ -26,10 +27,10 @@ QEffLlamaRotaryEmbedding, qeff_apply_rotary_pos_emb, ) - +from QEfficient.transformers.models.llama_swiftkv.config_llama_swiftkv import LlamaSwiftKVConfig class LlamaSwiftKVAttention(nn.Module): - def __init__(self, config, layer_idx) -> None: + def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None: super().__init__() self.hidden_size = config.hidden_size self.attention_dropout = config.attention_dropout @@ -112,7 +113,7 @@ def forward( class LlamaSwiftKVDecoderLayer(nn.Module): - def __init__(self, config, layer_idx) -> None: + def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_key_value_heads = config.num_key_value_heads @@ -147,7 +148,9 @@ def forward( class LlamaSwiftKVModel(nn.Module): - def __init__(self, config): + config_class = LlamaSwiftKVConfig + + def __init__(self, config: LlamaSwiftKVConfig): super().__init__() self.vocab_size = config.vocab_size self.config = config @@ -344,8 +347,10 @@ def forward( return orig_hidden_states, next_cache -class LlamaSwiftKVForCausalLM(nn.Module): - def __init__(self, *, config): +class LlamaSwiftKVForCausalLM(PreTrainedModel): + config_class = LlamaSwiftKVConfig + + def __init__(self, *, config: LlamaSwiftKVConfig): super().__init__() self.model = LlamaSwiftKVModel( diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index b8b5981cd..c543da036 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -7,6 +7,7 @@ import hashlib import warnings + from pathlib import Path from time import perf_counter from typing import List, Optional, Union @@ -51,6 +52,7 @@ from QEfficient.utils import constants, get_padding_shape_from_config from QEfficient.utils.cache import to_hashable from QEfficient.utils.logging_utils import logger +from QEfficient.utils._utils import QEFFLoadSwiftKVModels MODELS_WITH_ACCURACY_ISSUE_FOR_MXFP6 = ["MllamaForConditionalGeneration"] @@ -78,6 +80,10 @@ def __repr__(self) -> str: @classmethod @with_replaced_quantizers def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs): + + # Load the SwiftKV model if supported + QEFFLoadSwiftKVModels(pretrained_model_name_or_path) + if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 8344a053d..5b205ffbc 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -19,7 +19,7 @@ from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants from QEfficient.utils.logging_utils import logger - +from QEfficient.transformers.modeling_utils import SwiftKVModelCardNameToSwiftKVModelTypeDict, SwiftKVModelTypeToConfigClassAndModelArchClassDict class DownloadRetryLimitExceeded(Exception): """ From ed4e1a9dd069efb9c7c22002011dd0d6df3cf62b Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 27 Feb 2025 15:16:14 +0530 Subject: [PATCH 18/18] rebased Signed-off-by: Onkar Chougule --- QEfficient/transformers/modeling_utils.py | 3 +- .../llama_swiftkv/config_llama_swiftkv.py | 6 +- .../llama_swiftkv/modeling_llama_swiftkv.py | 1 + .../transformers/models/modeling_auto.py | 1 - QEfficient/utils/_utils.py | 78 ++++++++++++++++++- 5 files changed, 78 insertions(+), 11 deletions(-) diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index aec82e8cd..42244e288 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -378,6 +378,5 @@ def _create_causal_mask( # While onboarding new models make sure to add the new SwiftKV model card names to this dictionary. SwiftKVModelTypeToConfigClassAndModelArchClassDict = { # LlamaSwiftKV Model - "llama_swiftkv" : [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM] + "llama_swiftkv": [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM] } - diff --git a/QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py index fa97388de..77eeb61a3 100644 --- a/QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py @@ -9,8 +9,6 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" - - from typing import Optional from transformers import LlamaConfig @@ -40,6 +38,4 @@ def __init__( self.swiftkv = swiftkv self.num_key_value_layers = num_key_value_layers or self.num_hidden_layers self.key_value_group_size = key_value_group_size or 1 - assert ( - self.num_hidden_layers - self.num_key_value_layers - ) % self.key_value_group_size == 0 \ No newline at end of file + assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0 diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index e2bd5a08a..4d6888bc7 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -29,6 +29,7 @@ ) from QEfficient.transformers.models.llama_swiftkv.config_llama_swiftkv import LlamaSwiftKVConfig + class LlamaSwiftKVAttention(nn.Module): def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None: super().__init__() diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c543da036..feda125ef 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -80,7 +80,6 @@ def __repr__(self) -> str: @classmethod @with_replaced_quantizers def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs): - # Load the SwiftKV model if supported QEFFLoadSwiftKVModels(pretrained_model_name_or_path) diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 5b205ffbc..e9b58d209 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -8,6 +8,8 @@ import json import os import subprocess +import sys +import warnings from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -15,11 +17,21 @@ import torch from huggingface_hub import login, snapshot_download from requests.exceptions import HTTPError -from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast - +from transformers import ( + AutoConfig, + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) + +from QEfficient.transformers.modeling_utils import ( + SwiftKVModelCardNameToSwiftKVModelTypeDict, + SwiftKVModelTypeToConfigClassAndModelArchClassDict, +) from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants from QEfficient.utils.logging_utils import logger -from QEfficient.transformers.modeling_utils import SwiftKVModelCardNameToSwiftKVModelTypeDict, SwiftKVModelTypeToConfigClassAndModelArchClassDict + class DownloadRetryLimitExceeded(Exception): """ @@ -442,3 +454,63 @@ class IOInfo: def __repr__(self): return f"input_name:{self.name}\tdatatype:{self.datatype}\tshape:{self.shape}" + + +def convert_str_to_class(className): + """ + Convert the string to class name + --------- + :className: `str`- Class name string. + Return: + Class Name + """ + return getattr(sys.modules[__name__], className) + + +def register_swiftKV_model(model_type, SwiftkvConfigCls, SwiftKVModelCls): + """ + Register the SwiftKV Models + --------------------------------------- + : model_type: str: name of the swiftKVModel for example llama_swiftkv + : SwiftkVConfigCls: SwiftKV Config class for example LlamaSwiftKVConfig + : SwiftKVModelCls: SwiftKV model class name for example LlamaSwiftKVForCausalLM + """ + + # Register the SwiftKV Config class using AutoConfig + AutoConfig.register(model_type, SwiftkvConfigCls) + + # Construct the AutoModel class name using SwiftKVModel Class name, this code is written to make things generic + swiftKvModelName = SwiftKVModelCls.__name__ + start_index = swiftKvModelName.find("SwiftKVFor") + + # Calculate the index after "SwiftKVFor" + substring_start = start_index + len("SwiftKVFor") + + # Get the substring after "SwiftKVFor" + swiftKVModel = swiftKvModelName[substring_start:] + + AutoModelName = "AutoModelFor" + swiftKVModel + + # Convert the string to class name + AutoModelClassName = convert_str_to_class(AutoModelName) + + # Register the SwiftKVModel Class and config class using AutoModelClass + AutoModelClassName.register(SwiftkvConfigCls, SwiftKVModelCls) + + +def QEFFLoadSwiftKVModels(pretrained_model_name_or_path): + """ + Load the SwiftKV Models + --------------------------------------- + : pretrained_model_name_or_path: str: name of the swiftKVModel for example Snowflake/Llama-3.1-SwiftKV-8B-Instruct + """ + try: + modelType = SwiftKVModelCardNameToSwiftKVModelTypeDict[pretrained_model_name_or_path] + + SwiftKVConfigCls = SwiftKVModelTypeToConfigClassAndModelArchClassDict[modelType][0] + SwiftKVModelArchCls = SwiftKVModelTypeToConfigClassAndModelArchClassDict[modelType][1] + + register_swiftKV_model(modelType, SwiftKVConfigCls, SwiftKVModelArchCls) + + except KeyError: + warnings.warn("Requested SwiftKVModel is currently not supported... stay tuned for future releases", Warning)