Skip to content

Commit 5520757

Browse files
authored
Minor Changes to AutoModelForSpeechSeq2Seq to better align with other models (#286)
Minor fixes to generate and compile to be more consistent with how other models are called. --------- Signed-off-by: Kushal Dulla <[email protected]>
1 parent d590081 commit 5520757

File tree

4 files changed

+71
-60
lines changed

4 files changed

+71
-60
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def export(
579579
)
580580

581581
self.lang_model.export(inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir)
582+
return self.onnx_path
582583

583584
def compile(
584585
self,
@@ -676,6 +677,7 @@ def compile(
676677
custom_io=custom_io_lang,
677678
**compiler_options,
678679
)
680+
return self.qpc_path
679681

680682
def generate(
681683
self,
@@ -895,7 +897,7 @@ def export(
895897
inputs = self.model.get_dummy_inputs()
896898
dynamic_axes = self.model.get_onnx_dynamic_axes()
897899
output_names = self.model.get_output_names()
898-
self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
900+
return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
899901

900902
def compile(
901903
self,
@@ -1727,20 +1729,26 @@ def export(self, export_dir: Optional[str] = None) -> str:
17271729
inputs = self.model.get_dummy_inputs()
17281730
dynamic_axes = self.model.get_onnx_dynamic_axes()
17291731
output_names = self.model.get_output_names()
1730-
self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
1732+
return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
17311733

17321734
def compile(
17331735
self,
17341736
onnx_path: Optional[str] = None,
17351737
compile_dir: Optional[str] = None,
17361738
*,
1737-
encoder_ctx_len: int = 1500,
1738-
decoder_ctx_len: int = 150,
1739-
feature_len: int = 3000,
1739+
prefill_seq_len: Optional[int] = 1,
1740+
encoder_ctx_len: Optional[int] = None,
1741+
ctx_len: int = 150,
1742+
full_batch_size: Optional[int] = None,
1743+
kv_cache_batch_size: Optional[int] = None,
17401744
batch_size: int = 1,
17411745
num_devices: int = 1,
17421746
num_cores: int = 16, # FIXME: Make this mandatory arg
17431747
mxfp6_matmul: bool = False,
1748+
mxint8_kv_cache: bool = False,
1749+
num_speculative_tokens: Optional[int] = None,
1750+
enable_qnn: bool = False,
1751+
qnn_config: Optional[str] = None,
17441752
**compiler_options,
17451753
) -> str:
17461754
"""
@@ -1751,19 +1759,41 @@ def compile(
17511759
``Optional`` Args:
17521760
:onnx_path (str, optional): Path to pre-exported onnx model.
17531761
:compile_dir (str, optional): Path for saving the qpc generated.
1754-
:seq_len (int, optional): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``.
1762+
:encoder_ctx_len (int, optional): The maximum length of context for encoder, based on the AutoProcessor output. ``Defaults to checking config, if None in config then 1500``
1763+
:ctx_len (int, optional): The maximum length of context to keep for decoding. ``Defaults to 150``.
17551764
:batch_size (int, optional): Batch size. ``Defaults to 1``.
17561765
:num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1.
17571766
:num_cores (int): Number of cores used to compile the model.
17581767
:mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``.
17591768
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
1760-
:allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``
1769+
1770+
Other args are not yet implemented for AutoModelForSpeechSeq2Seq
17611771
Returns:
17621772
:str: Path of the compiled ``qpc`` package.
17631773
"""
1764-
specializations = self.model.get_specializations(batch_size, encoder_ctx_len, decoder_ctx_len, feature_len)
1774+
specializations, compiler_options = self.model.get_specializations(
1775+
batch_size,
1776+
encoder_ctx_len,
1777+
ctx_len,
1778+
**compiler_options,
1779+
)
17651780

1766-
self._compile(
1781+
if full_batch_size:
1782+
logger.warning("Continuous batching is not yet enabled for AutoModelForSpeechSeq2Seq")
1783+
1784+
if kv_cache_batch_size:
1785+
logger.warning("Prefix caching is not yet enabled for AutoModelForSpeechSeq2Seq")
1786+
1787+
if mxint8_kv_cache:
1788+
logger.warning("mxint8 cache is not yet enabled for AutoModelForSpeechSeq2Seq")
1789+
1790+
if num_speculative_tokens:
1791+
logger.warning("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq")
1792+
1793+
if enable_qnn or qnn_config:
1794+
logger.warning("QNN compile is not yet enabled for AutoModelForSpeechSeq2Seq")
1795+
1796+
return self._compile(
17671797
onnx_path,
17681798
compile_dir,
17691799
compile_only=True,
@@ -1781,7 +1811,6 @@ def generate(
17811811
inputs: torch.Tensor,
17821812
generation_len: int,
17831813
streamer: Optional[TextStreamer] = None,
1784-
enable_debug_logs: bool = False,
17851814
device_ids: List[int] = None,
17861815
) -> Union[torch.Tensor, np.ndarray]:
17871816
"""
@@ -1790,9 +1819,8 @@ def generate(
17901819
17911820
``Mandatory`` Args:
17921821
:processor: autoprocessor to process inputs and decode logits
1793-
:inputs (np.ndarray): inputs to run the execution.
1822+
:inputs (torch.Tensor): inputs to run the execution.
17941823
:generation_len (int): length upto which to generate
1795-
:sample_rate (int): sampling rate at which input audio is stored in inputs (needed for processor)
17961824
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
17971825
Returns:
17981826
:dict: Output from the ``AI_100`` or ``PyTorch`` runtime.
@@ -1803,9 +1831,20 @@ def generate(
18031831
inputs = self.auto_correct_inputs(inputs)
18041832

18051833
if self.qpc_session is None:
1806-
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids, enable_debug_logs=enable_debug_logs)
1834+
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids)
18071835
self.batch_size = self.qpc_session.bindings[0].dims[0]
18081836

1837+
inputs["input_features"] = inputs["input_features"].numpy().astype(np.float32)
1838+
1839+
# add start token id and initial position ids to inputs
1840+
seq_len = 1
1841+
inputs["decoder_input_ids"] = (
1842+
torch.ones((self.batch_size, seq_len), dtype=torch.int64) * self.model.config.decoder_start_token_id
1843+
).numpy()
1844+
inputs["decoder_position_ids"] = (
1845+
torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(self.batch_size, 1).numpy()
1846+
)
1847+
18091848
self.qpc_session.skip_buffers(
18101849
[x for x in self.qpc_session.input_names + self.qpc_session.output_names if x.startswith("past_")]
18111850
)

QEfficient/transformers/models/whisper/modeling_whisper.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import random
99
from typing import Optional, Tuple
1010

11-
import numpy as np
1211
import torch
1312
from torch import nn
1413
from transformers.cache_utils import Cache, EncoderDecoderCache, StaticCache
@@ -812,28 +811,33 @@ def get_dummy_inputs(
812811

813812
return inputs
814813

815-
def get_specializations(
816-
self, batch_size: int, encoder_ctx_len: int, decoder_ctx_len: int, feature_len: int, **compiler_options
817-
):
814+
def get_specializations(self, batch_size: int, encoder_ctx_len, ctx_len, **compiler_options):
815+
if encoder_ctx_len is None and hasattr(self.config, "max_source_positions"):
816+
encoder_ctx_len = self.config.max_source_positions
817+
elif encoder_ctx_len is None:
818+
encoder_ctx_len = 1500
819+
logger.warning("Setting `encoder_ctx_len=1500` as it was neither passed nor found in config")
820+
feature_len = encoder_ctx_len * 2
821+
818822
encoder_specializations = {
819823
"batch_size": batch_size,
820824
"seq_len": 1,
821825
"encoder_ctx_len": encoder_ctx_len,
822-
"decoder_ctx_len": decoder_ctx_len,
826+
"decoder_ctx_len": ctx_len,
823827
"feature_len": feature_len,
824828
}
825829

826830
decoder_specializations = {
827831
"batch_size": batch_size,
828832
"seq_len": 1,
829833
"encoder_ctx_len": encoder_ctx_len,
830-
"decoder_ctx_len": decoder_ctx_len,
834+
"decoder_ctx_len": ctx_len,
831835
"feature_len": 1, # important dummy feature so that torch.where knows whether to run cross attention or not
832836
}
833837

834838
specializations = [encoder_specializations, decoder_specializations]
835839

836-
return specializations
840+
return specializations, compiler_options
837841

838842
def get_onnx_dynamic_axes(
839843
self,
@@ -874,7 +878,5 @@ def get_output_names(
874878

875879
def get_inputs_info(self):
876880
return [
877-
IOInfo(name="input_features", datatype=np.float32, shape=("batch_size", "num_mel_bins", "feature_len")),
878-
IOInfo(name="decoder_input_ids", datatype=np.int64, shape=("batch_size", "seq_len")),
879-
IOInfo(name="decoder_position_ids", datatype=np.int64, shape=("batch_size", "seq_len")),
881+
IOInfo(name="input_features", datatype=torch.float32, shape=("batch_size", "num_mel_bins", "feature_len")),
880882
]

examples/speech_to_text/run_whisper_speech_to_text.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
import numpy as np
9-
import torch
108
from datasets import load_dataset
119
from transformers import AutoProcessor
1210

@@ -29,24 +27,10 @@
2927
## STEP 3 -- export and compile model
3028
qeff_model.compile()
3129

32-
## STEP 4 -- prepare generate inputs
33-
bs = 1
34-
seq_len = 1
35-
input_features = (
36-
processor(data, sampling_rate=sample_rate, return_tensors="pt").input_features.numpy().astype(np.float32)
30+
## STEP 4 -- generate output for loaded input and processor
31+
exec_info = qeff_model.generate(
32+
inputs=processor(data, sampling_rate=sample_rate, return_tensors="pt"), generation_len=ctx_len
3733
)
38-
decoder_input_ids = (
39-
torch.ones((bs, seq_len), dtype=torch.int64) * qeff_model.model.config.decoder_start_token_id
40-
).numpy()
41-
decoder_position_ids = torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1).numpy()
42-
inputs = dict(
43-
input_features=input_features,
44-
decoder_input_ids=decoder_input_ids,
45-
decoder_position_ids=decoder_position_ids,
46-
)
47-
48-
## STEP 5 -- generate output for loaded input and processor
49-
exec_info = qeff_model.generate(inputs=inputs, generation_len=ctx_len)
5034

51-
## STEP 6 (optional) -- use processor to decode output
35+
## STEP 5 (optional) -- use processor to decode output
5236
print(processor.batch_decode(exec_info.generated_ids)[0])

tests/transformers/models/test_speech_seq2seq_models.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -334,28 +334,14 @@ def check_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(
334334
pytest.skip("No available devices to run model on Cloud AI 100")
335335

336336
qeff_model.compile(
337-
encoder_ctx_len=qeff_model.model.config.max_source_positions,
338-
decoder_ctx_len=ctx_len,
337+
ctx_len=ctx_len,
339338
num_cores=16,
340339
batch_size=batch_size,
341340
)
342341

343-
bs = 1
344-
seq_len = 1
345-
input_features = (
346-
processor(data, sampling_rate=sample_rate, return_tensors="pt").input_features.numpy().astype(np.float32)
347-
)
348-
decoder_input_ids = (
349-
torch.ones((bs, seq_len), dtype=torch.int64) * qeff_model.model.config.decoder_start_token_id
350-
).numpy()
351-
decoder_position_ids = torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1).numpy()
352-
inputs = dict(
353-
input_features=input_features,
354-
decoder_input_ids=decoder_input_ids,
355-
decoder_position_ids=decoder_position_ids,
342+
exec_info = qeff_model.generate(
343+
inputs=processor(data, sampling_rate=sample_rate, return_tensors="pt"), generation_len=ctx_len
356344
)
357-
358-
exec_info = qeff_model.generate(inputs=inputs, generation_len=ctx_len)
359345
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
360346
assert (pytorch_kv_tokens == cloud_ai_100_tokens).all(), (
361347
"Tokens don't match for pytorch output and Cloud AI 100 output."

0 commit comments

Comments
 (0)