Skip to content

Commit bbaaf61

Browse files
committed
Added unit test for non HF models like swiftkv
Signed-off-by: Hem Agnihotri <[email protected]>
1 parent 013e4b7 commit bbaaf61

File tree

3 files changed

+172
-42
lines changed

3 files changed

+172
-42
lines changed

QEfficient/transformers/modeling_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@
9191

9292
# Placeholder for all non-transformer models
9393
from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import (
94-
LlamaSwiftKVConfig,
95-
LlamaSwiftKVForCausalLM,
94+
QeffLlamaSwiftKVConfig,
95+
QeffLlamaSwiftKVForCausalLM,
9696
)
9797

9898
from .models.codegen.modeling_codegen import (
@@ -280,7 +280,7 @@
280280

281281
# Map of model type to config class, Modelling class and transformer model architecture class
282282
MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = {
283-
"llama_swiftkv": [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM, AutoModelForCausalLM],
283+
"llama_swiftkv": [QeffLlamaSwiftKVConfig, QeffLlamaSwiftKVForCausalLM, AutoModelForCausalLM],
284284
}
285285

286286

QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
)
3232

3333

34-
class LlamaSwiftKVConfig(LlamaConfig):
34+
class QeffLlamaSwiftKVConfig(LlamaConfig):
3535
"""
3636
Args:
3737
num_key_value_layers (int, optional):
@@ -59,8 +59,8 @@ def __init__(
5959
assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0
6060

6161

62-
class LlamaSwiftKVAttention(nn.Module):
63-
def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None:
62+
class QeffLlamaSwiftKVAttention(nn.Module):
63+
def __init__(self, config: QeffLlamaSwiftKVConfig, layer_idx) -> None:
6464
super().__init__()
6565
self.hidden_size = config.hidden_size
6666
self.attention_dropout = config.attention_dropout
@@ -139,12 +139,12 @@ def forward(
139139
return attn_output, past_key_value
140140

141141

142-
class LlamaSwiftKVDecoderLayer(nn.Module):
143-
def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None:
142+
class QeffLlamaSwiftKVDecoderLayer(nn.Module):
143+
def __init__(self, config: QeffLlamaSwiftKVConfig, layer_idx) -> None:
144144
super().__init__()
145145
self.hidden_size = config.hidden_size
146146
self.num_key_value_heads = config.num_key_value_heads
147-
self.self_attn = LlamaSwiftKVAttention(config=config, layer_idx=layer_idx)
147+
self.self_attn = QeffLlamaSwiftKVAttention(config=config, layer_idx=layer_idx)
148148
self.mlp = LlamaMLP(config)
149149
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
150150
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -179,10 +179,10 @@ def forward(
179179
return hidden_states, past_key_values
180180

181181

182-
class LlamaSwiftKVModel(nn.Module):
183-
config_class = LlamaSwiftKVConfig
182+
class QeffLlamaSwiftKVModel(nn.Module):
183+
config_class = QeffLlamaSwiftKVConfig
184184

185-
def __init__(self, config: LlamaSwiftKVConfig):
185+
def __init__(self, config: QeffLlamaSwiftKVConfig):
186186
super().__init__()
187187
self.vocab_size = config.vocab_size
188188
self.config = config
@@ -192,7 +192,7 @@ def __init__(self, config: LlamaSwiftKVConfig):
192192
[
193193
QEffLlamaDecoderLayer(config=config, layer_idx=idx)
194194
if idx < config.num_key_value_layers
195-
else LlamaSwiftKVDecoderLayer(config=config, layer_idx=idx)
195+
else QeffLlamaSwiftKVDecoderLayer(config=config, layer_idx=idx)
196196
for idx in range(config.num_hidden_layers)
197197
]
198198
)
@@ -391,13 +391,13 @@ def forward(
391391
return hidden_states, next_cache
392392

393393

394-
class LlamaSwiftKVForCausalLM(PreTrainedModel): #
395-
config_class = LlamaSwiftKVConfig
394+
class QeffLlamaSwiftKVForCausalLM(PreTrainedModel): #
395+
config_class = QeffLlamaSwiftKVConfig
396396

397-
def __init__(self, config: LlamaSwiftKVConfig):
397+
def __init__(self, config: QeffLlamaSwiftKVConfig):
398398
super().__init__(config=config)
399399

400-
self.model = LlamaSwiftKVModel(
400+
self.model = QeffLlamaSwiftKVModel(
401401
config=config,
402402
)
403403
self.vocab_size = config.vocab_size

tests/transformers/models/test_causal_lm_models.py

Lines changed: 155 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import Optional
1010

1111
import numpy as np
12-
1312
import pytest
1413
from transformers import AutoModelForCausalLM
1514

@@ -23,9 +22,33 @@
2322
from QEfficient.utils.run_utils import ApiRunner
2423

2524
test_models = [
26-
"Snowflake/Llama-3.1-SwiftKV-8B-Instruct", # SwiftKV model
25+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
26+
"gpt2",
27+
"Salesforce/codegen-350M-mono",
28+
"microsoft/Phi-3-mini-4k-instruct",
29+
"tiiuae/falcon-7b",
30+
"Qwen/Qwen2-0.5B",
31+
"bigcode/starcoder2-3b",
32+
"Felladrin/Minueza-32M-Base",
33+
"wtang06/mpt-125m-c4",
34+
"hakurei/gpt-j-random-tinier",
35+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
36+
"meta-llama/Llama-3.2-1B",
37+
"unsloth/gemma-2b",
38+
"unsloth/gemma-2-2b",
39+
"TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", # AWQ model
40+
"TheBloke/Llama-2-7B-GPTQ", # GPTQ model
41+
"ibm-granite/granite-20b-code-base",
42+
# "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic", # naive-quantized compressed-tensor FP8 model per-channel weight, per-token activations
43+
"neuralmagic/Llama-3.2-3B-Instruct-FP8", # float quantized compressed-tensor per tensor both weight and activations
44+
"neuralmagic/Qwen2-0.5B-Instruct-FP8", # fp8 quant method, static, with lm head ignored
45+
"ibm-granite/granite-3.1-2b-instruct",
46+
"ibm-granite/granite-guardian-3.1-2b",
2747
]
2848

49+
swiftkv_test_models = [
50+
"Snowflake/Llama-3.1-SwiftKV-8B-Instruct", # SwiftKV model
51+
]
2952
spd_test_models = [
3053
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
3154
]
@@ -89,15 +112,15 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
89112
Constants.CTX_LEN,
90113
)
91114

92-
# pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf)
115+
pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf)
93116
is_tlm = False if num_speculative_tokens is None else True
94117
qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm)
95118

96119
pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model)
97120

98-
# assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
99-
# "Tokens don't match for HF PyTorch model output and KV PyTorch model output"
100-
# )
121+
assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
122+
"Tokens don't match for HF PyTorch model output and KV PyTorch model output"
123+
)
101124

102125
onnx_model_path = qeff_model.export()
103126
ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm)
@@ -128,18 +151,18 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
128151
config = model_hf.config
129152
full_batch_size = 4
130153
fbs_prompts = Constants.INPUT_STR * 4
131-
# api_runner = ApiRunner(
132-
# batch_size,
133-
# tokenizer,
134-
# config,
135-
# fbs_prompts,
136-
# Constants.PROMPT_LEN,
137-
# Constants.CTX_LEN,
138-
# full_batch_size,
139-
# )
140-
141-
# pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf)
142-
# pytorch_hf_tokens = np.vstack(pytorch_hf_tokens)
154+
api_runner = ApiRunner(
155+
batch_size,
156+
tokenizer,
157+
config,
158+
fbs_prompts,
159+
Constants.PROMPT_LEN,
160+
Constants.CTX_LEN,
161+
full_batch_size,
162+
)
163+
164+
pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf)
165+
pytorch_hf_tokens = np.vstack(pytorch_hf_tokens)
143166

144167
qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm)
145168
onnx_model_path = qeff_model.export()
@@ -156,19 +179,112 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
156179
full_batch_size=full_batch_size,
157180
num_speculative_tokens=num_speculative_tokens,
158181
)
159-
# exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)
160-
qeff_model.generate(tokenizer, prompts=fbs_prompts)
161-
182+
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)
162183

163-
"""
164184
assert all(
165185
[
166186
all(pt_token[:24] == cloud_token[:24])
167187
for pt_token, cloud_token in zip(pytorch_hf_tokens, exec_info_fbs.generated_ids)
168188
]
169189
), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output."
170190
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
171-
"""
191+
192+
193+
def check_non_hf_kv_vs_ort_vs_ai100(
194+
model_name: str,
195+
prompt_len: int = Constants.PROMPT_LEN,
196+
ctx_len: int = Constants.CTX_LEN,
197+
n_layer: int = 1,
198+
num_speculative_tokens: Optional[int] = None,
199+
):
200+
"""
201+
Validate the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
202+
``Mandatory`` Args:
203+
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
204+
:prompt_len (int): Prompt length for the model to compile.
205+
:ctx_len (int): Maximum context length to compile the model.
206+
:n_layers (int): Number of layers for the Model.
207+
"""
208+
replace_transformers_quantizers()
209+
model_config = {"model_name": model_name}
210+
model_config["n_layer"] = n_layer
211+
212+
model_hf, _ = load_causal_lm_model(model_config)
213+
214+
tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name)
215+
config = model_hf.config
216+
batch_size = len(Constants.INPUT_STR)
217+
api_runner = ApiRunner(
218+
batch_size,
219+
tokenizer,
220+
config,
221+
Constants.INPUT_STR,
222+
Constants.PROMPT_LEN,
223+
Constants.CTX_LEN,
224+
)
225+
226+
is_tlm = False if num_speculative_tokens is None else True
227+
228+
qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm)
229+
pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model)
230+
231+
onnx_model_path = qeff_model.export()
232+
ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm)
233+
234+
assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output."
235+
236+
if not get_available_device_id():
237+
pytest.skip("No available devices to run model on Cloud AI 100")
238+
239+
qpc_path = qeff_model.compile(
240+
prefill_seq_len=prompt_len,
241+
ctx_len=ctx_len,
242+
num_cores=14,
243+
mxfp6=False,
244+
aic_enable_depth_first=False,
245+
num_speculative_tokens=num_speculative_tokens,
246+
)
247+
248+
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
249+
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
250+
gen_len = ort_tokens.shape[-1]
251+
252+
assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), (
253+
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
254+
)
255+
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
256+
257+
# testing for CB models
258+
model_hf, _ = load_causal_lm_model(model_config)
259+
config = model_hf.config
260+
full_batch_size = 4
261+
fbs_prompts = Constants.INPUT_STR * 4
262+
263+
qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm)
264+
onnx_model_path = qeff_model.export()
265+
266+
if not get_available_device_id():
267+
pytest.skip("No available devices to run model on Cloud AI 100")
268+
269+
qpc_path = qeff_model.compile(
270+
prefill_seq_len=prompt_len,
271+
ctx_len=ctx_len,
272+
num_cores=14,
273+
mxfp6=False,
274+
aic_enable_depth_first=False,
275+
full_batch_size=full_batch_size,
276+
num_speculative_tokens=num_speculative_tokens,
277+
)
278+
279+
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)
280+
281+
assert all(
282+
[
283+
all(pt_token[:24] == cloud_token[:24])
284+
for pt_token, cloud_token in zip(ort_tokens, exec_info_fbs.generated_ids)
285+
]
286+
), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output."
287+
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
172288

173289

174290
# FIXME: there should be a CB test here
@@ -211,14 +327,28 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
211327
"""
212328
if model_name == "microsoft/Phi-3-mini-4k-instruct":
213329
n_layer = 2 # test only 2 layer models
214-
elif model_name == "Snowflake/Llama-3.1-SwiftKV-8B-Instruct":
215-
n_layer = 32
216330
else:
217331
n_layer = 1
218332

219333
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer)
220334

221335

336+
@pytest.mark.on_qaic
337+
@pytest.mark.parametrize("model_name", swiftkv_test_models)
338+
def test_non_hf_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
339+
"""
340+
Test function to validate the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
341+
``Mandatory`` Args:
342+
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
343+
"""
344+
if model_name == "Snowflake/Llama-3.1-SwiftKV-8B-Instruct":
345+
n_layer = 32
346+
else:
347+
n_layer = 2
348+
349+
check_non_hf_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer)
350+
351+
222352
@pytest.mark.skip() # remove when the SDK 1.20.0 issue solved for compiling this model
223353
@pytest.mark.on_qaic
224354
@pytest.mark.parametrize("model_name", spd_test_models)

0 commit comments

Comments
 (0)