diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 47c462979..a0120b3ff 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -1,6 +1,6 @@ # ----------------------------------------------------------------------------- # -# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- @@ -12,8 +12,19 @@ # hf_transfer is imported (will happen on line 15 via leading imports) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +from transformers import AutoConfig + +from QEfficient.transformers.modeling_utils import MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS from QEfficient.utils.logging_utils import logger +# loop over all the model types which are not present in transformers and register them +for model_type, model_cls in MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS.items(): + # Register the model config class based on the model type. This will be first element in the tuple + AutoConfig.register(model_type, model_cls[0]) + + # Register the non transformer library Class and config class using AutoModelClass + model_cls[2].register(model_cls[0], model_cls[1]) + def check_qaic_sdk(): """Check if QAIC SDK is installed""" diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index a5c375c6e..db1cb7ea4 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -36,6 +36,58 @@ class QEffDynamicCache(DynamicCache): """ + def write_only(self, key_states, value_states, layer_idx, cache_kwargs): + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + else: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + + # Scatter + if batch_index is not None: + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) + + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], position_ids, value_states + ) + + def read_only(self, layer_idx, cache_kwargs): + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + ctx_len = k_out.shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + def update( self, key_states: torch.Tensor, diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index ccad5e020..666fd7973 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn +from transformers import AutoModelForCausalLM from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, CodeGenBlock, @@ -88,6 +89,12 @@ from QEfficient.customop import CustomRMSNormAIC +# Placeholder for all non-transformer models +from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import ( + QeffLlamaSwiftKVConfig, + QeffLlamaSwiftKVForCausalLM, +) + from .models.codegen.modeling_codegen import ( QEffCodeGenAttention, QeffCodeGenBlock, @@ -271,6 +278,11 @@ WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration, } +# Map of model type to config class, Modelling class and transformer model architecture class +MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = { + "llama_swiftkv": [QeffLlamaSwiftKVConfig, QeffLlamaSwiftKVForCausalLM, AutoModelForCausalLM], +} + def _prepare_cross_attention_mask( cross_attention_mask: torch.Tensor, diff --git a/QEfficient/transformers/models/llama_swiftkv/__init__.py b/QEfficient/transformers/models/llama_swiftkv/__init__.py new file mode 100644 index 000000000..72ba36c8a --- /dev/null +++ b/QEfficient/transformers/models/llama_swiftkv/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py new file mode 100644 index 000000000..5ac67b39a --- /dev/null +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -0,0 +1,337 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +# This file is adapted from vllm implementation by snowflake here: https://github.com/Snowflake-Labs/vllm/blob/swiftkv/vllm/model_executor/models/llama_swiftkv.py +# The Modules are updated as required by Cloud AI 100 HW requirements. + + +"""Inference-only LLaMA model compatible with HuggingFace weights.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import LlamaConfig +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, repeat_kv + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.transformers.models.llama.modeling_llama import ( + QEffLlamaDecoderLayer, + QEffLlamaRotaryEmbedding, + qeff_apply_rotary_pos_emb, +) + + +class QeffLlamaSwiftKVConfig(LlamaConfig): + """ + Args: + num_key_value_layers (int, optional): + The number of layers, from the first layer, that have keys and + values. If None, all layers have keys and values. + last_key_value_heads (int, optional): + The number of heads in the last layer that have keys and values. + If None, the number of heads in the last key-value layer is equal + to the number of heads in all the other key-value layers. + """ + + model_type = "llama_swiftkv" + + def __init__( + self, + swiftkv: bool = False, + num_key_value_layers: Optional[int] = None, + key_value_group_size: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.swiftkv = swiftkv + self.num_key_value_layers = num_key_value_layers or self.num_hidden_layers + self.key_value_group_size = key_value_group_size or 1 + assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0 + + +class QeffLlamaSwiftKVAttention(nn.Module): + def __init__(self, config: QeffLlamaSwiftKVConfig, layer_idx) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.layer_idx = layer_idx + self.q_proj_swiftkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj_swiftkv = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj_swiftkv = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + self.rotary_emb = QEffLlamaRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask=None, + batch_index: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + query = self.q_proj_swiftkv(hidden_states) + # Reshape the query, key, and value tensors. + query_states = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = position_ids.shape[-1] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) + query_states, _ = qeff_apply_rotary_pos_emb( + query_states, torch.empty_like(query_states), cos, sin, position_ids + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, past_key_value + + +class QeffLlamaSwiftKVDecoderLayer(nn.Module): + def __init__(self, config: QeffLlamaSwiftKVConfig, layer_idx) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.self_attn = QeffLlamaSwiftKVAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + past_key_values, + causal_mask, + batch_index: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, past_key_values = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + past_key_value=past_key_values, + attention_mask=causal_mask, + batch_index=batch_index, + ) + + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, past_key_values + + +class QeffLlamaSwiftKVModel(nn.Module): + config_class = QeffLlamaSwiftKVConfig + + def __init__(self, config: QeffLlamaSwiftKVConfig): + super().__init__() + self.vocab_size = config.vocab_size + self.config = config + + self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, None) + self.layers = torch.nn.ModuleList( + [ + QEffLlamaDecoderLayer(config=config, layer_idx=idx) + if idx < config.num_key_value_layers + else QeffLlamaSwiftKVDecoderLayer(config=config, layer_idx=idx) + for idx in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def _run_swiftkv_layers( + self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask, batch_index + ) -> torch.Tensor: + for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): + layer = self.layers[layer_idx] + hidden_states, past_key_values = layer( + hidden_states, position_ids, past_key_values, causal_mask, batch_index + ) + + hidden_states = self.norm(hidden_states) + return hidden_states, past_key_values + + def forward( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_values: List[torch.Tensor], + batch_index: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + use_cache = True + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) + hidden_states = inputs_embeds + + next_decoder_cache = None + + for layer_idx in range(self.config.num_key_value_layers): + layer = self.layers[layer_idx] + hidden_states, next_decoder_cache = layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=False, + use_cache=True, + cache_position=cache_position, + position_embeddings=None, + ) + + bsz, q_len, _ = hidden_states.size() + swiftkv_hidden_states = self.norm_swiftkv(hidden_states) + #################################### + ## THE MAGIC OF SWIFT KV BEGINS HERE + #################################### + for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): + self_attn = self.layers[layer_idx].self_attn + key_states = self_attn.k_proj_swiftkv(swiftkv_hidden_states) + value_states = self_attn.v_proj_swiftkv(swiftkv_hidden_states) + key_states = key_states.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose( + 1, 2 + ) + + kv_seq_len = key_states.shape[-2] + if past_key_values is not None: + if self_attn.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self_attn.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_values.get_usable_length(kv_seq_len, self_attn.layer_idx) + + cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) + _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids, "batch_index": batch_index} + past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) + + last_pos_id = position_ids.to(torch.int32).argmax(1, keepdim=True) + orig_hidden_states = hidden_states + + # Extracting only the last valid position id to be processed by self-attn of half of the layers, as KV cache is already filled. + if batch_index is not None: + hidden_states = orig_hidden_states[batch_index, last_pos_id, :] + causal_mask = causal_mask[batch_index, :, last_pos_id, :] + else: + hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :] + causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :] + + hidden_states, next_decoder_cache = self._run_swiftkv_layers( + hidden_states, position_ids, past_key_values, causal_mask, batch_index + ) + # We can fill the orig_hidden_states with the processed hidden_states here but it's not needed as for next token prediction + # we only need the last valid pos_indices hidden_states. + # Here the shape of hiden_states is [batch_size, 1, hidden_dim] instead of [batch_size, seq_len, hidden_dim] + # This saves un-necessary data movement on devices. + #################################### + ## THE MAGIC OF SWIFT KV ENDS HERE + #################################### + + next_cache = next_decoder_cache.to_legacy_cache() + return hidden_states, next_cache + + +class QeffLlamaSwiftKVForCausalLM(PreTrainedModel): # + config_class = QeffLlamaSwiftKVConfig + + def __init__(self, config: QeffLlamaSwiftKVConfig): + super().__init__(config=config) + + self.model = QeffLlamaSwiftKVModel( + config=config, + ) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.config = config + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + ): + hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values, batch_index) + logits = self.lm_head(hidden_states) + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=output_past_key_values, + hidden_states=None, + attentions=None, + ) \ No newline at end of file diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index ea9044e2c..8ba5e2c18 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -17,7 +17,12 @@ import yaml from huggingface_hub import login, snapshot_download from requests.exceptions import HTTPError -from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import ( + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants, QnnConstants from QEfficient.utils.logging_utils import logger diff --git a/README.md b/README.md index 2185c9f64..724717874 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ --- *Latest news* :fire:
+- [03/2025] Added support for swiftkv model [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct) - [02/2025] [VLMs support](https://github.com/quic/efficient-transformers/pull/267) added for the models [InternVL-1B](https://huggingface.co/OpenGVLab/InternVL2_5-1B), [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf) and [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) - [01/2025] [FP8 models support](https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127) Added support for inference of FP8 models. diff --git a/docs/source/validate.md b/docs/source/validate.md index acd4c11da..7f1690d2d 100644 --- a/docs/source/validate.md +++ b/docs/source/validate.md @@ -33,6 +33,7 @@ | **Phi3ForCausalLM** | Phi-3, Phi-3.5 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | ✔️ | | **QwenForCausalLM** | DeepSeek-R1-Distill-Qwen | [DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | ✔️ | | | Qwen2, Qwen2.5 | [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | ✔️ | +| **LlamaSwiftKVForCausalLM** | swiftkv | [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct) | ✔️ | ## Embedding Models diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 418386780..8f997418c 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -46,6 +46,9 @@ "ibm-granite/granite-guardian-3.1-2b", ] +swiftkv_test_models = [ + "Snowflake/Llama-3.1-SwiftKV-8B-Instruct", # SwiftKV model +] spd_test_models = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ] @@ -110,7 +113,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( ) pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) - is_tlm = False if num_speculative_tokens is None else True qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm) @@ -146,6 +148,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( # testing for CB models model_hf, _ = load_causal_lm_model(model_config) + config = model_hf.config full_batch_size = 4 fbs_prompts = Constants.INPUT_STR * 4 api_runner = ApiRunner( @@ -187,6 +190,103 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) +def check_non_hf_kv_vs_ort_vs_ai100( + model_name: str, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = 1, + num_speculative_tokens: Optional[int] = None, +): + """ + Validate the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + :prompt_len (int): Prompt length for the model to compile. + :ctx_len (int): Maximum context length to compile the model. + :n_layers (int): Number of layers for the Model. + """ + replace_transformers_quantizers() + model_config = {"model_name": model_name} + model_config["n_layer"] = n_layer + + model_hf, _ = load_causal_lm_model(model_config) + + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + config = model_hf.config + batch_size = len(Constants.INPUT_STR) + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + ) + + is_tlm = False if num_speculative_tokens is None else True + + qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm) + pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + + onnx_model_path = qeff_model.export() + ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) + + assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + num_speculative_tokens=num_speculative_tokens, + ) + + exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) + cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size + gen_len = ort_tokens.shape[-1] + + assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), ( + "Tokens don't match for ONNXRT output and Cloud AI 100 output." + ) + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + + # testing for CB models + model_hf, _ = load_causal_lm_model(model_config) + config = model_hf.config + full_batch_size = 4 + fbs_prompts = Constants.INPUT_STR * 4 + + qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm) + onnx_model_path = qeff_model.export() + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + ) + + exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) + + assert all( + [ + all(pt_token[:24] == cloud_token[:24]) + for pt_token, cloud_token in zip(ort_tokens, exec_info_fbs.generated_ids) + ] + ), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output." + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + + # FIXME: there should be a CB test here @pytest.mark.parametrize("model_name", ["gpt2"], ids=lambda x: x) def test_causal_lm_export_with_deprecated_api(model_name): @@ -233,6 +333,22 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", swiftkv_test_models) +def test_non_hf_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + if model_name == "Snowflake/Llama-3.1-SwiftKV-8B-Instruct": + n_layer = 32 + else: + n_layer = 2 + + check_non_hf_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) + + @pytest.mark.skip() # remove when the SDK 1.20.0 issue solved for compiling this model @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", spd_test_models)