From 04bfb9d3ea6ab706e0e7282bb4fd6446a5fb41ea Mon Sep 17 00:00:00 2001 From: Hem Agnihotri Date: Thu, 13 Mar 2025 01:55:34 +0000 Subject: [PATCH 1/5] Added Infra in QEfficient for execution of models whose modelling file is not present at hugging face and checkpoint like swiftkv Signed-off-by: Hem Agnihotri --- QEfficient/__init__.py | 20 +- QEfficient/transformers/cache_utils.py | 29 ++ QEfficient/transformers/modeling_utils.py | 65 +++ .../models/llama_swiftkv/__init__.py | 6 + .../llama_swiftkv/modeling_llama_swiftkv.py | 402 ++++++++++++++++++ QEfficient/utils/_utils.py | 7 +- README.md | 1 + docs/source/validate.md | 1 + 8 files changed, 529 insertions(+), 2 deletions(-) create mode 100644 QEfficient/transformers/models/llama_swiftkv/__init__.py create mode 100644 QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 47c462979..60aba0d74 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -1,6 +1,6 @@ # ----------------------------------------------------------------------------- # -# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- @@ -12,8 +12,26 @@ # hf_transfer is imported (will happen on line 15 via leading imports) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +from transformers import AutoConfig + +from QEfficient.transformers.modeling_utils import ( + MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS, + get_auto_model_class, + get_model_class_type_from_model_type, +) from QEfficient.utils.logging_utils import logger +# loop over all the model types which are not present in transformers and register them +for model_type, model_cls in MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS.items(): + # Register the model config class based on the model type. This will be first element in the tuple + AutoConfig.register(model_type, model_cls[0]) + + model_class_type = get_model_class_type_from_model_type(model_type) + AutoModelClassName = get_auto_model_class(model_class_type, model_cls[1]) + + # Register the non transformer library Class and config class using AutoModelClass + AutoModelClassName.register(model_cls[0], model_cls[1]) + def check_qaic_sdk(): """Check if QAIC SDK is installed""" diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index a5c375c6e..e7d6e8275 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -36,6 +36,35 @@ class QEffDynamicCache(DynamicCache): """ + def write_only(self, key_states, value_states, layer_idx, cache_kwargs): + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + else: + position_ids = cache_kwargs.get("position_ids") + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + + def read_only(self, layer_idx, **cache_kwargs): + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + position_ids = cache_kwargs.get("position_ids") + ctx_len = k_out.shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + def update( self, key_states: torch.Tensor, diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index ccad5e020..e70542ff7 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -88,6 +88,12 @@ from QEfficient.customop import CustomRMSNormAIC +# Placeholder for all non-transformer models +from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import ( + LlamaSwiftKVConfig, + LlamaSwiftKVForCausalLM, +) + from .models.codegen.modeling_codegen import ( QEffCodeGenAttention, QeffCodeGenBlock, @@ -271,6 +277,17 @@ WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration, } +# Map of model type to config class and Model architecture class +# While onboarding new models make sure to add the new model card names to this dictionary. +# Developers are expected to follow the naming conventions like ForCausalLM while defining the class names +MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = {"llama_swiftkv": [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM]} + +# list of sub-strings representing the model type, this is typically taken from llama-swiftkv +LIST_OF_MODEL_TYPES = {"swiftkv"} + +# list of sub-strings used for representing the model Architecture class name, for example LlamaSwiftKVForCausalLM +MODEL_TYPE_TO_MODEL_CLASS_TYPE = {"swiftkv": "SwiftKVFor"} + def _prepare_cross_attention_mask( cross_attention_mask: torch.Tensor, @@ -362,3 +379,51 @@ def _create_causal_mask( attention_mask = attention_mask.unsqueeze(1) return attention_mask + + +def convert_str_to_class(className): + """ + Convert the string to class name + --------- + :className: `str`- Class name string. + Return: + Class Name + """ + module = __import__("transformers") + return getattr(module, className) + + +def get_auto_model_class(model_type, NonTransformerModelCls): + """ + Register the Non Transformer Models like swiftkv + --------------------------------------- + : model_type: str: name of the Non Transformer model for example llama_swiftkv + : NonTransformerModelCls: SwiftKV model class name for example LlamaSwiftKVForCausalLM + """ + + # Construct the AutoModel class name using NonTransformerModel class e.g. SwiftKVModel Class name, this code is written to make things generic + nonTransformerModelClsName = NonTransformerModelCls.__name__ + start_index = nonTransformerModelClsName.find(model_type) + + # Calculate the index after model_type example "SwiftKVFor" + substring_start = start_index + len(model_type) + + # Get the substring after model_type example "SwiftKVFor" + nonTransformerModel = nonTransformerModelClsName[substring_start:] + + autoModelName = "AutoModelFor" + nonTransformerModel + + # Convert the string to class name + autoModelClassName = convert_str_to_class(autoModelName) + + return autoModelClassName + + +def get_model_class_type_from_model_type(model_type): + for substring in LIST_OF_MODEL_TYPES: + if substring in model_type: + model_class_type = substring + break + + model_class_name = MODEL_TYPE_TO_MODEL_CLASS_TYPE[model_class_type] + return model_class_name diff --git a/QEfficient/transformers/models/llama_swiftkv/__init__.py b/QEfficient/transformers/models/llama_swiftkv/__init__.py new file mode 100644 index 000000000..72ba36c8a --- /dev/null +++ b/QEfficient/transformers/models/llama_swiftkv/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py new file mode 100644 index 000000000..26931fced --- /dev/null +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -0,0 +1,402 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +# This file is adapted from vllm implementation by snowflake here: https://github.com/Snowflake-Labs/vllm/blob/swiftkv/vllm/model_executor/models/llama_swiftkv.py +# The Modules are updated as required by Cloud AI 100 HW requirements. + + +"""Inference-only LLaMA model compatible with HuggingFace weights.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import LlamaConfig +from transformers.cache_utils import Cache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, logger, repeat_kv + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.transformers.models.llama.modeling_llama import ( + QEffLlamaDecoderLayer, + QEffLlamaRotaryEmbedding, + qeff_apply_rotary_pos_emb, +) + + +class LlamaSwiftKVConfig(LlamaConfig): + """ + Args: + num_key_value_layers (int, optional): + The number of layers, from the first layer, that have keys and + values. If None, all layers have keys and values. + last_key_value_heads (int, optional): + The number of heads in the last layer that have keys and values. + If None, the number of heads in the last key-value layer is equal + to the number of heads in all the other key-value layers. + """ + + model_type = "llama_swiftkv" + + def __init__( + self, + swiftkv: bool = False, + num_key_value_layers: Optional[int] = None, + key_value_group_size: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.swiftkv = swiftkv + self.num_key_value_layers = num_key_value_layers or self.num_hidden_layers + self.key_value_group_size = key_value_group_size or 1 + assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0 + + +class LlamaSwiftKVAttention(nn.Module): + def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.layer_idx = layer_idx + self.q_proj_swiftkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj_swiftkv = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj_swiftkv = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + self.rotary_emb = QEffLlamaRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask=None, + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + query = self.q_proj_swiftkv(hidden_states) + + # Reshape the query, key, and value tensors. + query_states = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = position_ids.shape[-1] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + key_states, value_states = past_key_value.read_only(self.layer_idx, position_ids=position_ids) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + position_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) + position_ids = position_ids[:, position_idx[0]] + query_states, _ = qeff_apply_rotary_pos_emb( + query_states, torch.empty_like(query_states), cos, sin, position_ids + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, past_key_value + + +class LlamaSwiftKVDecoderLayer(nn.Module): + def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + + self.self_attn = LlamaSwiftKVAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, past_key_values = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + past_key_value=past_key_values, + attention_mask=causal_mask, + ) + + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, past_key_values + + +class LlamaSwiftKVModel(nn.Module): + config_class = LlamaSwiftKVConfig + + def __init__(self, config: LlamaSwiftKVConfig): + super().__init__() + self.vocab_size = config.vocab_size + self.config = config + + self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, None) + self.layers = torch.nn.ModuleList( + [ + QEffLlamaDecoderLayer(config=config, layer_idx=idx) + if idx < config.num_key_value_layers + else LlamaSwiftKVDecoderLayer(config=config, layer_idx=idx) + for idx in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def _run_swiftkv_layers( + self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask + ) -> torch.Tensor: + for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): + layer = self.layers[layer_idx] + hidden_states, past_key_values = layer(hidden_states, position_ids, past_key_values, causal_mask) + + hidden_states = self.norm(hidden_states) + return hidden_states, past_key_values + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + self.config._attn_implementation = "eager" + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + else: + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def forward( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_values: List[torch.Tensor], + ): + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + use_cache = True + + if use_cache and not isinstance(past_key_values, Cache): + if past_key_values is None: + past_key_values = QEffDynamicCache() + else: + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + None, inputs_embeds, cache_position, position_ids, past_key_values, False + ) + hidden_states = inputs_embeds + + next_decoder_cache = None + + for layer_idx in range(self.config.num_key_value_layers): + layer = self.layers[layer_idx] + hidden_states, next_decoder_cache = layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=False, + use_cache=True, + cache_position=cache_position, + position_embeddings=None, + ) + + bsz, q_len, _ = hidden_states.size() + swiftkv_hidden_states = self.norm_swiftkv(hidden_states) + + #################################### + ## THE MAGIC OF SWIFT KV BEGINS HERE + #################################### + for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): + self_attn = self.layers[layer_idx].self_attn + key_states = self_attn.k_proj_swiftkv(swiftkv_hidden_states) + value_states = self_attn.v_proj_swiftkv(swiftkv_hidden_states) + key_states = key_states.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose( + 1, 2 + ) + + kv_seq_len = key_states.shape[-2] + if past_key_values is not None: + if self_attn.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self_attn.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_values.get_usable_length(kv_seq_len, self_attn.layer_idx) + + cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) + _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids} + past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) + + last_pos_id = position_ids.to(torch.int32).argmax(1, keepdim=True) + orig_hidden_states = hidden_states + hidden_states = orig_hidden_states[:, last_pos_id[0], :] + causal_mask = causal_mask[:, :, last_pos_id[0], :] + + hidden_states, next_decoder_cache = self._run_swiftkv_layers( + hidden_states, position_ids, past_key_values, causal_mask + ) + orig_hidden_states[:, last_pos_id[0], :] = hidden_states + #################################### + ## THE MAGIC OF SWIFT KV ENDS HERE + #################################### + + next_cache = next_decoder_cache.to_legacy_cache() + return orig_hidden_states, next_cache + + +class LlamaSwiftKVForCausalLM(PreTrainedModel): + config_class = LlamaSwiftKVConfig + + def __init__(self, config: LlamaSwiftKVConfig): + super().__init__(config=config) + + self.model = LlamaSwiftKVModel( + config=config, + ) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.config = config + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + ): + hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values) + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + return logits, output_past_key_values diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index ea9044e2c..8ba5e2c18 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -17,7 +17,12 @@ import yaml from huggingface_hub import login, snapshot_download from requests.exceptions import HTTPError -from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import ( + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants, QnnConstants from QEfficient.utils.logging_utils import logger diff --git a/README.md b/README.md index 2185c9f64..724717874 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ --- *Latest news* :fire:
+- [03/2025] Added support for swiftkv model [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct) - [02/2025] [VLMs support](https://github.com/quic/efficient-transformers/pull/267) added for the models [InternVL-1B](https://huggingface.co/OpenGVLab/InternVL2_5-1B), [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf) and [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) - [01/2025] [FP8 models support](https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127) Added support for inference of FP8 models. diff --git a/docs/source/validate.md b/docs/source/validate.md index acd4c11da..7f1690d2d 100644 --- a/docs/source/validate.md +++ b/docs/source/validate.md @@ -33,6 +33,7 @@ | **Phi3ForCausalLM** | Phi-3, Phi-3.5 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | ✔️ | | **QwenForCausalLM** | DeepSeek-R1-Distill-Qwen | [DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | ✔️ | | | Qwen2, Qwen2.5 | [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | ✔️ | +| **LlamaSwiftKVForCausalLM** | swiftkv | [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct) | ✔️ | ## Embedding Models From abf9099fb81901b703323bdfb1c42a94cc3aad71 Mon Sep 17 00:00:00 2001 From: Hem Agnihotri Date: Sat, 22 Mar 2025 02:55:06 +0000 Subject: [PATCH 2/5] Support for continous batching and unit test Signed-off-by: Hem Agnihotri --- QEfficient/__init__.py | 11 +--- QEfficient/transformers/cache_utils.py | 31 +++++++-- QEfficient/transformers/modeling_utils.py | 63 ++----------------- .../llama_swiftkv/modeling_llama_swiftkv.py | 50 ++++++++++----- .../models/test_causal_lm_models.py | 45 +++++-------- 5 files changed, 84 insertions(+), 116 deletions(-) diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 60aba0d74..a0120b3ff 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -14,11 +14,7 @@ from transformers import AutoConfig -from QEfficient.transformers.modeling_utils import ( - MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS, - get_auto_model_class, - get_model_class_type_from_model_type, -) +from QEfficient.transformers.modeling_utils import MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS from QEfficient.utils.logging_utils import logger # loop over all the model types which are not present in transformers and register them @@ -26,11 +22,8 @@ # Register the model config class based on the model type. This will be first element in the tuple AutoConfig.register(model_type, model_cls[0]) - model_class_type = get_model_class_type_from_model_type(model_type) - AutoModelClassName = get_auto_model_class(model_class_type, model_cls[1]) - # Register the non transformer library Class and config class using AutoModelClass - AutoModelClassName.register(model_cls[0], model_cls[1]) + model_cls[2].register(model_cls[0], model_cls[1]) def check_qaic_sdk(): diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index e7d6e8275..89b2dbab3 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -43,12 +43,29 @@ def write_only(self, key_states, value_states, layer_idx, cache_kwargs): self.value_cache.append(value_states) else: position_ids = cache_kwargs.get("position_ids") - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + batch_index = cache_kwargs.get("batch_index", None) + + # Scatter + if batch_index is not None: + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) + + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], position_ids, value_states + ) def read_only(self, layer_idx, **cache_kwargs): k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) ctx_len = k_out.shape[2] ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) @@ -60,8 +77,14 @@ def read_only(self, layer_idx, **cache_kwargs): invalid_idx_value = 0 ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index e70542ff7..8d758dd2e 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn +from transformers import AutoModelForCausalLM from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, CodeGenBlock, @@ -277,16 +278,10 @@ WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration, } -# Map of model type to config class and Model architecture class -# While onboarding new models make sure to add the new model card names to this dictionary. -# Developers are expected to follow the naming conventions like ForCausalLM while defining the class names -MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = {"llama_swiftkv": [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM]} - -# list of sub-strings representing the model type, this is typically taken from llama-swiftkv -LIST_OF_MODEL_TYPES = {"swiftkv"} - -# list of sub-strings used for representing the model Architecture class name, for example LlamaSwiftKVForCausalLM -MODEL_TYPE_TO_MODEL_CLASS_TYPE = {"swiftkv": "SwiftKVFor"} +# Map of model type to config class, Modelling class and transformer model architecture class +MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = { + "llama_swiftkv": [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM, AutoModelForCausalLM], +} def _prepare_cross_attention_mask( @@ -379,51 +374,3 @@ def _create_causal_mask( attention_mask = attention_mask.unsqueeze(1) return attention_mask - - -def convert_str_to_class(className): - """ - Convert the string to class name - --------- - :className: `str`- Class name string. - Return: - Class Name - """ - module = __import__("transformers") - return getattr(module, className) - - -def get_auto_model_class(model_type, NonTransformerModelCls): - """ - Register the Non Transformer Models like swiftkv - --------------------------------------- - : model_type: str: name of the Non Transformer model for example llama_swiftkv - : NonTransformerModelCls: SwiftKV model class name for example LlamaSwiftKVForCausalLM - """ - - # Construct the AutoModel class name using NonTransformerModel class e.g. SwiftKVModel Class name, this code is written to make things generic - nonTransformerModelClsName = NonTransformerModelCls.__name__ - start_index = nonTransformerModelClsName.find(model_type) - - # Calculate the index after model_type example "SwiftKVFor" - substring_start = start_index + len(model_type) - - # Get the substring after model_type example "SwiftKVFor" - nonTransformerModel = nonTransformerModelClsName[substring_start:] - - autoModelName = "AutoModelFor" + nonTransformerModel - - # Convert the string to class name - autoModelClassName = convert_str_to_class(autoModelName) - - return autoModelClassName - - -def get_model_class_type_from_model_type(model_type): - for substring in LIST_OF_MODEL_TYPES: - if substring in model_type: - model_class_type = substring - break - - model_class_name = MODEL_TYPE_TO_MODEL_CLASS_TYPE[model_class_type] - return model_class_name diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 26931fced..badf76cce 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -18,6 +18,7 @@ from transformers import LlamaConfig from transformers.cache_utils import Cache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, logger, repeat_kv @@ -89,6 +90,7 @@ def forward( position_ids, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask=None, + batch_index: Optional[torch.LongTensor] = None, ) -> torch.Tensor: bsz, q_len, _ = hidden_states.size() query = self.q_proj_swiftkv(hidden_states) @@ -106,10 +108,11 @@ def forward( ) kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - key_states, value_states = past_key_value.read_only(self.layer_idx, position_ids=position_ids) + key_states, value_states = past_key_value.read_only( + self.layer_idx, position_ids=position_ids, batch_index=batch_index + ) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - position_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) - position_ids = position_ids[:, position_idx[0]] + position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) query_states, _ = qeff_apply_rotary_pos_emb( query_states, torch.empty_like(query_states), cos, sin, position_ids ) @@ -134,7 +137,6 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -153,7 +155,12 @@ def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None: self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( - self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + past_key_values, + causal_mask, + batch_index: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -164,6 +171,7 @@ def forward( position_ids=position_ids, past_key_value=past_key_values, attention_mask=causal_mask, + batch_index=batch_index, ) hidden_states = residual + hidden_states @@ -197,11 +205,13 @@ def __init__(self, config: LlamaSwiftKVConfig): self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def _run_swiftkv_layers( - self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask + self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask, batch_index ) -> torch.Tensor: for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): layer = self.layers[layer_idx] - hidden_states, past_key_values = layer(hidden_states, position_ids, past_key_values, causal_mask) + hidden_states, past_key_values = layer( + hidden_states, position_ids, past_key_values, causal_mask, batch_index + ) hidden_states = self.norm(hidden_states) return hidden_states, past_key_values @@ -285,6 +295,7 @@ def forward( input_ids: Optional[torch.Tensor], position_ids: torch.Tensor, past_key_values: List[torch.Tensor], + batch_index: Optional[torch.LongTensor] = None, ): inputs_embeds = self.embed_tokens(input_ids) @@ -323,6 +334,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + batch_index=batch_index, output_attentions=False, use_cache=True, cache_position=cache_position, @@ -356,18 +368,21 @@ def forward( cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) - cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids} + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids, "batch_index": batch_index} past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) last_pos_id = position_ids.to(torch.int32).argmax(1, keepdim=True) orig_hidden_states = hidden_states - hidden_states = orig_hidden_states[:, last_pos_id[0], :] - causal_mask = causal_mask[:, :, last_pos_id[0], :] + + hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :] + + causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :] hidden_states, next_decoder_cache = self._run_swiftkv_layers( - hidden_states, position_ids, past_key_values, causal_mask + hidden_states, position_ids, past_key_values, causal_mask, batch_index ) - orig_hidden_states[:, last_pos_id[0], :] = hidden_states + + orig_hidden_states[torch.arange(bsz), last_pos_id, :] = hidden_states #################################### ## THE MAGIC OF SWIFT KV ENDS HERE #################################### @@ -394,9 +409,16 @@ def forward( input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, ): - hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values) + hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values, batch_index) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] logits = self.lm_head(hidden_states) - return logits, output_past_key_values + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=output_past_key_values, + hidden_states=None, + attentions=None, + ) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 418386780..7894a63e0 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -8,7 +8,6 @@ import os from typing import Optional -import numpy as np import pytest from transformers import AutoModelForCausalLM @@ -22,28 +21,7 @@ from QEfficient.utils.run_utils import ApiRunner test_models = [ - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "gpt2", - "Salesforce/codegen-350M-mono", - "microsoft/Phi-3-mini-4k-instruct", - "tiiuae/falcon-7b", - "Qwen/Qwen2-0.5B", - "bigcode/starcoder2-3b", - "Felladrin/Minueza-32M-Base", - "wtang06/mpt-125m-c4", - "hakurei/gpt-j-random-tinier", - "mistralai/Mixtral-8x7B-Instruct-v0.1", - "meta-llama/Llama-3.2-1B", - "unsloth/gemma-2b", - "unsloth/gemma-2-2b", - "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", # AWQ model - "TheBloke/Llama-2-7B-GPTQ", # GPTQ model - "ibm-granite/granite-20b-code-base", - # "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic", # naive-quantized compressed-tensor FP8 model per-channel weight, per-token activations - "neuralmagic/Llama-3.2-3B-Instruct-FP8", # float quantized compressed-tensor per tensor both weight and activations - "neuralmagic/Qwen2-0.5B-Instruct-FP8", # fp8 quant method, static, with lm head ignored - "ibm-granite/granite-3.1-2b-instruct", - "ibm-granite/granite-guardian-3.1-2b", + "Snowflake/Llama-3.1-SwiftKV-8B-Instruct", # SwiftKV model ] spd_test_models = [ @@ -109,16 +87,15 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( Constants.CTX_LEN, ) - pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) - + # pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) is_tlm = False if num_speculative_tokens is None else True qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm) pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) - assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( - "Tokens don't match for HF PyTorch model output and KV PyTorch model output" - ) + # assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( + # "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + # ) onnx_model_path = qeff_model.export() ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) @@ -158,8 +135,8 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( full_batch_size, ) - pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) - pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) + # pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) + # pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm) onnx_model_path = qeff_model.export() @@ -176,8 +153,11 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens, ) - exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) + # exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) + qeff_model.generate(tokenizer, prompts=fbs_prompts) + +""" assert all( [ all(pt_token[:24] == cloud_token[:24]) @@ -185,6 +165,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( ] ), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output." assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) +""" # FIXME: there should be a CB test here @@ -227,6 +208,8 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): """ if model_name == "microsoft/Phi-3-mini-4k-instruct": n_layer = 2 # test only 2 layer models + elif model_name == "Snowflake/Llama-3.1-SwiftKV-8B-Instruct": + n_layer = 32 else: n_layer = 1 From 1b1af48cdb3a3b17801b1e5a2c0aa69d563a29c4 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 25 Mar 2025 22:22:59 +0530 Subject: [PATCH 3/5] Fixed CB bug for SwiftKV Signed-off-by: Onkar Chougule --- QEfficient/transformers/cache_utils.py | 2 +- .../llama_swiftkv/modeling_llama_swiftkv.py | 32 +++++++++---------- .../models/test_causal_lm_models.py | 21 ++++++------ 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 89b2dbab3..db1cb7ea4 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -62,7 +62,7 @@ def write_only(self, key_states, value_states, layer_idx, cache_kwargs): self.value_cache[layer_idx], position_ids, value_states ) - def read_only(self, layer_idx, **cache_kwargs): + def read_only(self, layer_idx, cache_kwargs): k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index badf76cce..d582bcba8 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -94,7 +94,6 @@ def forward( ) -> torch.Tensor: bsz, q_len, _ = hidden_states.size() query = self.q_proj_swiftkv(hidden_states) - # Reshape the query, key, and value tensors. query_states = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -107,10 +106,9 @@ def forward( "with a layer index." ) kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) - key_states, value_states = past_key_value.read_only( - self.layer_idx, position_ids=position_ids, batch_index=batch_index - ) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) query_states, _ = qeff_apply_rotary_pos_emb( @@ -121,10 +119,8 @@ def forward( value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) - # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) @@ -148,7 +144,6 @@ def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_key_value_heads = config.num_key_value_heads - self.self_attn = LlamaSwiftKVAttention(config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -343,7 +338,6 @@ def forward( bsz, q_len, _ = hidden_states.size() swiftkv_hidden_states = self.norm_swiftkv(hidden_states) - #################################### ## THE MAGIC OF SWIFT KV BEGINS HERE #################################### @@ -374,24 +368,30 @@ def forward( last_pos_id = position_ids.to(torch.int32).argmax(1, keepdim=True) orig_hidden_states = hidden_states - hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :] - - causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :] + # Extracting only the last valid position id to be processed by self-attn of half of the layers, as KV cache is already filled. + if batch_index is not None: + hidden_states = orig_hidden_states[batch_index, last_pos_id, :] + causal_mask = causal_mask[batch_index, :, last_pos_id, :] + else: + hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :] + causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :] hidden_states, next_decoder_cache = self._run_swiftkv_layers( hidden_states, position_ids, past_key_values, causal_mask, batch_index ) - - orig_hidden_states[torch.arange(bsz), last_pos_id, :] = hidden_states + # We can fill the orig_hidden_states with the processed hidden_states here but it's not needed as for next token prediction + # we only need the last valid pos_indices hidden_states. + # Here the shape of hiden_states is [batch_size, 1, hidden_dim] instead of [batch_size, seq_len, hidden_dim] + # This saves un-necessary data movement on devices. #################################### ## THE MAGIC OF SWIFT KV ENDS HERE #################################### next_cache = next_decoder_cache.to_legacy_cache() - return orig_hidden_states, next_cache + return hidden_states, next_cache -class LlamaSwiftKVForCausalLM(PreTrainedModel): +class LlamaSwiftKVForCausalLM(PreTrainedModel): # config_class = LlamaSwiftKVConfig def __init__(self, config: LlamaSwiftKVConfig): @@ -412,8 +412,6 @@ def forward( batch_index: Optional[torch.LongTensor] = None, ): hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values, batch_index) - logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] logits = self.lm_head(hidden_states) return CausalLMOutputWithPast( loss=None, diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 7894a63e0..83439ce34 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -8,6 +8,8 @@ import os from typing import Optional +import numpy as np + import pytest from transformers import AutoModelForCausalLM @@ -123,17 +125,18 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( # testing for CB models model_hf, _ = load_causal_lm_model(model_config) + config = model_hf.config full_batch_size = 4 fbs_prompts = Constants.INPUT_STR * 4 - api_runner = ApiRunner( - batch_size, - tokenizer, - config, - fbs_prompts, - Constants.PROMPT_LEN, - Constants.CTX_LEN, - full_batch_size, - ) + # api_runner = ApiRunner( + # batch_size, + # tokenizer, + # config, + # fbs_prompts, + # Constants.PROMPT_LEN, + # Constants.CTX_LEN, + # full_batch_size, + # ) # pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) # pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) From 68304456590b4ee8aa9e963f911e15f5265bbdf3 Mon Sep 17 00:00:00 2001 From: Hem Agnihotri Date: Mon, 31 Mar 2025 09:17:37 +0000 Subject: [PATCH 4/5] Added unit test for non HF models like swiftkv Signed-off-by: Hem Agnihotri --- QEfficient/transformers/modeling_utils.py | 6 +- .../llama_swiftkv/modeling_llama_swiftkv.py | 28 +-- .../models/test_causal_lm_models.py | 180 +++++++++++++++--- 3 files changed, 172 insertions(+), 42 deletions(-) diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 8d758dd2e..666fd7973 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -91,8 +91,8 @@ # Placeholder for all non-transformer models from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import ( - LlamaSwiftKVConfig, - LlamaSwiftKVForCausalLM, + QeffLlamaSwiftKVConfig, + QeffLlamaSwiftKVForCausalLM, ) from .models.codegen.modeling_codegen import ( @@ -280,7 +280,7 @@ # Map of model type to config class, Modelling class and transformer model architecture class MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = { - "llama_swiftkv": [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM, AutoModelForCausalLM], + "llama_swiftkv": [QeffLlamaSwiftKVConfig, QeffLlamaSwiftKVForCausalLM, AutoModelForCausalLM], } diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index d582bcba8..d909272c7 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -31,7 +31,7 @@ ) -class LlamaSwiftKVConfig(LlamaConfig): +class QeffLlamaSwiftKVConfig(LlamaConfig): """ Args: num_key_value_layers (int, optional): @@ -59,8 +59,8 @@ def __init__( assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0 -class LlamaSwiftKVAttention(nn.Module): - def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None: +class QeffLlamaSwiftKVAttention(nn.Module): + def __init__(self, config: QeffLlamaSwiftKVConfig, layer_idx) -> None: super().__init__() self.hidden_size = config.hidden_size self.attention_dropout = config.attention_dropout @@ -139,12 +139,12 @@ def forward( return attn_output, past_key_value -class LlamaSwiftKVDecoderLayer(nn.Module): - def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None: +class QeffLlamaSwiftKVDecoderLayer(nn.Module): + def __init__(self, config: QeffLlamaSwiftKVConfig, layer_idx) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_key_value_heads = config.num_key_value_heads - self.self_attn = LlamaSwiftKVAttention(config=config, layer_idx=layer_idx) + self.self_attn = QeffLlamaSwiftKVAttention(config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -179,10 +179,10 @@ def forward( return hidden_states, past_key_values -class LlamaSwiftKVModel(nn.Module): - config_class = LlamaSwiftKVConfig +class QeffLlamaSwiftKVModel(nn.Module): + config_class = QeffLlamaSwiftKVConfig - def __init__(self, config: LlamaSwiftKVConfig): + def __init__(self, config: QeffLlamaSwiftKVConfig): super().__init__() self.vocab_size = config.vocab_size self.config = config @@ -192,7 +192,7 @@ def __init__(self, config: LlamaSwiftKVConfig): [ QEffLlamaDecoderLayer(config=config, layer_idx=idx) if idx < config.num_key_value_layers - else LlamaSwiftKVDecoderLayer(config=config, layer_idx=idx) + else QeffLlamaSwiftKVDecoderLayer(config=config, layer_idx=idx) for idx in range(config.num_hidden_layers) ] ) @@ -391,13 +391,13 @@ def forward( return hidden_states, next_cache -class LlamaSwiftKVForCausalLM(PreTrainedModel): # - config_class = LlamaSwiftKVConfig +class QeffLlamaSwiftKVForCausalLM(PreTrainedModel): # + config_class = QeffLlamaSwiftKVConfig - def __init__(self, config: LlamaSwiftKVConfig): + def __init__(self, config: QeffLlamaSwiftKVConfig): super().__init__(config=config) - self.model = LlamaSwiftKVModel( + self.model = QeffLlamaSwiftKVModel( config=config, ) self.vocab_size = config.vocab_size diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 83439ce34..8f997418c 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -9,7 +9,6 @@ from typing import Optional import numpy as np - import pytest from transformers import AutoModelForCausalLM @@ -23,9 +22,33 @@ from QEfficient.utils.run_utils import ApiRunner test_models = [ - "Snowflake/Llama-3.1-SwiftKV-8B-Instruct", # SwiftKV model + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "gpt2", + "Salesforce/codegen-350M-mono", + "microsoft/Phi-3-mini-4k-instruct", + "tiiuae/falcon-7b", + "Qwen/Qwen2-0.5B", + "bigcode/starcoder2-3b", + "Felladrin/Minueza-32M-Base", + "wtang06/mpt-125m-c4", + "hakurei/gpt-j-random-tinier", + "mistralai/Mixtral-8x7B-Instruct-v0.1", + "meta-llama/Llama-3.2-1B", + "unsloth/gemma-2b", + "unsloth/gemma-2-2b", + "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", # AWQ model + "TheBloke/Llama-2-7B-GPTQ", # GPTQ model + "ibm-granite/granite-20b-code-base", + # "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic", # naive-quantized compressed-tensor FP8 model per-channel weight, per-token activations + "neuralmagic/Llama-3.2-3B-Instruct-FP8", # float quantized compressed-tensor per tensor both weight and activations + "neuralmagic/Qwen2-0.5B-Instruct-FP8", # fp8 quant method, static, with lm head ignored + "ibm-granite/granite-3.1-2b-instruct", + "ibm-granite/granite-guardian-3.1-2b", ] +swiftkv_test_models = [ + "Snowflake/Llama-3.1-SwiftKV-8B-Instruct", # SwiftKV model +] spd_test_models = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ] @@ -89,15 +112,15 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( Constants.CTX_LEN, ) - # pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) is_tlm = False if num_speculative_tokens is None else True qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm) pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) - # assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( - # "Tokens don't match for HF PyTorch model output and KV PyTorch model output" - # ) + assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( + "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + ) onnx_model_path = qeff_model.export() ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) @@ -128,18 +151,18 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( config = model_hf.config full_batch_size = 4 fbs_prompts = Constants.INPUT_STR * 4 - # api_runner = ApiRunner( - # batch_size, - # tokenizer, - # config, - # fbs_prompts, - # Constants.PROMPT_LEN, - # Constants.CTX_LEN, - # full_batch_size, - # ) - - # pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) - # pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + fbs_prompts, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + full_batch_size, + ) + + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) + pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm) onnx_model_path = qeff_model.export() @@ -156,11 +179,8 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens, ) - # exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) - qeff_model.generate(tokenizer, prompts=fbs_prompts) - + exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) -""" assert all( [ all(pt_token[:24] == cloud_token[:24]) @@ -168,7 +188,103 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( ] ), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output." assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) -""" + + +def check_non_hf_kv_vs_ort_vs_ai100( + model_name: str, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = 1, + num_speculative_tokens: Optional[int] = None, +): + """ + Validate the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + :prompt_len (int): Prompt length for the model to compile. + :ctx_len (int): Maximum context length to compile the model. + :n_layers (int): Number of layers for the Model. + """ + replace_transformers_quantizers() + model_config = {"model_name": model_name} + model_config["n_layer"] = n_layer + + model_hf, _ = load_causal_lm_model(model_config) + + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + config = model_hf.config + batch_size = len(Constants.INPUT_STR) + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + ) + + is_tlm = False if num_speculative_tokens is None else True + + qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm) + pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + + onnx_model_path = qeff_model.export() + ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) + + assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + num_speculative_tokens=num_speculative_tokens, + ) + + exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) + cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size + gen_len = ort_tokens.shape[-1] + + assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), ( + "Tokens don't match for ONNXRT output and Cloud AI 100 output." + ) + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + + # testing for CB models + model_hf, _ = load_causal_lm_model(model_config) + config = model_hf.config + full_batch_size = 4 + fbs_prompts = Constants.INPUT_STR * 4 + + qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm) + onnx_model_path = qeff_model.export() + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + ) + + exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) + + assert all( + [ + all(pt_token[:24] == cloud_token[:24]) + for pt_token, cloud_token in zip(ort_tokens, exec_info_fbs.generated_ids) + ] + ), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output." + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) # FIXME: there should be a CB test here @@ -211,14 +327,28 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): """ if model_name == "microsoft/Phi-3-mini-4k-instruct": n_layer = 2 # test only 2 layer models - elif model_name == "Snowflake/Llama-3.1-SwiftKV-8B-Instruct": - n_layer = 32 else: n_layer = 1 check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", swiftkv_test_models) +def test_non_hf_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + if model_name == "Snowflake/Llama-3.1-SwiftKV-8B-Instruct": + n_layer = 32 + else: + n_layer = 2 + + check_non_hf_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) + + @pytest.mark.skip() # remove when the SDK 1.20.0 issue solved for compiling this model @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", spd_test_models) From 10d73b82a5d6c624419cf3337db29120a9a2b59e Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Wed, 16 Apr 2025 10:31:18 +0000 Subject: [PATCH 5/5] Changes to modeling file Signed-off-by: Amit Raj --- .../llama_swiftkv/modeling_llama_swiftkv.py | 95 +------------------ 1 file changed, 5 insertions(+), 90 deletions(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index d909272c7..5ac67b39a 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -16,11 +16,10 @@ import torch from torch import nn from transformers import LlamaConfig -from transformers.cache_utils import Cache, StaticCache -from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel -from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, logger, repeat_kv +from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, repeat_kv from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask @@ -211,80 +210,6 @@ def _run_swiftkv_layers( hidden_states = self.norm(hidden_states) return hidden_states, past_key_values - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - position_ids: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - self.config._attn_implementation = "eager" - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_length() - else: - target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens - - if attention_mask is not None and attention_mask.dim() == 4: - # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing - if attention_mask.max() != 0: - raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - else: - causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - def forward( self, input_ids: Optional[torch.Tensor], @@ -298,15 +223,7 @@ def forward( use_cache = True if use_cache and not isinstance(past_key_values, Cache): - if past_key_values is None: - past_key_values = QEffDynamicCache() - else: - past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( @@ -315,9 +232,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - None, inputs_embeds, cache_position, position_ids, past_key_values, False - ) + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) hidden_states = inputs_embeds next_decoder_cache = None @@ -419,4 +334,4 @@ def forward( past_key_values=output_past_key_values, hidden_states=None, attentions=None, - ) + ) \ No newline at end of file