diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py
index 7f63b34ca..57f638de0 100644
--- a/QEfficient/__init__.py
+++ b/QEfficient/__init__.py
@@ -6,7 +6,17 @@
# -----------------------------------------------------------------------------
import os
-import warnings
+
+# ----------------------------------------------------------------------------- #
+# For faster downloads via hf_transfer
+# This code is put above import statements as this needs to be executed before
+# hf_transfer is imported (will happen on line 15 via leading imports)
+os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+# DO NOT ADD ANY CODE ABOVE THIS LINE
+# Please contact maintainers if you must edit this file above this line.
+# ----------------------------------------------------------------------------- #
+# Placeholder for all non-transformer models registered in QEfficient
+import warnings # noqa: I001
import QEfficient.utils.model_registery # noqa: F401
from QEfficient.base import (
@@ -18,12 +28,19 @@
QEFFCommonLoader,
)
from QEfficient.compile.compile_helper import compile
+from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.peft import QEffAutoPeftModelForCausalLM
from QEfficient.transformers.transform import transform
from QEfficient.utils import custom_format_warning
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
+
+# custom warning for the better logging experience
+warnings.formatwarning = custom_format_warning
+
# Users can use QEfficient.export for exporting models to ONNX
export = qualcomm_efficient_converter
@@ -39,15 +56,9 @@
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
+ "QEffFluxPipeline",
]
-# For faster downloads via hf_transfer
-# This code is put above import statements as this needs to be executed before
-# hf_transfer is imported (will happen on line 15 via leading imports)
-os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
-# Placeholder for all non-transformer models registered in QEfficient
-# custom warning for the better logging experience
-warnings.formatwarning = custom_format_warning
# Conditionally import QAIC-related modules if the SDK is installed
__version__ = "0.0.1.dev0"
diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py
index ef7e83adf..1e596ce29 100644
--- a/QEfficient/base/modeling_qeff.py
+++ b/QEfficient/base/modeling_qeff.py
@@ -7,8 +7,6 @@
import gc
import inspect
-import logging
-import re
import shutil
import subprocess
import warnings
@@ -21,28 +19,24 @@
from QEfficient.base.onnx_transforms import (
BaseOnnxTransform,
- CustomOpTransform,
OnnxTransformPipeline,
- RenameFunctionOutputsTransform,
)
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
-from QEfficient.transformers.cache_utils import InvalidIndexProvider
-from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils import (
constants,
create_json,
create_model_params,
dump_qconfig,
- export_wrapper,
generate_mdp_partition_config,
hash_dict_params,
load_json,
)
-from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches
+from QEfficient.utils.export_utils import export_wrapper
+from QEfficient.utils.logging_utils import QEFFLogger
-logger = logging.getLogger(__name__)
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
class QEFFBaseModel(ABC):
@@ -66,6 +60,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
self.model = model
self.hash_params = create_model_params(self, **kwargs)
+ self.prefill_onnx_path: Optional[str] = None
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
@@ -125,9 +120,35 @@ def _model_offloaded_check(self) -> None:
logger.error(error_msg)
raise RuntimeError(error_msg)
+ @property
+ def model_name(self) -> str:
+ """
+ Get the model class name without QEff/QEFF prefix.
+
+ This property extracts the underlying model's class name and removes
+ any QEff or QEFF prefix that may have been added during wrapping.
+
+ Returns:
+ str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel")
+ """
+ mname = self.model.__class__.__name__
+ if mname.startswith("QEff") or mname.startswith("QEFF"):
+ mname = mname[4:]
+ return mname
+
@property
@abstractmethod
- def model_name(self) -> str: ...
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ This is an abstract property that must be implemented by all subclasses.
+ Typically returns: self.model.config.__dict__
+
+ Returns:
+ Dict: The configuration dictionary of the underlying model
+ """
+ pass
@abstractmethod
def export(self, export_dir: Optional[str] = None) -> Path:
@@ -184,11 +205,11 @@ def _export(
example_inputs: Dict[str, torch.Tensor],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
- export_kwargs: Optional[Dict[str, any]] = None,
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
- use_onnx_subfunctions: bool = False,
+ prefill_only: Optional[bool] = False,
+ **export_kwargs,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms
@@ -213,11 +234,16 @@ def _export(
instance using from_pretrained() for re-export.
"""
+ # TODO: Hack for retain_full_kv, handle this outside
+ export_kwargs.pop("retain_full_kv", None)
onnx_path = export_dir / f"{self.model_name}.onnx"
# Return early if ONNX already exists
if onnx_path.is_file():
- self.onnx_path = onnx_path
+ if prefill_only:
+ self.prefill_onnx_path = onnx_path
+ else:
+ self.onnx_path = onnx_path
return onnx_path
# check if the model is in meta state or weights are offloaded
@@ -253,19 +279,6 @@ def _export(
input_names.append(param)
try:
- # Initialize the registry with your custom ops
- export_kwargs = {} if export_kwargs is None else export_kwargs
- if use_onnx_subfunctions:
- warnings.warn(
- "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results."
- )
- apply_torch_patches()
- InvalidIndexProvider.SUBFUNC_ENABLED = True
- output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names]
- export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model)
- self._onnx_transforms.append(RenameFunctionOutputsTransform)
- self._onnx_transforms.append(CustomOpTransform)
-
torch.onnx.export(
self.model,
(example_inputs,),
@@ -309,15 +322,43 @@ def _export(
finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
- if use_onnx_subfunctions:
- undo_torch_patches()
- InvalidIndexProvider.SUBFUNC_ENABLED = False
- self._onnx_transforms.remove(CustomOpTransform)
- self._onnx_transforms.remove(RenameFunctionOutputsTransform)
-
- self.onnx_path = onnx_path
+ if prefill_only:
+ self.prefill_onnx_path = onnx_path
+ else:
+ self.onnx_path = onnx_path
+ logger.info("Model export is finished and saved at: %s", onnx_path)
return onnx_path
+ def get_onnx_path(
+ self,
+ prefill_only: Optional[bool] = False,
+ enable_chunking: Optional[bool] = False,
+ specializations: Optional[List[Dict[str, int]]] = None,
+ offload_pt_weights: Optional[bool] = True,
+ use_onnx_subfunctions: Optional[bool] = False,
+ retain_full_kv: Optional[bool] = False,
+ ):
+ kwargs = {
+ "offload_pt_weights": offload_pt_weights,
+ "use_onnx_subfunctions": use_onnx_subfunctions,
+ "retain_full_kv": retain_full_kv,
+ }
+ if prefill_only:
+ if self.prefill_onnx_path is None:
+ kwargs.update(
+ {
+ "prefill_only": prefill_only,
+ "prefill_seq_len": specializations[0].get("seq_len"),
+ "enable_chunking": enable_chunking,
+ }
+ )
+ self.export(**kwargs)
+ return self.prefill_onnx_path
+ else:
+ if self.onnx_path is None:
+ self.export(**kwargs)
+ return self.onnx_path
+
@dump_qconfig
def _compile(
self,
@@ -332,6 +373,10 @@ def _compile(
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
use_onnx_subfunctions: bool = False,
+ prefill_only: Optional[str] = None,
+ offload_pt_weights: Optional[bool] = True,
+ enable_chunking: Optional[bool] = False,
+ retain_full_kv: Optional[bool] = None,
**compiler_options,
) -> str:
"""
@@ -357,11 +402,18 @@ def _compile(
For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""
-
- if onnx_path is None and self.onnx_path is None:
- self.export(use_onnx_subfunctions=use_onnx_subfunctions)
-
- onnx_path = Path(onnx_path or self.onnx_path)
+ onnx_path = Path(
+ onnx_path
+ if onnx_path
+ else self.get_onnx_path(
+ prefill_only,
+ enable_chunking,
+ specializations,
+ offload_pt_weights,
+ use_onnx_subfunctions,
+ retain_full_kv,
+ )
+ )
compile_dir = Path(compile_dir or onnx_path.parent)
qpc_path = compile_dir / "qpc"
if not onnx_path.is_file():
@@ -423,6 +475,7 @@ def _compile(
"mdp_ts_num_devices": mdp_ts_num_devices,
"mdp_ts_json": mdp_ts_json,
"num_speculative_tokens": num_speculative_tokens,
+ "prefill_only": prefill_only,
}
compile_hash = hash_dict_params(compile_hash_params)
@@ -462,6 +515,16 @@ def _compile(
command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
+ if use_onnx_subfunctions:
+
+ class FeatureNotAvailableError(Exception):
+ pass
+
+ exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" {" ".join(command)}'
+ raise FeatureNotAvailableError(
+ "ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model."
+ + f"\nRun following command manually with assert compiler:\n{exec_command}"
+ )
try:
subprocess.run(command, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
@@ -482,5 +545,5 @@ def _compile(
logger.info("Hashed parameters exported successfully.")
self.qpc_path = qpc_path
-
+ logger.info("Model compilation is finished and saved at: %s", qpc_path)
return qpc_path
diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py
index 945850c50..16697cec9 100644
--- a/QEfficient/base/onnx_transforms.py
+++ b/QEfficient/base/onnx_transforms.py
@@ -19,16 +19,20 @@
from QEfficient.customop.ctx_scatter_gather import (
CtxGather,
CtxGather3D,
+ CtxGatherBlockedKV,
CtxGatherFunc,
CtxGatherFunc3D,
+ CtxGatherFuncBlockedKV,
CtxScatter,
CtxScatter3D,
CtxScatterFunc,
CtxScatterFunc3D,
)
from QEfficient.customop.ctx_scatter_gather_cb import (
+ CtxGatherBlockedKVCB,
CtxGatherCB,
CtxGatherCB3D,
+ CtxGatherFuncBlockedKVCB,
CtxGatherFuncCB,
CtxGatherFuncCB3D,
CtxScatterCB,
@@ -91,10 +95,12 @@ class CustomOpTransform(BaseOnnxTransform):
"CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D),
"CtxGatherFunc": (CtxGatherFunc, CtxGather),
"CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D),
- "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
"CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D),
- "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
"CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D),
+ "CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV),
+ "CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB),
+ "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
+ "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
}
@classmethod
diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py
index e503a057f..dd8d9c0b8 100644
--- a/QEfficient/base/pytorch_transforms.py
+++ b/QEfficient/base/pytorch_transforms.py
@@ -9,7 +9,9 @@
from torch import nn
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
class PytorchTransform:
diff --git a/QEfficient/cloud/export.py b/QEfficient/cloud/export.py
index a5e0b6e19..48d10c700 100644
--- a/QEfficient/cloud/export.py
+++ b/QEfficient/cloud/export.py
@@ -12,7 +12,9 @@
from QEfficient.base.common import QEFFCommonLoader
from QEfficient.utils import check_and_assign_cache_dir
from QEfficient.utils.custom_yaml import generate_custom_io
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
# Specifically for Docker images.
ROOT_DIR = os.path.dirname(os.path.abspath(""))
diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py
index 35ebbde32..3e67e3354 100644
--- a/QEfficient/cloud/finetune.py
+++ b/QEfficient/cloud/finetune.py
@@ -29,10 +29,12 @@
from QEfficient.finetune.utils.dataset_utils import get_dataloader, get_longest_seq_length
from QEfficient.finetune.utils.device_map import get_device_map
from QEfficient.finetune.utils.helper import Task_Mode, get_world_size
-from QEfficient.finetune.utils.logging_utils import logger
from QEfficient.finetune.utils.parser import get_finetune_parser
from QEfficient.finetune.utils.train_utils import print_model_size, print_trainable_parameters, train
from QEfficient.utils._utils import hf_download
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("FT", loglevel="INFO")
# Try importing QAIC-specific module, proceed without it if unavailable
try:
diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py
index ef05d29ab..3a5e20ca6 100644
--- a/QEfficient/cloud/infer.py
+++ b/QEfficient/cloud/infer.py
@@ -17,7 +17,9 @@
from QEfficient.base.common import QEFFCommonLoader
from QEfficient.utils import check_and_assign_cache_dir, load_hf_processor, load_hf_tokenizer
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
# TODO: Remove after adding support for VLM's compile and execute
diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py
index 5de21f876..bf83013f4 100644
--- a/QEfficient/compile/compile_helper.py
+++ b/QEfficient/compile/compile_helper.py
@@ -15,7 +15,9 @@
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.utils import constants
from QEfficient.utils._utils import load_json, load_yaml
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
def create_and_dump_specializations(
diff --git a/QEfficient/compile/qnn_compiler.py b/QEfficient/compile/qnn_compiler.py
index e2ec20364..9fcadb6d0 100644
--- a/QEfficient/compile/qnn_compiler.py
+++ b/QEfficient/compile/qnn_compiler.py
@@ -18,7 +18,9 @@
generate_qnn_specialization,
)
from QEfficient.utils.hash_utils import to_hashable
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
class QNN:
diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py
index c7dc8639a..7b15effe7 100644
--- a/QEfficient/customop/ctx_scatter_gather.py
+++ b/QEfficient/customop/ctx_scatter_gather.py
@@ -136,6 +136,7 @@ class CtxGatherFunc(torch.autograd.Function):
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
+ ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices)
return data[batch_indices, head_indices, ctx_indices]
@staticmethod
diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py
index 8a06bc2b1..c15b60810 100644
--- a/QEfficient/customop/ctx_scatter_gather_cb.py
+++ b/QEfficient/customop/ctx_scatter_gather_cb.py
@@ -126,6 +126,7 @@ class CtxGatherFuncCB(torch.autograd.Function):
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = batch_index.view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
+ ctx_indices = torch.where(ctx_indices >= data.shape[2], 0, ctx_indices)
return data[batch_indices, head_indices, ctx_indices]
@staticmethod
diff --git a/QEfficient/diffusers/README.md b/QEfficient/diffusers/README.md
new file mode 100644
index 000000000..40d45e984
--- /dev/null
+++ b/QEfficient/diffusers/README.md
@@ -0,0 +1,95 @@
+
+
+
+
+# **Diffusion Models on Qualcomm Cloud AI 100**
+
+
+
+
+### π¨ **Experience the Future of AI Image Generation**
+
+* Optimized for Qualcomm Cloud AI 100*
+
+

+
+**Generated with**: `black-forest-labs/FLUX.1-schnell` β’ `"A girl laughing"` β’ 4 steps β’ 0.0 guidance scale β’ β‘
+
+
+
+
+
+
+
+[](https://github.com/huggingface/diffusers)
+
+
+---
+
+## β¨ Overview
+
+QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware.
+
+## π οΈ Installation
+
+### Prerequisites
+
+Ensure you have Python 3.8+ and the required dependencies:
+
+```bash
+# Create Python virtual environment (Recommended Python 3.10)
+sudo apt install python3.10-venv
+python3.10 -m venv qeff_env
+source qeff_env/bin/activate
+pip install -U pip
+```
+
+### Install QEfficient
+
+```bash
+# Install from GitHub (includes diffusers support)
+pip install git+https://github.com/quic/efficient-transformers
+
+# Or build from source
+git clone https://github.com/quic/efficient-transformers.git
+cd efficient-transformers
+pip install build wheel
+python -m build --wheel --outdir dist
+pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl
+```
+
+---
+
+## π― Supported Models
+- β
[`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
+
+---
+
+
+## π Examples
+
+Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory:
+
+---
+
+## π€ Contributing
+
+We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details.
+
+
+
+---
+
+## π Acknowledgments
+
+- **HuggingFace Diffusers**: For the excellent foundation library
+- **Stability AI**: For the amazing Stable Diffusion models
+---
+
+## π Support
+
+- π **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/)
+- π **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues)
+
+---
+
diff --git a/QEfficient/diffusers/__init__.py b/QEfficient/diffusers/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/__init__.py b/QEfficient/diffusers/models/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/models/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/normalization.py b/QEfficient/diffusers/models/normalization.py
new file mode 100644
index 000000000..933832ed8
--- /dev/null
+++ b/QEfficient/diffusers/models/normalization.py
@@ -0,0 +1,40 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+from typing import Optional, Tuple
+
+import torch
+from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+
+
+class QEffAdaLayerNormZero(AdaLayerNormZero):
+ def forward(
+ self,
+ x: torch.Tensor,
+ shift_msa: Optional[torch.Tensor] = None,
+ scale_msa: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x
+
+
+class QEffAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
+ def forward(
+ self,
+ x: torch.Tensor,
+ scale_msa: Optional[torch.Tensor] = None,
+ shift_msa: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x
+
+
+class QEffAdaLayerNormContinuous(AdaLayerNormContinuous):
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
+ emb = conditioning_embedding
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py
new file mode 100644
index 000000000..d3c84ee63
--- /dev/null
+++ b/QEfficient/diffusers/models/pytorch_transforms.py
@@ -0,0 +1,56 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
+from diffusers.models.transformers.transformer_flux import (
+ FluxAttention,
+ FluxAttnProcessor,
+ FluxSingleTransformerBlock,
+ FluxTransformer2DModel,
+ FluxTransformerBlock,
+)
+from torch import nn
+
+from QEfficient.base.pytorch_transforms import ModuleMappingTransform
+from QEfficient.customop.rms_norm import CustomRMSNormAIC
+from QEfficient.diffusers.models.normalization import (
+ QEffAdaLayerNormContinuous,
+ QEffAdaLayerNormZero,
+ QEffAdaLayerNormZeroSingle,
+)
+from QEfficient.diffusers.models.transformers.transformer_flux import (
+ QEffFluxAttention,
+ QEffFluxAttnProcessor,
+ QEffFluxSingleTransformerBlock,
+ QEffFluxTransformer2DModel,
+ QEffFluxTransformerBlock,
+)
+
+
+class CustomOpsTransform(ModuleMappingTransform):
+ _module_mapping = {
+ RMSNorm: CustomRMSNormAIC,
+ nn.RMSNorm: CustomRMSNormAIC, # for torch.nn.RMSNorm
+ }
+
+
+class AttentionTransform(ModuleMappingTransform):
+ _module_mapping = {
+ FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock,
+ FluxTransformerBlock: QEffFluxTransformerBlock,
+ FluxTransformer2DModel: QEffFluxTransformer2DModel,
+ FluxAttention: QEffFluxAttention,
+ FluxAttnProcessor: QEffFluxAttnProcessor,
+ }
+
+
+class NormalizationTransform(ModuleMappingTransform):
+ _module_mapping = {
+ AdaLayerNormZero: QEffAdaLayerNormZero,
+ AdaLayerNormZeroSingle: QEffAdaLayerNormZeroSingle,
+ AdaLayerNormContinuous: QEffAdaLayerNormContinuous,
+ }
diff --git a/QEfficient/diffusers/models/transformers/__init__.py b/QEfficient/diffusers/models/transformers/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/models/transformers/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/transformers/transformer_flux.py b/QEfficient/diffusers/models/transformers/transformer_flux.py
new file mode 100644
index 000000000..ec80543f4
--- /dev/null
+++ b/QEfficient/diffusers/models/transformers/transformer_flux.py
@@ -0,0 +1,329 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers.models.attention_dispatch import dispatch_attention_fn
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.transformers.transformer_flux import (
+ FluxAttention,
+ FluxAttnProcessor,
+ FluxSingleTransformerBlock,
+ FluxTransformer2DModel,
+ FluxTransformerBlock,
+ _get_qkv_projections,
+)
+
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
+
+
+def qeff_apply_rotary_emb(
+ x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, :, None, :]
+ sin = sin[None, :, None, :]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+ B, S, H, D = x.shape
+ x_real, x_imag = x.reshape(B, -1, H, D // 2, 2).unbind(-1)
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+ return out
+
+
+class QEffFluxAttnProcessor(FluxAttnProcessor):
+ _attention_backend = None
+ _parallel_config = None
+
+ def __call__(
+ self,
+ attn: "QEffFluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = qeff_apply_rotary_emb(query, image_rotary_emb)
+ key = qeff_apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = dispatch_attention_fn(
+ query, key, value, attn_mask=attention_mask, backend=self._attention_backend
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class QEffFluxAttention(FluxAttention):
+ def __qeff_init__(self):
+ processor = QEffFluxAttnProcessor()
+ self.processor = processor
+
+
+class QEffFluxSingleTransformerBlock(FluxSingleTransformerBlock):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ shift_msa, scale_msa, gate = torch.split(temb, 1)
+ residual = hidden_states
+ norm_hidden_states = self.norm(hidden_states, scale_msa, shift_msa)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ # if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(torch.finfo(torch.float32).min, torch.finfo(torch.float32).max)
+
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
+
+
+class QEffFluxTransformerBlock(FluxTransformerBlock):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ temb1 = tuple(torch.split(temb[:6], 1))
+ temb2 = tuple(torch.split(temb[6:], 1))
+ norm_hidden_states = self.norm1(hidden_states, shift_msa=temb1[0], scale_msa=temb1[1])
+ gate_msa, shift_mlp, scale_mlp, gate_mlp = temb1[-4:]
+
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, shift_msa=temb2[0], scale_msa=temb2[1])
+
+ c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = temb2[-4:]
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ # if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class QEffFluxTransformer2DModel(FluxTransformer2DModel):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ adaln_emb: torch.Tensor = None,
+ adaln_single_emb: torch.Tensor = None,
+ adaln_out: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ return_dict: bool = True,
+ controlnet_blocks_repeat: bool = False,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=adaln_emb[index_block],
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ # For Xlabs ControlNet.
+ if controlnet_blocks_repeat:
+ hidden_states = (
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
+ )
+ else:
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=adaln_single_emb[index_block],
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_single_block_samples is not None:
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
+
+ hidden_states = self.norm_out(hidden_states, adaln_out)
+ output = self.proj_out(hidden_states)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/QEfficient/diffusers/pipelines/__init__.py b/QEfficient/diffusers/pipelines/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/configs/flux_config.json b/QEfficient/diffusers/pipelines/configs/flux_config.json
new file mode 100644
index 000000000..73b92265f
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/configs/flux_config.json
@@ -0,0 +1,99 @@
+{
+ "description": "Default configuration for Flux pipeline",
+
+ "modules":
+ {
+ "text_encoder":
+ {
+ "specializations":{
+ "batch_size": 1,
+ "seq_len": 77
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+
+ },
+ "text_encoder_2":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only": true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "transformer":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 4,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts-mos": 1,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "vae_decoder":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "channels": 16
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "aic-enable-depth-first": true,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ }
+ }
+}
diff --git a/QEfficient/diffusers/pipelines/flux/__init__.py b/QEfficient/diffusers/pipelines/flux/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/flux/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py
new file mode 100644
index 000000000..8a7620a46
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py
@@ -0,0 +1,856 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+# TODO: Pipeline Architecture Improvements
+# 1. Introduce QEffDiffusionPipeline base class to provide unified export, compile,
+# and inference APIs across all diffusion pipelines, promoting code reusability
+# and consistent interface design.
+# 2. Implement persistent QPC session management strategy to retain/drop compiled model
+# sessions in memory across all pipeline modules.
+
+import os
+import time
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from diffusers import FluxPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+from tqdm import tqdm
+
+from QEfficient.diffusers.pipelines.pipeline_module import (
+ QEffFluxTransformerModel,
+ QEffTextEncoder,
+ QEffVAE,
+)
+from QEfficient.diffusers.pipelines.pipeline_utils import (
+ ONNX_SUBFUNCTION_MODULE,
+ ModulePerf,
+ QEffPipelineOutput,
+ calculate_compressed_latent_dimension,
+ compile_modules_parallel,
+ compile_modules_sequential,
+ config_manager,
+ set_module_device_ids,
+)
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
+
+
+class QEffFluxPipeline:
+ """
+ QEfficient-optimized Flux pipeline for high-performance text-to-image generation on Qualcomm AI hardware.
+
+ This pipeline provides an optimized implementation of the Flux diffusion model specifically designed
+ for deployment on Qualcomm AI Cloud (QAIC) devices. It wraps the original HuggingFace Flux model
+ components with QEfficient-optimized versions that can be exported to ONNX format and compiled
+ into Qualcomm Program Container (QPC) files for efficient inference.
+
+ The pipeline supports the complete Flux workflow including:
+ - Dual text encoding with CLIP and T5 encoders
+ - Transformer-based denoising with adaptive layer normalization
+ - VAE decoding for final image generation
+ - Performance monitoring and optimization
+
+ Attributes:
+ text_encoder (QEffTextEncoder): Optimized CLIP text encoder for pooled embeddings
+ text_encoder_2 (QEffTextEncoder): Optimized T5 text encoder for sequence embeddings
+ transformer (QEffFluxTransformerModel): Optimized Flux transformer for denoising
+ vae_decode (QEffVAE): Optimized VAE decoder for latent-to-image conversion
+ modules (Dict[str, Any]): Dictionary of all pipeline modules for batch operations
+ model (FluxPipeline): Original HuggingFace Flux model reference
+ tokenizer: CLIP tokenizer for text preprocessing
+ scheduler: Diffusion scheduler for timestep management
+
+ Example:
+ >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> images = pipeline(
+ ... prompt="A beautiful sunset over mountains",
+ ... height=512,
+ ... width=512,
+ ... num_inference_steps=28
+ ... )
+ >>> images.images[0].save("generated_image.png")
+ """
+
+ _hf_auto_class = FluxPipeline
+
+ def __init__(self, model, *args, **kwargs):
+ """
+ Initialize the QEfficient Flux pipeline.
+
+ This pipeline provides an optimized implementation of the Flux text-to-image model
+ for deployment on Qualcomm AI hardware. It wraps the original HuggingFace Flux model
+ components with QEfficient-optimized versions that can be exported to ONNX and compiled
+ for QAIC devices.
+
+ Args:
+ model: Pre-loaded FluxPipeline model
+ **kwargs: Additional arguments including height and width
+ """
+
+ # Wrap model components with QEfficient optimized versions
+ self.model = model
+ self.text_encoder = QEffTextEncoder(model.text_encoder)
+ self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2)
+ self.transformer = QEffFluxTransformerModel(model.transformer)
+ self.vae_decode = QEffVAE(model.vae, "decoder")
+
+ # Store all modules in a dictionary for easy iteration during export/compile
+ self.modules = {
+ "text_encoder": self.text_encoder,
+ "text_encoder_2": self.text_encoder_2,
+ "transformer": self.transformer,
+ "vae_decoder": self.vae_decode,
+ }
+
+ # Copy tokenizers and scheduler from the original model
+ self.tokenizer = model.tokenizer
+ self.text_encoder.tokenizer = model.tokenizer
+ self.text_encoder_2.tokenizer = model.tokenizer_2
+ self.tokenizer_max_length = model.tokenizer_max_length
+ self.scheduler = model.scheduler
+
+ # Override VAE forward method to use decode directly
+ self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode(
+ latent_sample, return_dict
+ )
+
+ # Sync max position embeddings between text encoders
+ self.text_encoder_2.model.config.max_position_embeddings = (
+ self.text_encoder.model.config.max_position_embeddings
+ )
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ **kwargs,
+ ):
+ """
+ Load a pretrained Flux model from HuggingFace Hub or local path and wrap it with QEfficient optimizations.
+
+ This class method provides a convenient way to instantiate a QEffFluxPipeline from a pretrained
+ Flux model. It automatically loads the base FluxPipeline model in float32 precision on CPU
+ and wraps all components with QEfficient-optimized versions for QAIC deployment.
+
+ Args:
+ pretrained_model_name_or_path (str or os.PathLike): Either a HuggingFace model identifier
+ (e.g., "black-forest-labs/FLUX.1-schnell") or a local path to a saved model directory.
+ **kwargs: Additional keyword arguments passed to FluxPipeline.from_pretrained().
+
+ Returns:
+ QEffFluxPipeline: A fully initialized pipeline instance with QEfficient-optimized components
+ ready for export, compilation, and inference on QAIC devices.
+
+ Raises:
+ ValueError: If the model path is invalid or model cannot be loaded
+ OSError: If there are issues accessing the model files
+ RuntimeError: If model initialization fails
+
+ Example:
+ >>> # Load from HuggingFace Hub
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>>
+ >>> # Load from local path
+ >>> pipeline = QEffFluxPipeline.from_pretrained("/path/to/local/flux/model")
+ >>>
+ >>> # Load with custom cache directory
+ >>> pipeline = QEffFluxPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-dev",
+ ... cache_dir="/custom/cache/dir"
+ ... )
+ """
+ # Load the base Flux model in float32 on CPU
+ model = cls._hf_auto_class.from_pretrained(
+ pretrained_model_name_or_path,
+ torch_dtype=torch.float32,
+ device_map="cpu",
+ **kwargs,
+ )
+
+ return cls(
+ model=model,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ **kwargs,
+ )
+
+ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ """
+ Export all pipeline modules to ONNX format for deployment preparation.
+
+ This method systematically exports each pipeline component (CLIP text encoder, T5 text encoder,
+ Flux transformer, and VAE decoder) to ONNX format. Each module is exported with its specific
+ configuration including dynamic axes, input/output specifications, and optimization settings.
+
+ The export process prepares the models for subsequent compilation to QPC format, enabling
+ efficient inference on QAIC hardware. ONNX subfunctions can be used for certain modules
+ to optimize memory usage and performance.
+
+ Args:
+ export_dir (str, optional): Target directory for saving ONNX model files. If None,
+ uses the default export directory structure based on model name and configuration.
+ The directory will be created if it doesn't exist.
+ use_onnx_subfunctions (bool, default=False): Whether to enable ONNX subfunction
+ optimization for supported modules. This can optimize thegraph and
+ improve compilation efficiency for models like the transformer.
+
+ Returns:
+ str: Absolute path to the export directory containing all ONNX model files.
+ Each module will have its own subdirectory with the exported ONNX file.
+
+ Raises:
+ RuntimeError: If ONNX export fails for any module
+ OSError: If there are issues creating the export directory or writing files
+ ValueError: If module configurations are invalid
+
+ Note:
+ - All models are exported in float32 precision for maximum compatibility
+ - Dynamic axes are configured to support variable batch sizes and sequence lengths
+ - The export process may take several minutes depending on model size
+ - Exported ONNX files can be large (several GB for complete pipeline)
+
+ Example:
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> export_path = pipeline.export(
+ ... export_dir="/path/to/export",
+ ... use_onnx_subfunctions=True
+ ... )
+ >>> print(f"Models exported to: {export_path}")
+ """
+ for module_name, module_obj in tqdm(self.modules.items(), desc="Exporting modules", unit="module"):
+ # Get ONNX export configuration for this module
+ example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params()
+
+ export_params = {
+ "inputs": example_inputs,
+ "output_names": output_names,
+ "dynamic_axes": dynamic_axes,
+ "export_dir": export_dir,
+ }
+
+ if use_onnx_subfunctions and module_name in ONNX_SUBFUNCTION_MODULE:
+ export_params["use_onnx_subfunctions"] = True
+
+ module_obj.export(**export_params)
+
+ @staticmethod
+ def get_default_config_path() -> str:
+ """
+ Get the absolute path to the default Flux pipeline configuration file.
+
+ Returns:
+ str: Absolute path to the flux_config.json file containing default pipeline
+ configuration settings for compilation and device allocation.
+ """
+ return "QEfficient/diffusers/pipelines/configs/flux_config.json"
+
+ def compile(
+ self,
+ compile_config: Optional[str] = None,
+ parallel: bool = False,
+ height: int = 512,
+ width: int = 512,
+ use_onnx_subfunctions: bool = False,
+ ) -> None:
+ """
+ Compile ONNX models into optimized QPC format for deployment on Qualcomm AI hardware.
+
+ Args:
+ compile_config (str, optional): Path to a JSON configuration file containing
+ compilation settings, device mappings, and optimization parameters. If None,
+ uses the default configuration from get_default_config_path().
+ parallel (bool, default=False): Compilation mode selection:
+ - True: Compile modules in parallel using ThreadPoolExecutor for faster processing
+ - False: Compile modules sequentially for lower resource usage
+ height (int, default=512): Target image height in pixels.
+ width (int, default=512): Target image width in pixels.
+ use_onnx_subfunctions (bool, default=False): Whether to export models with ONNX
+ subfunctions before compilation.
+
+ Raises:
+ RuntimeError: If compilation fails for any module or if QAIC compiler is not available
+ FileNotFoundError: If ONNX models haven't been exported or config file is missing
+ ValueError: If configuration parameters are invalid
+ OSError: If there are issues with file I/O during compilation
+
+ Example:
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> # Sequential compilation with default config
+ >>> pipeline.compile(height=1024, width=1024)
+ >>>
+ >>> # Parallel compilation with custom config
+ >>> pipeline.compile(
+ ... compile_config="/path/to/custom_config.json",
+ ... parallel=True,
+ ... height=512,
+ ... width=512
+ ... )
+ """
+ # Ensure all modules are exported to ONNX before compilation
+ if any(
+ path is None
+ for path in [
+ self.text_encoder.onnx_path,
+ self.text_encoder_2.onnx_path,
+ self.transformer.onnx_path,
+ self.vae_decode.onnx_path,
+ ]
+ ):
+ self.export(use_onnx_subfunctions=use_onnx_subfunctions)
+
+ # Load compilation configuration
+ config_manager(self, config_source=compile_config)
+
+ # Calculate compressed latent dimension using utility function
+ cl, latent_height, latent_width = calculate_compressed_latent_dimension(
+ height, width, self.model.vae_scale_factor
+ )
+
+ # Prepare dynamic specialization updates based on image dimensions
+ specialization_updates = {
+ "transformer": {"cl": cl},
+ "vae_decoder": {
+ "latent_height": latent_height,
+ "latent_width": latent_width,
+ },
+ }
+
+ # Use generic utility functions for compilation
+ if parallel:
+ compile_modules_parallel(self.modules, self.custom_config, specialization_updates)
+ else:
+ compile_modules_sequential(self.modules, self.custom_config, specialization_updates)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device_ids: Optional[List[int]] = None,
+ ):
+ """
+ Encode text prompts using the T5 text encoder for detailed semantic understanding.
+
+ T5 provides rich sequence embeddings that capture fine-grained text details,
+ complementing CLIP's global representation in Flux's dual encoder setup.
+
+ Args:
+ prompt (str or List[str]): Input prompt(s) to encode
+ num_images_per_prompt (int): Number of images to generate per prompt
+ max_sequence_length (int): Maximum token sequence length (default: 512)
+ device_ids (List[int], optional): QAIC device IDs for inference
+
+ Returns:
+ tuple: (prompt_embeds, inference_time)
+ - prompt_embeds (torch.Tensor): Encoded embeddings [batch*num_images, seq_len, 4096]
+ - inference_time (float): T5 encoder inference time in seconds
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ # Tokenize prompts with padding and truncation
+ text_inputs = self.text_encoder_2.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ # Check for truncation and warn user
+ untruncated_ids = self.text_encoder_2.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.text_encoder_2.tokenizer.batch_decode(
+ untruncated_ids[:, self.text_encoder_2.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ f"The following part of your input was truncated because `max_sequence_length` is set to "
+ f"{self.text_encoder_2.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ # Initialize QAIC inference session if not already created
+ if self.text_encoder_2.qpc_session is None:
+ self.text_encoder_2.qpc_session = QAICInferenceSession(
+ str(self.text_encoder_2.qpc_path), device_ids=device_ids
+ )
+
+ # Allocate output buffers for QAIC inference
+ text_encoder_2_output = {
+ "last_hidden_state": np.random.rand(
+ batch_size, max_sequence_length, self.text_encoder_2.model.config.d_model
+ ).astype(np.int32),
+ }
+ self.text_encoder_2.qpc_session.set_buffers(text_encoder_2_output)
+
+ # Prepare input for QAIC inference
+ aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
+
+ # Run T5 encoder inference and measure time
+ start_t5_time = time.perf_counter()
+ prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"])
+ end_t5_time = time.perf_counter()
+ text_encoder_2_perf = end_t5_time - start_t5_time
+
+ # Duplicate embeddings for multiple images per prompt
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, text_encoder_2_perf
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device_ids: Optional[List[int]] = None,
+ ):
+ """
+ Encode text prompts using the CLIP text encoder for global semantic representation.
+
+ CLIP provides pooled embeddings that capture high-level semantic meaning,
+ working alongside T5's detailed sequence embeddings in Flux's dual encoder setup.
+
+ Args:
+ prompt (str or List[str]): Input prompt(s) to encode
+ num_images_per_prompt (int): Number of images to generate per prompt
+ device_ids (List[int], optional): QAIC device IDs for inference
+
+ Returns:
+ tuple: (pooled_prompt_embeds, inference_time)
+ - pooled_prompt_embeds (torch.Tensor): Pooled embeddings [batch*num_images, 768]
+ - inference_time (float): CLIP encoder inference time in seconds
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ # Tokenize prompts
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+
+ # Check for truncation and warn user
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ f"The following part of your input was truncated because CLIP can only handle sequences up to "
+ f"{self.tokenizer_max_length} tokens: {removed_text}"
+ )
+
+ # Initialize QAIC inference session if not already created
+ if self.text_encoder.qpc_session is None:
+ self.text_encoder.qpc_session = QAICInferenceSession(str(self.text_encoder.qpc_path), device_ids=device_ids)
+
+ # Allocate output buffers for QAIC inference
+ text_encoder_output = {
+ "last_hidden_state": np.random.rand(
+ batch_size, self.tokenizer_max_length, self.text_encoder.model.config.hidden_size
+ ).astype(np.float32),
+ "pooler_output": np.random.rand(batch_size, self.text_encoder.model.config.hidden_size).astype(np.int32),
+ }
+ self.text_encoder.qpc_session.set_buffers(text_encoder_output)
+
+ # Prepare input for QAIC inference
+ aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
+
+ # Run CLIP encoder inference and measure time
+ start_text_encoder_time = time.perf_counter()
+ aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input)
+ end_text_encoder_time = time.perf_counter()
+ text_encoder_perf = end_text_encoder_time - start_text_encoder_time
+ # Extract pooled output (used for conditioning in Flux)
+ prompt_embeds = torch.tensor(aic_embeddings["pooler_output"])
+
+ # Duplicate embeddings for multiple images per prompt
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds, text_encoder_perf
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ ):
+ """
+ Encode text prompts using Flux's dual text encoder architecture.
+
+ Flux employs both CLIP and T5 encoders for comprehensive text understanding:
+ - CLIP provides pooled embeddings for global semantic conditioning
+ - T5 provides detailed sequence embeddings for fine-grained text control
+
+ Args:
+ prompt (str or List[str]): Primary prompt(s) for both encoders
+ prompt_2 (str or List[str], optional): Secondary prompt(s) for T5. If None, uses primary prompt
+ num_images_per_prompt (int): Number of images to generate per prompt
+ prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 embeddings
+ pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings
+ max_sequence_length (int): Maximum sequence length for T5 tokenization
+
+ Returns:
+ tuple: (prompt_embeds, pooled_prompt_embeds, text_ids, encoder_perf_times)
+ - prompt_embeds (torch.Tensor): T5 sequence embeddings [batch*num_images, seq_len, 4096]
+ - pooled_prompt_embeds (torch.Tensor): CLIP pooled embeddings [batch*num_images, 768]
+ - text_ids (torch.Tensor): Position IDs for text tokens [seq_len, 3]
+ - encoder_perf_times (List[float]): Performance times [CLIP_time, T5_time]
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ # Use primary prompt for both encoders if secondary not provided
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # Encode with CLIP (returns pooled embeddings)
+ pooled_prompt_embeds, text_encoder_perf = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device_ids=self.text_encoder.device_ids,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ # Encode with T5 (returns sequence embeddings)
+ prompt_embeds, text_encoder_2_perf = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device_ids=self.text_encoder_2.device_ids,
+ )
+
+ # Create text position IDs (required by Flux transformer)
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids, [text_encoder_perf, text_encoder_2_perf]
+
+ def __call__(
+ self,
+ height: int = 512,
+ width: int = 512,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ custom_config_path: Optional[str] = None,
+ parallel_compile: bool = False,
+ use_onnx_subfunctions: bool = False,
+ ):
+ """
+ Generate images from text prompts using the QEfficient-optimized Flux pipeline on QAIC hardware.
+
+ This is the main entry point for text-to-image generation. It orchestrates the complete Flux
+ diffusion pipeline optimized for Qualcomm AI Cloud devices.
+
+ Args:
+ height (int, optional): Target image height in pixels. Must be divisible by 8. Default: 512.
+ width (int, optional): Target image width in pixels. Must be divisible by 8. Default: 512.
+ prompt (str or List[str]): Primary text prompt(s) describing the desired image(s).
+ Required unless `prompt_embeds` is provided.
+ prompt_2 (str or List[str], optional): Secondary prompt for T5 encoder. If None, uses `prompt`.
+ negative_prompt (str or List[str], optional): Negative prompt(s) describing what to avoid.
+ Only used when `true_cfg_scale > 1.0`.
+ negative_prompt_2 (str or List[str], optional): Secondary negative prompt for T5. If None, uses `negative_prompt`.
+ true_cfg_scale (float, optional): True classifier-free guidance scale. Values > 1.0 enable
+ negative prompting. Default: 1.0 (disabled).
+ num_inference_steps (int, optional): Number of denoising steps. Default: 28.
+ timesteps (List[int], optional): Custom timestep schedule. If provided, overrides `num_inference_steps`.
+ guidance_scale (float, optional): Guidance scale for classifier-free guidance. Default: 3.5.
+ num_images_per_prompt (int, optional): Number of images to generate per prompt. Default: 1.
+ generator (torch.Generator or List[torch.Generator], optional): Random generator for reproducibility.
+ latents (torch.FloatTensor, optional): Pre-generated latent tensors. If None, random latents are generated.
+ prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 text embeddings. Shape: [batch, seq_len, 4096].
+ pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings. Shape: [batch, 768].
+ negative_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative T5 embeddings.
+ negative_pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative CLIP embeddings.
+ output_type (str, optional): Output format. Options: "pil" (default), "np", or "latent".
+ callback_on_step_end (Callable, optional): Callback function executed after each denoising step.
+ callback_on_step_end_tensor_inputs (List[str], optional): Tensor names to pass to callback. Default: ["latents"].
+ max_sequence_length (int, optional): Maximum token sequence length for T5 encoder. Default: 512.
+ custom_config_path (str, optional): Path to custom JSON configuration file for compilation settings.
+ parallel_compile (bool, optional): Whether to compile modules in parallel. Default: False.
+ use_onnx_subfunctions (bool, optional): Whether to export transformer blocks as ONNX subfunctions. Default: False.
+
+ Returns:
+ QEffPipelineOutput: A dataclass containing:
+ - images: Generated image(s) in the format specified by `output_type`
+ - pipeline_module: Performance metrics for each pipeline component (text encoders, transformer, VAE)
+
+ Raises:
+ ValueError: If input validation fails or parameters are incompatible.
+ RuntimeError: If compilation fails or QAIC devices are unavailable.
+ FileNotFoundError: If custom config file is specified but not found.
+
+ Example:
+ >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> result = pipeline(
+ ... prompt="A serene mountain landscape at sunset",
+ ... height=1024,
+ ... width=1024,
+ ... num_inference_steps=28,
+ ... guidance_scale=7.5
+ ... )
+ >>> result.images[0].save("mountain_sunset.png")
+ >>> print(f"Transformer inference time: {sum(result.pipeline_module[2].perf):.2f}s")
+ """
+ device = self.model._execution_device
+
+ if height is None or width is None:
+ logger.warning("Height or width is None. Setting default values of 512 for both dimensions.")
+
+ self.compile(
+ compile_config=custom_config_path,
+ parallel=parallel_compile,
+ height=height,
+ width=width,
+ use_onnx_subfunctions=use_onnx_subfunctions,
+ )
+
+ # Set device IDs for all modules based on configuration
+ set_module_device_ids(self)
+
+ # Validate all inputs
+ self.model.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ # Step 2: Determine batch size from inputs
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Step 3: Encode prompts with both text encoders
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+
+ (prompt_embeds, pooled_prompt_embeds, text_ids, text_encoder_perf) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # Encode negative prompts if using true classifier-free guidance
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # Step 4: Prepare timesteps for denoising
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # Step 5: Prepare initial latents
+ num_channels_latents = self.transformer.model.config.in_channels // 4
+ latents, latent_image_ids = self.model.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # Step 6: Calculate compressed latent dimension for transformer buffer allocation
+ cl, _, _ = calculate_compressed_latent_dimension(height, width, self.model.vae_scale_factor)
+
+ # Initialize transformer inference session
+ if self.transformer.qpc_session is None:
+ self.transformer.qpc_session = QAICInferenceSession(
+ str(self.transformer.qpc_path), device_ids=self.transformer.device_ids
+ )
+
+ # Allocate output buffer for transformer
+ output_buffer = {
+ "output": np.random.rand(batch_size, cl, self.transformer.model.config.in_channels).astype(np.float32),
+ }
+ self.transformer.qpc_session.set_buffers(output_buffer)
+
+ transformer_perf = []
+ self.scheduler.set_begin_index(0)
+
+ # Step 7: Denoising loop
+ with self.model.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Prepare timestep embedding
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ temb = self.transformer.model.time_text_embed(timestep, pooled_prompt_embeds)
+
+ # Compute AdaLN (Adaptive Layer Normalization) embeddings for dual transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(self.transformer.model.transformer_blocks)):
+ block = self.transformer.model.transformer_blocks[block_idx]
+ # Process through norm1 and norm1_context
+ f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1)
+ f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1)
+ adaln_emb.append(torch.cat(list(f1) + list(f2)))
+ adaln_dual_emb = torch.stack(adaln_emb)
+
+ # Compute AdaLN embeddings for single transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(self.transformer.model.single_transformer_blocks)):
+ block = self.transformer.model.single_transformer_blocks[block_idx]
+ f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1)
+ adaln_emb.append(torch.cat(list(f1)))
+ adaln_single_emb = torch.stack(adaln_emb)
+
+ # Compute output AdaLN embedding
+ temp = self.transformer.model.norm_out
+ adaln_out = temp.linear(temp.silu(temb))
+
+ # Normalize timestep to [0, 1] range
+ timestep = timestep / 1000
+
+ # Prepare all inputs for transformer inference
+ inputs_aic = {
+ "hidden_states": latents.detach().numpy(),
+ "encoder_hidden_states": prompt_embeds.detach().numpy(),
+ "pooled_projections": pooled_prompt_embeds.detach().numpy(),
+ "timestep": timestep.detach().numpy(),
+ "img_ids": latent_image_ids.detach().numpy(),
+ "txt_ids": text_ids.detach().numpy(),
+ "adaln_emb": adaln_dual_emb.detach().numpy(),
+ "adaln_single_emb": adaln_single_emb.detach().numpy(),
+ "adaln_out": adaln_out.detach().numpy(),
+ }
+
+ # Run transformer inference and measure time
+ start_transformer_step_time = time.perf_counter()
+ outputs = self.transformer.qpc_session.run(inputs_aic)
+ end_transformer_step_time = time.perf_counter()
+ transformer_perf.append(end_transformer_step_time - start_transformer_step_time)
+
+ noise_pred = torch.from_numpy(outputs["output"])
+
+ # Update latents using scheduler (x_t -> x_t-1)
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Handle dtype mismatch (workaround for MPS backend bug)
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ latents = latents.to(latents_dtype)
+
+ # Execute callback if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # Update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # Step 8: Decode latents to images (unless output_type is "latent")
+ if output_type == "latent":
+ image = latents
+ else:
+ # Unpack and denormalize latents
+ latents = self.model._unpack_latents(latents, height, width, self.model.vae_scale_factor)
+ latents = (latents / self.vae_decode.model.scaling_factor) + self.vae_decode.model.shift_factor
+
+ # Initialize VAE decoder inference session
+ if self.vae_decode.qpc_session is None:
+ self.vae_decode.qpc_session = QAICInferenceSession(
+ str(self.vae_decode.qpc_path), device_ids=self.vae_decode.device_ids
+ )
+
+ # Allocate output buffer for VAE decoder
+ output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.int32)}
+ self.vae_decode.qpc_session.set_buffers(output_buffer)
+
+ # Run VAE decoder inference and measure time
+ inputs = {"latent_sample": latents.numpy()}
+ start_decode_time = time.perf_counter()
+ image = self.vae_decode.qpc_session.run(inputs)
+ end_decode_time = time.perf_counter()
+ vae_decode_perf = end_decode_time - start_decode_time
+
+ # Post-process image
+ image_tensor = torch.from_numpy(image["sample"])
+ image = self.model.image_processor.postprocess(image_tensor, output_type=output_type)
+
+ # Build performance metrics
+ perf_metrics = [
+ ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]),
+ ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]),
+ ModulePerf(module_name="transformer", perf=transformer_perf),
+ ModulePerf(module_name="vae_decoder", perf=vae_decode_perf),
+ ]
+
+ return QEffPipelineOutput(
+ pipeline_module=perf_metrics,
+ images=image,
+ )
diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py
new file mode 100644
index 000000000..6d9243fdc
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/pipeline_module.py
@@ -0,0 +1,481 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+from typing import Dict, List, Tuple
+
+import torch
+import torch.nn as nn
+
+from QEfficient.base.modeling_qeff import QEFFBaseModel
+from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
+from QEfficient.diffusers.models.pytorch_transforms import (
+ AttentionTransform,
+ CustomOpsTransform,
+ NormalizationTransform,
+)
+from QEfficient.diffusers.models.transformers.transformer_flux import (
+ QEffFluxSingleTransformerBlock,
+ QEffFluxTransformerBlock,
+)
+from QEfficient.transformers.models.pytorch_transforms import (
+ T5ModelTransform,
+)
+from QEfficient.utils import constants
+
+
+class QEffTextEncoder(QEFFBaseModel):
+ """
+ Wrapper for text encoder models with ONNX export and QAIC compilation capabilities.
+
+ This class handles text encoder models (CLIP, T5) with specific transformations and
+ optimizations for efficient inference on Qualcomm AI hardware. It applies custom
+ PyTorch and ONNX transformations to prepare models for deployment.
+
+ Attributes:
+ model (nn.Module): The wrapped text encoder model (deep copy of original)
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform, T5ModelTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying text encoder model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the text encoder wrapper.
+
+ Args:
+ model (nn.Module): The text encoder model to wrap (CLIP or T5)
+ """
+ super().__init__(model)
+ self.model = model
+
+ def get_onnx_params(self) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the text encoder.
+
+ Creates example inputs, dynamic axes specifications, and output names
+ tailored to the specific text encoder type (CLIP vs T5).
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+
+ # Create example input with max sequence length
+ example_inputs = {
+ "input_ids": torch.zeros((bs, self.model.config.max_position_embeddings), dtype=torch.int64),
+ }
+
+ # Define which dimensions can vary at runtime
+ dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}}
+
+ # T5 only outputs hidden states, CLIP outputs both hidden states and pooled output
+ if self.model.__class__.__name__ == "T5EncoderModel":
+ output_names = ["last_hidden_state"]
+ else:
+ output_names = ["last_hidden_state", "pooler_output"]
+ example_inputs["output_hidden_states"] = False
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = {},
+ ) -> str:
+ """
+ Export the text encoder model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ **export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations)
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffUNet(QEFFBaseModel):
+ """
+ Wrapper for UNet models with ONNX export and QAIC compilation capabilities.
+
+ This class handles UNet models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. UNet is commonly used in
+ diffusion models for image generation tasks.
+
+ Attributes:
+ model (nn.Module): The wrapped UNet model
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying UNet model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the UNet wrapper.
+
+ Args:
+ model (nn.Module): The pipeline model containing the UNet
+ """
+ super().__init__(model.unet)
+ self.model = model.unet
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = {},
+ ) -> str:
+ """
+ Export the UNet model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ **export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffVAE(QEFFBaseModel):
+ """
+ Wrapper for Variational Autoencoder (VAE) models with ONNX export and QAIC compilation.
+
+ This class handles VAE models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. VAE models are used in diffusion
+ pipelines for encoding images to latent space and decoding latents back to images.
+
+ Attributes:
+ model (nn.Module): The wrapped VAE model (deep copy of original)
+ type (str): VAE operation type ("encoder" or "decoder")
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying VAE model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module, type: str) -> None:
+ """
+ Initialize the VAE wrapper.
+
+ Args:
+ model (nn.Module): The pipeline model containing the VAE
+ type (str): VAE operation type ("encoder" or "decoder")
+ """
+ super().__init__(model)
+ self.model = model
+
+ # To have different hashing for encoder/decoder
+ self.model.config["type"] = type
+
+ def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the VAE decoder.
+
+ Args:
+ latent_height (int): Height of latent representation (default: 32)
+ latent_width (int): Width of latent representation (default: 32)
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+
+ # VAE decoder takes latent representation as input
+ example_inputs = {
+ "latent_sample": torch.randn(bs, 16, latent_height, latent_width),
+ "return_dict": False,
+ }
+
+ output_names = ["sample"]
+
+ # All dimensions except channels can be dynamic
+ dynamic_axes = {
+ "latent_sample": {0: "batch_size", 1: "channels", 2: "latent_height", 3: "latent_width"},
+ }
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = {},
+ ) -> str:
+ """
+ Export the VAE model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ **export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffFluxTransformerModel(QEFFBaseModel):
+ """
+ Wrapper for Flux Transformer2D models with ONNX export and QAIC compilation capabilities.
+
+ This class handles Flux Transformer2D models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. Flux uses a transformer-based diffusion
+ architecture instead of traditional UNet, with dual transformer blocks and adaptive layer
+ normalization (AdaLN) for conditioning.
+
+ Attributes:
+ model (nn.Module): The wrapped Flux transformer model
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [AttentionTransform, NormalizationTransform, CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying Flux transformer model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the Flux transformer wrapper.
+
+ Args:
+ model (nn.Module): The Flux transformer model to wrap
+ use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
+ for better modularity and potential optimization
+ """
+ super().__init__(model)
+
+ def get_onnx_params(
+ self,
+ batch_size: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
+ seq_length: int = constants.FLUX_ONNX_EXPORT_SEQ_LENGTH,
+ cl: int = constants.FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM,
+ ) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the Flux transformer.
+
+ Creates example inputs for all Flux-specific inputs including hidden states,
+ text embeddings, timestep conditioning, and AdaLN embeddings.
+
+ Args:
+ batch_size (int): Batch size for example inputs (default: FLUX_ONNX_EXPORT_BATCH_SIZE)
+ seq_length (int): Text sequence length (default: FLUX_ONNX_EXPORT_SEQ_LENGTH)
+ cl (int): Compressed latent dimension (default: FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM)
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ example_inputs = {
+ # Latent representation of the image
+ "hidden_states": torch.randn(batch_size, cl, self.model.config.in_channels, dtype=torch.float32),
+ "encoder_hidden_states": torch.randn(
+ batch_size, seq_length, self.model.config.joint_attention_dim, dtype=torch.float32
+ ),
+ "pooled_projections": torch.randn(batch_size, self.model.config.pooled_projection_dim, dtype=torch.float32),
+ "timestep": torch.tensor([1.0], dtype=torch.float32),
+ "img_ids": torch.randn(cl, 3, dtype=torch.float32),
+ "txt_ids": torch.randn(seq_length, 3, dtype=torch.float32),
+ # AdaLN embeddings for dual transformer blocks
+ # Shape: [num_layers, FLUX_ADALN_DUAL_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM]
+ "adaln_emb": torch.randn(
+ self.model.config["num_layers"],
+ constants.FLUX_ADALN_DUAL_BLOCK_CHUNKS,
+ constants.FLUX_ADALN_HIDDEN_DIM,
+ dtype=torch.float32,
+ ),
+ # AdaLN embeddings for single transformer blocks
+ # Shape: [num_single_layers, FLUX_ADALN_SINGLE_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM]
+ "adaln_single_emb": torch.randn(
+ self.model.config["num_single_layers"],
+ constants.FLUX_ADALN_SINGLE_BLOCK_CHUNKS,
+ constants.FLUX_ADALN_HIDDEN_DIM,
+ dtype=torch.float32,
+ ),
+ # Output AdaLN embedding
+ # Shape: [batch_size, FLUX_ADALN_OUTPUT_DIM] for final projection
+ "adaln_out": torch.randn(batch_size, constants.FLUX_ADALN_OUTPUT_DIM, dtype=torch.float32),
+ }
+
+ output_names = ["output"]
+
+ # Define dynamic dimensions for runtime flexibility
+ dynamic_axes = {
+ "hidden_states": {0: "batch_size", 1: "cl"},
+ "encoder_hidden_states": {0: "batch_size", 1: "seq_len"},
+ "pooled_projections": {0: "batch_size"},
+ "timestep": {0: "steps"},
+ "img_ids": {0: "cl"},
+ }
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = {},
+ use_onnx_subfunctions: bool = False,
+ ) -> str:
+ """
+ Export the Flux transformer model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions)
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+
+ if use_onnx_subfunctions:
+ export_kwargs = {"export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}}
+
+ # Sort _use_default_values in config to ensure consistent hash generation during export
+ self.model.config["_use_default_values"].sort()
+
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ offload_pt_weights=False, # As weights are needed with AdaLN changes
+ **export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations)
+ """
+ self._compile(specializations=specializations, **compiler_options)
diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py
new file mode 100644
index 000000000..eb0a07569
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/pipeline_utils.py
@@ -0,0 +1,220 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import os
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+from tqdm import tqdm
+
+from QEfficient.utils._utils import load_json
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
+
+
+def calculate_compressed_latent_dimension(height: int, width: int, vae_scale_factor: int) -> int:
+ """
+ Calculate the compressed latent dimension.
+ Args:
+ height (int): Target image height in pixels
+ width (int): Target image width in pixels
+ vae_scale_factor (int): VAE downsampling factor (typically 8 for Flux)
+
+ Returns:
+ int: Compressed latent dimension (cl) for transformer input buffer allocation
+ """
+ latent_height = height // vae_scale_factor
+ latent_width = width // vae_scale_factor
+ # cl = compressed latent dimension (divided by 4 for Flux's 2x2 packing)
+ cl = (latent_height * latent_width) // 4
+ return cl, latent_height, latent_width
+
+
+def config_manager(cls, config_source: Optional[str] = None):
+ """
+ JSON-based compilation configuration manager for diffusion pipelines.
+
+ Supports loading configuration from JSON files only. Automatically detects
+ model type and handles model-specific requirements.
+ Initialize the configuration manager.
+
+ Args:
+ config_source: Path to JSON configuration file. If None, uses default config.
+ """
+ if config_source is None:
+ config_source = cls.get_default_config_path()
+
+ if not isinstance(config_source, str):
+ raise ValueError("config_source must be a path to JSON configuration file")
+
+ # Direct use of load_json utility - no wrapper needed
+ if not os.path.exists(config_source):
+ raise FileNotFoundError(f"Configuration file not found: {config_source}")
+
+ cls.custom_config = load_json(config_source)
+
+
+def set_module_device_ids(cls):
+ """
+ Set device IDs for each module based on the custom configuration.
+
+ Iterates through all modules in the pipeline and assigns device IDs
+ from the configuration file to each module's device_ids attribute.
+ """
+ config_modules = cls.custom_config["modules"]
+ for module_name, module_obj in cls.modules.items():
+ module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"]
+
+
+def compile_modules_parallel(
+ modules: Dict[str, Any],
+ config: Dict[str, Any],
+ specialization_updates: Dict[str, Dict[str, Any]] = None,
+) -> None:
+ """
+ Compile multiple pipeline modules in parallel using ThreadPoolExecutor.
+
+ Args:
+ modules: Dictionary of module_name -> module_object pairs to compile
+ config: Configuration dictionary containing module-specific compilation settings
+ specialization_updates: Optional dictionary of module_name -> specialization_updates
+ to apply dynamic values (e.g., image dimensions)
+ """
+
+ def _prepare_and_compile(module_name: str, module_obj: Any) -> None:
+ """Prepare specializations and compile a single module."""
+ specializations = config["modules"][module_name]["specializations"].copy()
+ compile_kwargs = config["modules"][module_name]["compilation"]
+
+ if specialization_updates and module_name in specialization_updates:
+ specializations.update(specialization_updates[module_name])
+
+ module_obj.compile(specializations=[specializations], **compile_kwargs)
+
+ # Execute compilations in parallel
+ with ThreadPoolExecutor(max_workers=len(modules)) as executor:
+ futures = {executor.submit(_prepare_and_compile, name, obj): name for name, obj in modules.items()}
+
+ with tqdm(total=len(futures), desc="Compiling modules", unit="module") as pbar:
+ for future in as_completed(futures):
+ try:
+ future.result()
+ except Exception as e:
+ logger.error(f"Compilation failed for {futures[future]}: {e}")
+ raise
+ pbar.update(1)
+
+
+def compile_modules_sequential(
+ modules: Dict[str, Any],
+ config: Dict[str, Any],
+ specialization_updates: Dict[str, Dict[str, Any]] = None,
+) -> None:
+ """
+ Compile multiple pipeline modules sequentially.
+
+ This function provides a generic way to compile diffusion pipeline modules
+ sequentially, which is the default behavior for backward compatibility.
+
+ Args:
+ modules: Dictionary of module_name -> module_object pairs to compile
+ config: Configuration dictionary containing module-specific compilation settings
+ specialization_updates: Optional dictionary of module_name -> specialization_updates
+ to apply dynamic values (e.g., image dimensions)
+
+ """
+ for module_name, module_obj in tqdm(modules.items(), desc="Compiling modules", unit="module"):
+ module_config = config["modules"]
+ specializations = module_config[module_name]["specializations"].copy()
+ compile_kwargs = module_config[module_name]["compilation"]
+
+ # Apply dynamic specialization updates if provided
+ if specialization_updates and module_name in specialization_updates:
+ specializations.update(specialization_updates[module_name])
+
+ # Compile the module to QPC format
+ module_obj.compile(specializations=[specializations], **compile_kwargs)
+
+
+@dataclass(frozen=True)
+class ModulePerf:
+ """
+ Data class to store performance metrics for a pipeline module.
+
+ Attributes:
+ module_name: Name of the pipeline module (e.g., 'text_encoder', 'transformer', 'vae_decoder')
+ perf: Performance metric in seconds. Can be a single float for modules that run once,
+ or a list of floats for modules that run multiple times (e.g., transformer steps)
+ """
+
+ module_name: str
+ perf: int
+
+
+@dataclass(frozen=True)
+class QEffPipelineOutput:
+ """
+ Data class to store the output of a QEfficient diffusion pipeline.
+
+ Attributes:
+ pipeline_module: List of ModulePerf objects containing performance metrics for each module
+ images: Generated images as either a list of PIL Images or numpy array
+ """
+
+ pipeline_module: list[ModulePerf]
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+ def __repr__(self):
+ output_str = "=" * 60 + "\n"
+ output_str += "QEfficient Diffusers Pipeline Inference Report\n"
+ output_str += "=" * 60 + "\n\n"
+
+ # Module-wise inference times
+ output_str += "Module-wise Inference Times:\n"
+ output_str += "-" * 60 + "\n"
+
+ # Calculate E2E time while iterating
+ e2e_time = 0
+ for module_perf in self.pipeline_module:
+ module_name = module_perf.module_name
+ inference_time = module_perf.perf
+
+ # Add to E2E time
+ e2e_time += sum(inference_time) if isinstance(inference_time, list) else inference_time
+
+ # Format module name for display
+ display_name = module_name.replace("_", " ").title()
+
+ # Handle transformer specially as it has a list of times
+ if isinstance(inference_time, list) and len(inference_time) > 0:
+ total_time = sum(inference_time)
+ avg_time = total_time / len(inference_time)
+ output_str += f" {display_name:25s} {total_time:.4f} s\n"
+ output_str += f" - Total steps: {len(inference_time)}\n"
+ output_str += f" - Average per step: {avg_time:.4f} s\n"
+ output_str += f" - Min step time: {min(inference_time):.4f} s\n"
+ output_str += f" - Max step time: {max(inference_time):.4f} s\n"
+ else:
+ # Single inference time value
+ output_str += f" {display_name:25s} {inference_time:.4f} s\n"
+
+ output_str += "-" * 60 + "\n\n"
+
+ # Print E2E time after all modules
+ output_str += f"End-to-End Inference Time: {e2e_time:.4f} s\n\n"
+ output_str += "=" * 60 + "\n"
+
+ return output_str
+
+
+# List of module name that require special handling during export
+# when use_onnx_subfunctions is enabled
+ONNX_SUBFUNCTION_MODULE = ["transformer"]
diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py
index 2547d9db3..73f9d00cf 100644
--- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py
+++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py
@@ -20,7 +20,9 @@
from QEfficient.utils import load_hf_tokenizer
from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants
from QEfficient.utils.generate_inputs import InputHandler
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
def convert_to_cloud_bertstyle(
diff --git a/QEfficient/finetune/dataset/alpaca_dataset.py b/QEfficient/finetune/dataset/alpaca_dataset.py
index ff44860eb..f25fa9767 100644
--- a/QEfficient/finetune/dataset/alpaca_dataset.py
+++ b/QEfficient/finetune/dataset/alpaca_dataset.py
@@ -11,7 +11,9 @@
import torch
from torch.utils.data import Dataset
-from QEfficient.finetune.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("FT", loglevel="INFO")
PROMPT_DICT = {
"prompt_input": (
diff --git a/QEfficient/finetune/dataset/custom_dataset.py b/QEfficient/finetune/dataset/custom_dataset.py
index ef76e83ed..2a98a3ca4 100644
--- a/QEfficient/finetune/dataset/custom_dataset.py
+++ b/QEfficient/finetune/dataset/custom_dataset.py
@@ -9,7 +9,9 @@
import logging
from pathlib import Path
-from QEfficient.finetune.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("FT", loglevel="INFO")
def load_module_from_py_file(py_file: str) -> object:
diff --git a/QEfficient/finetune/dataset/grammar_dataset.py b/QEfficient/finetune/dataset/grammar_dataset.py
index 8fb3eb152..4a2e4658d 100644
--- a/QEfficient/finetune/dataset/grammar_dataset.py
+++ b/QEfficient/finetune/dataset/grammar_dataset.py
@@ -10,7 +10,9 @@
from datasets import load_dataset
from torch.utils.data import Dataset
-from QEfficient.finetune.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("FT", loglevel="INFO")
class grammar(Dataset):
diff --git a/QEfficient/finetune/utils/config_utils.py b/QEfficient/finetune/utils/config_utils.py
index 0c8b3d827..89d1753d9 100644
--- a/QEfficient/finetune/utils/config_utils.py
+++ b/QEfficient/finetune/utils/config_utils.py
@@ -20,7 +20,9 @@
from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
from QEfficient.finetune.utils.helper import Peft_Method
-from QEfficient.finetune.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("FT", loglevel="INFO")
def update_config(config, **kwargs):
diff --git a/QEfficient/finetune/utils/dataset_utils.py b/QEfficient/finetune/utils/dataset_utils.py
index 01c1e32aa..9e15a7322 100644
--- a/QEfficient/finetune/utils/dataset_utils.py
+++ b/QEfficient/finetune/utils/dataset_utils.py
@@ -16,7 +16,9 @@
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
from QEfficient.finetune.dataset.dataset_config import DATALOADER_COLLATE_FUNC, DATASET_PREPROC
from QEfficient.finetune.utils.helper import get_world_size
-from QEfficient.finetune.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("FT", loglevel="INFO")
def get_preprocessed_dataset(
diff --git a/QEfficient/finetune/utils/plot_metrics.py b/QEfficient/finetune/utils/plot_metrics.py
index 1e22bc6a8..4ea307dff 100644
--- a/QEfficient/finetune/utils/plot_metrics.py
+++ b/QEfficient/finetune/utils/plot_metrics.py
@@ -11,7 +11,9 @@
import matplotlib.pyplot as plt
-from QEfficient.finetune.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("FT", loglevel="INFO")
def plot_metric(data, metric_name, x_label, y_label, title, colors):
diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py
index e07b5dd04..c33906b61 100644
--- a/QEfficient/generation/embedding_handler.py
+++ b/QEfficient/generation/embedding_handler.py
@@ -23,7 +23,9 @@
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import constants
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
class VisionHandler:
diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py
index 7da2300d6..e56a44530 100755
--- a/QEfficient/generation/text_generation_inference.py
+++ b/QEfficient/generation/text_generation_inference.py
@@ -19,9 +19,11 @@
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import padding_check_and_fix
from QEfficient.utils.constants import Constants
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
from QEfficient.utils.sampler_utils import validate_sampler_inputs
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
+
@dataclass
class PerfMetrics:
@@ -329,6 +331,7 @@ def cloud_ai_100_exec_kv(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
+ include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
@@ -356,6 +359,8 @@ def cloud_ai_100_exec_kv(
next tokens. For Speculative Decoding Target Language Model,
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
Decoding Draft Language Model and `return_pdfs`=False for regular model.
+ :include_guided_decoding (bool, default=False): If True, enables guided token-level filtering
+ during decoding. Only works when `include_sampler`=True.
sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend.
The dictionary should contain the following keys:
`repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
@@ -394,6 +399,7 @@ def cloud_ai_100_exec_kv(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
+ include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)
@@ -442,6 +448,7 @@ def __init__(
is_tlm: Optional[int] = None,
include_sampler: bool = False,
return_pdfs: bool = False,
+ include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
activate: bool = True,
) -> None:
@@ -451,6 +458,7 @@ def __init__(
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm
self.return_pdfs = return_pdfs
+ self.include_guided_decoding = include_guided_decoding
self.sampling_params = sampling_params
self._qpc_path = qpc_path # Store qpc_path for later use
@@ -461,7 +469,9 @@ def __init__(
# Validate sampler inputs for On-Device Sampling
self.include_sampler = validate_sampler_inputs(
- session_inputs=set(self._session.input_names), include_sampler=include_sampler
+ session_inputs=set(self._session.input_names),
+ include_sampler=include_sampler,
+ include_guided_decoding=include_guided_decoding,
)
# Fetch the variables from the QPC
@@ -628,7 +638,7 @@ def prepare_decode_inputs(self):
decode_inputs["batch_index"] = self.batch_index
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
- for op in Constants.SAMPLER_OPS:
+ for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if self.batch_index is not None:
decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()]
else:
@@ -795,7 +805,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
inputs["num_logits_to_keep"] = np.zeros((1, 1))
if self.include_sampler:
inputs["last_accepted_output_tokens"] = inputs["input_ids"]
- for op in Constants.SAMPLER_OPS:
+ for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if decode_batch_id is not None:
inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
@@ -1067,6 +1077,7 @@ def __init__(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
+ include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._qaic_model = QEffTextGenerationBase(
@@ -1082,6 +1093,7 @@ def __init__(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
+ include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)
self._full_batch_size = self._qaic_model.full_batch_size
@@ -1302,4 +1314,5 @@ def generate(
generated_ids=self._qaic_model.generated_ids,
perf_metrics=perf_metrics,
)
+ logger.info("Text Generated finised")
return latency_stats
diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py
index b37fdc74a..7ac37bbda 100644
--- a/QEfficient/generation/vlm_generation.py
+++ b/QEfficient/generation/vlm_generation.py
@@ -36,7 +36,10 @@
write_io_files,
)
from QEfficient.utils import LRUCache
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.constants import Constants
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
class VisionLanguageGeneration(QEffTextGenerationBase):
@@ -93,6 +96,7 @@ def __init__(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
+ include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
@@ -114,6 +118,7 @@ def __init__(
is_tlm: Target language model flag
include_sampler: Enable on-device sampling (new feature)
return_pdfs: Return probability distributions
+ include_guided_decoding: Enable guided decoding in on-device sampling
sampling_params: Sampling parameters for on-device sampling
"""
# Validate required parameters
@@ -137,6 +142,7 @@ def __init__(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
+ include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
activate=False, # vision components need to be initialized first
)
@@ -313,6 +319,13 @@ def _execute_chunked_prefill(
prefill_ccl_id = 0
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
+ if self.include_sampler:
+ for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
+ if decode_batch_id is not None:
+ lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
+ else:
+ lang_inputs[op] = self.sampling_params[op]
+
for i in range(num_chunks):
input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len]
position_ids_slice = lang_inputs["position_ids"][
@@ -338,6 +351,11 @@ def _execute_chunked_prefill(
chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"]
+ if self.include_sampler:
+ chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
+ for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
+ chunk_inputs[op] = lang_inputs[op]
+
outputs = self._session.run(chunk_inputs)
if "image_idx_output" in outputs:
@@ -790,6 +808,7 @@ def generate_stream_tokens(
is_tlm=self.is_tlm,
include_sampler=self.include_sampler,
return_pdfs=self.return_pdfs,
+ include_guided_decoding=self.include_guided_decoding,
sampling_params=self.sampling_params,
)
diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py
index e69aebb2b..28c19601e 100644
--- a/QEfficient/peft/auto.py
+++ b/QEfficient/peft/auto.py
@@ -6,7 +6,6 @@
# ----------------------------------------------------------------------------
import hashlib
-import logging
import warnings
from typing import List, Optional, Union
@@ -32,8 +31,9 @@
from QEfficient.utils import constants
from QEfficient.utils._utils import get_padding_shape_from_config
from QEfficient.utils.hash_utils import to_hashable
+from QEfficient.utils.logging_utils import QEFFLogger
-logger = logging.getLogger(__name__)
+logger = QEFFLogger.get_logger("FT", loglevel="INFO")
class QEffAutoPeftModelForCausalLM(QEFFBaseModel):
@@ -253,7 +253,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
return obj
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model with the active adapter to ONNX format.
@@ -291,10 +291,10 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
example_inputs,
output_names,
dynamic_axes,
- export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights
+ do_constant_folding=False, # To avoid merging adapter weights with base weights
onnx_transform_kwargs={"adapter_name": self.model.active_adapter},
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ **kwargs,
)
def compile(
diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py
index 64fa3f61c..07f61c1e9 100644
--- a/QEfficient/peft/lora/auto.py
+++ b/QEfficient/peft/lora/auto.py
@@ -19,7 +19,9 @@
from QEfficient.peft.lora.pytorch_transforms import LoraModelInputsTransform, TargetModulesTransform
from QEfficient.utils import constants, get_padding_shape_from_config
from QEfficient.utils.hash_utils import to_hashable
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("FT", loglevel="INFO")
class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM):
@@ -327,7 +329,7 @@ def _init_adapter_model(self):
# load_weight to model
self._load_adapter_weights_to_model()
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``.
@@ -387,7 +389,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
output_names,
dynamic_axes,
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ **kwargs,
)
def generate(
diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py
index 62cc71a4c..faadaba6b 100644
--- a/QEfficient/transformers/cache_utils.py
+++ b/QEfficient/transformers/cache_utils.py
@@ -46,6 +46,7 @@ def _get_invalid_idx_value(cls):
"""
if torch.onnx.is_in_onnx_export():
if cls.SUBFUNC_ENABLED:
+ # TODO: should not return 0 remove this if condition, it can hurt perf
return 0
else:
return torch.iinfo(torch.int32).max
@@ -681,6 +682,37 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache
+ def write_only(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ k_out, v_out = key_states, value_states
+ else:
+ position_ids = cache_kwargs.get("position_ids")
+ is_sliding_layer = cache_kwargs.get("is_sliding")
+ _, _, ctx_len, _ = self.key_cache[layer_idx].shape
+ if is_sliding_layer:
+ kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1)
+ self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
+ self.value_cache[layer_idx] = CtxScatterFunc.apply(
+ self.value_cache[layer_idx], kv_position_ids, value_states
+ )
+ else:
+ kv_position_ids = position_ids
+
+ self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
+ self.value_cache[layer_idx] = CtxScatterFunc.apply(
+ self.value_cache[layer_idx], kv_position_ids, value_states
+ )
+ k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
+ return k_out, v_out
+
def update(
self,
key_states: torch.Tensor,
@@ -747,3 +779,92 @@ def update(
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out
+
+ def full_cache_update_chunked(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ position_ids = cache_kwargs.get("position_ids")
+ batch_index = cache_kwargs.get("batch_index")
+ invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
+
+ # Scatter
+ if batch_index is not None:
+ if torch.onnx.is_in_onnx_export():
+ scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids)
+ self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
+ self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
+ )
+ self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
+ self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
+ )
+ else:
+ self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
+ self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)
+
+ k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ # Gather
+ ctx_len = cache_kwargs.get("CCL", k_out.shape[2])
+ ctx_indices = torch.arange(ctx_len)[None, None, ...]
+ gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
+ invalid_mask = ctx_indices > gather_limit
+ ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
+ if batch_index is not None:
+ k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len)
+ v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len)
+ else:
+ k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
+ v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)
+ v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
+
+ return k_out, v_out
+
+ def sliding_window_update_chunked(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ position_ids = cache_kwargs.get("position_ids")
+ batch_index = cache_kwargs.get("batch_index")
+ invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
+
+ if batch_index is not None:
+ if torch.onnx.is_in_onnx_export():
+ scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids)
+ self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
+ self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
+ )
+ self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
+ self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
+ )
+ else:
+ self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
+ self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)
+
+ k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
+ sliding_window_len = cache_kwargs.get("sliding_window")
+
+ # Gather
+ ctx_len = position_ids.shape[1] + sliding_window_len
+ ctx_indices = torch.arange(ctx_len)[None, None, ...]
+ first_pos_idx = position_ids[0][0]
+ add_idx = torch.where(first_pos_idx >= sliding_window_len, first_pos_idx - sliding_window_len, 0)
+ ctx_indices += add_idx
+ gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
+ invalid_mask = ctx_indices > gather_limit
+ ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
+ if batch_index is not None:
+ k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len)
+ v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len)
+ else:
+ k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
+ v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)
+ v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
+
+ return k_out, v_out
diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py
index 5337b44f5..47059d8dc 100644
--- a/QEfficient/transformers/modeling_utils.py
+++ b/QEfficient/transformers/modeling_utils.py
@@ -188,6 +188,9 @@
# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc.
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}
+# This is for supporting different modelling classes specially written for prefill-only model
+SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"}
+
# Define a transformers layers to QEff layers dictionary
# While onboarding new models make sure to add the new layer maps to this dictionary.
TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = {
diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
index 84552aff4..48eaf6f54 100644
--- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
+++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
+import math
+import os
from typing import Callable, Optional, Union
import torch
@@ -30,8 +32,10 @@
from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
-from QEfficient.utils import constants
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
class QEffGptOssExperts(GptOssExperts):
@@ -42,8 +46,8 @@ def __qeff_init__(self):
self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim))
-class QEffGptOssMLP(GptOssMLP):
- def alt_forward(self, hidden: torch.Tensor):
+class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP):
+ def forward(self, hidden: torch.Tensor):
B, S, H = hidden.shape
T = B * S
hidden = hidden.view(T, H)
@@ -78,7 +82,62 @@ def alt_forward(self, hidden: torch.Tensor):
up = (hidden @ W_u) + b_u # [T, I]
# Apply GptOss activation with clamping
- gate = gate.clamp(min=None, max=self.experts.limit)
+ gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
+ up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
+
+ # GLU activation
+ glu = gate * torch.sigmoid(gate * self.experts.alpha)
+ intermediate = (up + 1) * glu # [T, I]
+
+ # Down projection
+ down_out = (intermediate @ W_d) + b_d # [T, H]
+
+ # Apply routing weights and accumulate
+ expert_out += down_out * routing_weight
+
+ # original shape [B, S, H]
+ return expert_out.view(B, S, H), router_logits
+
+
+class QEffPrefillOnlyGptOssMLP(GptOssMLP):
+ def forward(self, hidden: torch.Tensor):
+ if os.environ.get("NUM_FFN_BLOCKS", None) is not None:
+ return self.blocked_ffn_forward(hidden)
+ B, S, H = hidden.shape
+ T = B * S
+ hidden = hidden.view(T, H)
+
+ # Router computation
+ router_logits = F.linear(hidden, self.router.weight, self.router.bias)
+
+ # Top-k selection
+ top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K]
+ top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype)
+
+ masked_logits = torch.zeros_like(router_logits)
+ masked_logits.scatter_(1, top_i, top_w)
+
+ # Routing weights for each expert [T, E]
+ routing_weights = masked_logits
+
+ # ββββββββββββββββββ allocate the output tensor βββββ
+ expert_out = hidden.new_zeros((T, H)) # accumulation buffer
+
+ # βββββββββββββββββββββββββ Expert computation loop βββββββββββββββββββββββββββββ
+ for e in range(self.experts.num_experts):
+ routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1]
+
+ W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I]
+ b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I]
+ W_d = self.experts.down_proj[e] # [I, H]
+ b_d = self.experts.down_proj_bias[e] # [H]
+
+ # Gate and Up projections
+ gate = (hidden @ W_g) + b_g # [T, I]
+ up = (hidden @ W_u) + b_u # [T, I]
+
+ # Apply GptOss activation with clamping
+ gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
# GLU activation
@@ -88,6 +147,165 @@ def alt_forward(self, hidden: torch.Tensor):
# Down projection
down_out = (intermediate @ W_d) + b_d # [T, H]
+ # Apply routing weights and accumulate
+ expert_out += down_out * routing_weight
+
+ # original shape [B, S, H]
+ return expert_out.view(B, S, H), router_logits
+
+ def blocked_ffn_forward(self, hidden: torch.Tensor):
+ B, S, H = hidden.shape
+ T = B * S
+ hidden = hidden.view(T, H)
+
+ # Router computation
+ router_logits = F.linear(hidden, self.router.weight, self.router.bias)
+
+ # Top-k selection
+ top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K]
+ top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype)
+
+ masked_logits = torch.zeros_like(router_logits)
+ masked_logits.scatter_(1, top_i, top_w)
+
+ # Routing weights for each expert [T, E]
+ routing_weights = masked_logits
+
+ # ββββββββββββββββββ allocate the output tensor βββββ
+ expert_out = hidden.new_zeros((T, H)) # accumulation buffer
+ target_blocks = int(os.environ.get("NUM_FFN_BLOCKS", 1))
+ block_positions = []
+ for j in range(target_blocks):
+ block_positions.append(j * (T // target_blocks))
+ # βββββββββββββββββββββββββ Expert computation loop βββββββββββββββββββββββββββββ
+ for e in range(self.experts.num_experts):
+ routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1]
+
+ W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I]
+ b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I]
+ W_d = self.experts.down_proj[e] # [I, H]
+ b_d = self.experts.down_proj_bias[e] # [H]
+
+ block_count = 0
+ outs = []
+ for block_idx in range(target_blocks):
+ block_count += 1
+ qi = block_positions[block_idx]
+
+ # Calculate block size (last block should be handled with remainder)
+ if block_idx == target_blocks - 1:
+ real_q_len = T - qi
+ else:
+ real_q_len = block_positions[block_idx + 1] - qi
+
+ tgb = hidden[qi : qi + real_q_len, :]
+ # Gate and Up projections
+ # Gate and Up projections
+ gate = (tgb @ W_g) + b_g # [T, I]
+ up = (tgb @ W_u) + b_u # [T, I]
+
+ # Apply GptOss activation with clamping
+ gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
+ up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
+
+ # GLU activation
+ glu = gate * torch.sigmoid(gate * self.experts.alpha)
+ intermediate = (up + 1) * glu # [T, I]
+
+ # Down projection
+ down_out_block = (intermediate @ W_d) + b_d # [T, H]
+
+ outs.append(down_out_block)
+
+ down_out = torch.cat(outs, dim=0)
+
+ # Apply routing weights and accumulate
+ expert_out += down_out * routing_weight
+
+ # original shape [B, S, H]
+ return expert_out.view(B, S, H), router_logits
+
+ def blocked_ffn_forward_block_weights(self, hidden: torch.Tensor):
+ B, S, H = hidden.shape
+ T = B * S
+ hidden = hidden.view(T, H)
+
+ # Router computation
+ router_logits = F.linear(hidden, self.router.weight, self.router.bias)
+
+ # Top-k selection
+ top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K]
+ top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype)
+
+ masked_logits = torch.zeros_like(router_logits)
+ masked_logits.scatter_(1, top_i, top_w)
+
+ # Routing weights for each expert [T, E]
+ routing_weights = masked_logits
+
+ # ββββββββββββββββββ allocate the output tensor βββββ
+ expert_out = hidden.new_zeros((T, H)) # accumulation buffer
+ target_blocks = int(os.environ.get("NUM_BLOCKS", 1))
+ block_positions = []
+ for j in range(target_blocks):
+ block_positions.append(j * (T // target_blocks))
+ # βββββββββββββββββββββββββ Expert computation loop βββββββββββββββββββββββββββββ
+ for e in range(self.experts.num_experts):
+ routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1]
+
+ W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I]
+ b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I]
+ W_d = self.experts.down_proj[e] # [I, H]
+ b_d = self.experts.down_proj_bias[e] # [H]
+
+ block_count = 0
+ outs = []
+ for block_idx in range(target_blocks):
+ block_count += 1
+ qi = block_positions[block_idx]
+
+ # Calculate block size (last block should be handled with remainder)
+ if block_idx == target_blocks - 1:
+ real_q_len = T - qi
+ else:
+ real_q_len = block_positions[block_idx + 1] - qi
+
+ tgb = hidden[qi : qi + real_q_len, :]
+ # Gate and Up projections
+
+ wg_col_shape = W_g.shape[1]
+ wg_num_blocks = math.ceil(wg_col_shape / 128)
+ last_block_size = wg_col_shape % 128 if wg_col_shape % 128 != 0 else 128
+
+ intermediates = []
+ for i in range(wg_num_blocks):
+ if i == wg_num_blocks - 1:
+ cur_gate = (tgb @ W_g[:, -last_block_size:]) + b_g[-last_block_size:]
+ cur_up = (tgb @ W_u[:, -last_block_size:]) + b_u[-last_block_size:]
+ else:
+ cur_gate = (tgb @ W_g[:, i * 128 : (i + 1) * 128]) + b_g[i * 128 : (i + 1) * 128]
+ cur_up = (tgb @ W_u[:, i * 128 : (i + 1) * 128]) + b_u[i * 128 : (i + 1) * 128]
+
+ cur_gate = cur_gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
+ cur_up = cur_up.clamp(min=-self.experts.limit, max=self.experts.limit)
+ cur_glu = cur_gate * torch.sigmoid(cur_gate * self.experts.alpha)
+ cur_intermediate = (cur_up + 1) * cur_glu
+ intermediates.append(cur_intermediate)
+
+ intermediate = torch.cat(intermediates, dim=-1)
+
+ downs = []
+ for i in range(wg_num_blocks):
+ if i == wg_num_blocks - 1:
+ downs.append((intermediate @ W_d[:, -last_block_size:]) + b_d[-last_block_size:])
+ else:
+ downs.append((intermediate @ W_d[:, i * 128 : (i + 1) * 128]) + b_d[i * 128 : (i + 1) * 128])
+
+ down_out_block = torch.cat(downs, dim=1)
+ outs.append(down_out_block)
+
+ down_out = torch.cat(outs, dim=0)
+
# Apply routing weights and accumulate
masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out))
expert_out += masked_down
@@ -95,6 +313,8 @@ def alt_forward(self, hidden: torch.Tensor):
# original shape [B, S, H]
return expert_out.view(B, S, H), router_logits
+
+class QEffGptOssMLP(GptOssMLP):
# ------------------- Gather based, weights as activation approach ---------------
def forward_weights_as_activation(self, hidden_states):
bs, seq_len, _ = hidden_states.shape
@@ -142,7 +362,6 @@ def forward_weights_as_activation(self, hidden_states):
# ------------------- Gather based, weights as activation approach, With Seperate Gate, up Projections ---------------
def forward(self, hidden_states):
- # print("Seperate Split, Up, Gate Projections")
bs, seq_len, _ = hidden_states.shape
hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size)
@@ -172,7 +391,7 @@ def forward(self, hidden_states):
up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1)
# Apply activation with clamping
- gate = gate.clamp(min=None, max=self.experts.limit)
+ gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
# GLU activation
@@ -404,6 +623,283 @@ def eager_attention_forward(
return attn_output, attn_weights
+def eager_attention_forward_blocked(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ BS, NH, CL, DH = query.shape
+ target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1))
+ block_positions = []
+ for j in range(target_blocks):
+ block_positions.append(j * (CL // target_blocks))
+ block_count = 0
+
+ outs = []
+ for block_idx in range(target_blocks):
+ block_count += 1
+ qi = block_positions[block_idx]
+
+ # Calculate block size (last block should be handled with remainder)
+ if block_idx == target_blocks - 1:
+ real_q_len = CL - qi
+ else:
+ real_q_len = block_positions[block_idx + 1] - qi
+
+ q_block = query[:, :, qi : qi + real_q_len, :]
+ scores = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling
+ attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :]
+ curr_attn_weights = torch.where(
+ attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores
+ )
+ sinks = module.sinks.reshape(1, -1, 1, 1).expand(
+ curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1
+ )
+ combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1)
+ combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
+ curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32)
+ curr_attn_weights = curr_attn_weights[..., :-1]
+ out_block = torch.matmul(curr_attn_weights, value_states)
+ outs.append(out_block)
+ output = torch.cat(outs, dim=2)
+
+ output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous()
+ return output, output
+
+
+def opt_eager_attention_forward_blocked(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ BS, NH, CL, DH = query.shape
+ target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1))
+ block_positions = []
+ for j in range(target_blocks):
+ block_positions.append(j * (CL // target_blocks))
+ block_count = 0
+ outs = []
+ for block_idx in range(target_blocks):
+ block_count += 1
+ qi = block_positions[block_idx]
+ # Calculate block size (last block should be handled with remainder)
+
+ if block_idx == target_blocks - 1:
+ real_q_len = CL - qi
+ else:
+ real_q_len = block_positions[block_idx + 1] - qi
+
+ if block_idx == 0:
+ kv_start_idx = 0
+ else:
+ kv_start_idx = qi - 128
+
+ q_block = query[:, :, qi : qi + real_q_len, :]
+ if kwargs.get("sliding_window"):
+ k_block = key_states[:, :, kv_start_idx : qi + real_q_len, :]
+ v_block = value_states[:, :, kv_start_idx : qi + real_q_len, :]
+ attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, kv_start_idx : qi + real_q_len]
+ else:
+ k_block = key_states
+ v_block = value_states
+ attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :]
+
+ scores = torch.matmul(q_block, k_block.transpose(2, 3)) * scaling
+ curr_attn_weights = torch.where(
+ attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores
+ )
+ sinks = module.sinks.reshape(1, -1, 1, 1).expand(
+ curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1
+ )
+ combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1)
+ combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
+ curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32)
+ curr_attn_weights = curr_attn_weights[..., :-1]
+ out_block = torch.matmul(curr_attn_weights, v_block)
+ outs.append(out_block)
+ output = torch.cat(outs, dim=2)
+
+ output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous()
+ return output, output
+
+
+class QEffPrefillOnlyChunkedGptOssAttention(GptOssAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __qeff_init__(self):
+ self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ batch_index: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ sliding_mask=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")):
+ max_seq_len_cached = 32 * 1024
+ cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached)
+ query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "config": self.config,
+ "is_sliding": self.sliding_window is not None,
+ "sliding_window": self.sliding_window,
+ }
+ if self.sliding_window is not None:
+ key_states, value_states = past_key_value.sliding_window_update_chunked(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+ else:
+ key_states, value_states = past_key_value.full_cache_update_chunked(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ if self.sliding_window is not None:
+ attention_mask = sliding_mask
+ # positive_pos_ids = torch.where(position_ids<0, 0, position_ids)
+ ctx_len = position_ids.shape[1] + self.sliding_window
+ ctx_indices = torch.arange(ctx_len)
+ first_pos_idx = position_ids[0][0]
+ add_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx - self.sliding_window, 0)
+ # start_idx = torch.where(first_pos_idx>=self.sliding_window, first_pos_idx-self.sliding_window, 0)
+ # end_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx+position_ids.shape[1], position_ids.shape[1]+self.sliding_window)
+ ctx_indices += add_idx
+ attention_mask = attention_mask[:, :, :, ctx_indices]
+ else:
+ attention_mask = attention_mask
+
+ attention_interface: Callable = eager_attention_forward
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ s_aux=self.sinks, # diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights, past_key_value
+
+
+class QEffPrefillOnlyGptOssAttention(GptOssAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __qeff_init__(self):
+ self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ batch_index: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ sliding_mask=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")):
+ max_seq_len_cached = 32 * 1024
+ cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached)
+ query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "config": self.config,
+ "is_sliding": self.sliding_window is not None,
+ "sliding_window": past_key_value.sliding_window_len,
+ }
+ if self.sliding_window is not None:
+ sliding_window_len = past_key_value.sliding_window_len
+ short_read_idx = torch.arange(past_key_value.key_cache[self.layer_idx].shape[2])
+ read_idx = short_read_idx + torch.where(
+ position_ids.max() > sliding_window_len - 1, position_ids.max() - sliding_window_len + 1, 0
+ )
+ # This is a trick to export with seq_len position_ids.max(), 0, read_idx)
+ k_cache = key_states[:, :, read_idx, :]
+ v_cache = value_states[:, :, read_idx, :]
+ else:
+ k_cache, v_cache = key_states, value_states
+ _, _ = past_key_value.write_only(k_cache, v_cache, self.layer_idx, cache_kwargs)
+
+ if self.sliding_window is not None:
+ attention_mask = sliding_mask
+ else:
+ attention_mask = attention_mask
+
+ if os.environ.get("ENABLE_OPT_SWA", "0") == "1":
+ attention_interface: Callable = opt_eager_attention_forward_blocked
+ else:
+ attention_interface: Callable = eager_attention_forward_blocked
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ s_aux=self.sinks, # diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights, past_key_value
+
+
class QEffGptOssAttention(GptOssAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -429,8 +925,9 @@ def forward(
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024)
+ if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")):
+ max_seq_len_cached = 32 * 1024
+ cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
@@ -511,7 +1008,6 @@ def forward(
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
- # alth, _ = self.mlp.alt_forward(hidden_states)
hidden_states = hidden_states.reshape(residual.shape)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
@@ -525,6 +1021,97 @@ def forward(
return outputs
+class QEffPrefillOnlyGptOssModel(GptOssModel):
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ batch_index: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens
+ causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len)
+ sliding_mask = _create_causal_mask(
+ position_ids=position_ids,
+ target_length=past_key_values.max_cache_len,
+ sliding_window=self.config.sliding_window,
+ )
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ batch_index=batch_index,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ sliding_mask=sliding_mask,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if return_legacy_cache:
+ past_key_values = past_key_values.to_legacy_cache()
+
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ )
+
+
class QEffGptOssModel(GptOssModel):
def forward(
self,
@@ -578,7 +1165,6 @@ def forward(
)
hidden_states = inputs_embeds
- # position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
@@ -708,15 +1294,15 @@ def forward(
router_logits=outputs.router_logits,
)
- def get_pkv_dynamic_axes(
- self,
- ):
+ def get_pkv_dynamic_axes(self, retain_full_kv: Optional[bool] = False, continuous_batching: Optional[bool] = False):
pkv_dynamic_axes = []
for layer_type in self.config.layer_types:
- if layer_type == "sliding_attention":
- pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"})
- elif layer_type == "full_attention":
- pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"})
+ if layer_type == "sliding_attention" and not retain_full_kv:
+ pkv_dynamic_axes.append(
+ {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"}
+ )
+ else:
+ pkv_dynamic_axes.append({0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"})
return pkv_dynamic_axes
def get_specializations(
@@ -724,10 +1310,14 @@ def get_specializations(
batch_size: int,
prefill_seq_len: int,
ctx_len: int,
+ **kwargs,
):
batch_size = batch_size if batch_size else 1
- prefill_seq_len = prefill_seq_len if prefill_seq_len else constants.PROMPT_LEN
- ctx_len = ctx_len if ctx_len else constants.CTX_LEN
+ if kwargs.get("prefill_only") and not kwargs.get("enable_chunking") and ctx_len != prefill_seq_len:
+ ctx_len = prefill_seq_len
+ logger.warning(
+ f"overriding ctx_len={prefill_seq_len}, currently we don't support ctx_len different than prefill_seq_len for prefill_only model"
+ )
specializations = [
{
diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py
index 85c331aa8..0fc993406 100644
--- a/QEfficient/transformers/models/internvl/modeling_internvl.py
+++ b/QEfficient/transformers/models/internvl/modeling_internvl.py
@@ -13,7 +13,9 @@
from QEfficient.utils import constants
from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
class QEffInternEncoderWrapper(nn.Module):
diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py
index d5f5ee920..d3516e700 100644
--- a/QEfficient/transformers/models/llava/modeling_llava.py
+++ b/QEfficient/transformers/models/llava/modeling_llava.py
@@ -15,7 +15,9 @@
)
from QEfficient.utils._utils import IOInfo
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
BS = 1
FBS = 4
diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py
index 878d04a45..3d6e9797a 100755
--- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py
+++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py
@@ -18,7 +18,9 @@
from QEfficient.utils import constants
from QEfficient.utils._utils import IOInfo
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
FBS = constants.ONNX_EXPORT_EXAMPLE_FBS
diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py
index 89e19c65b..d5d510f99 100644
--- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py
+++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py
@@ -21,7 +21,9 @@
from QEfficient.utils import constants
from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
def custom_cumsum(tensor):
diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py
index 8edc1f3f0..750313bb7 100644
--- a/QEfficient/transformers/models/modeling_auto.py
+++ b/QEfficient/transformers/models/modeling_auto.py
@@ -5,10 +5,11 @@
#
# ----------------------------------------------------------------------------
+import os
import warnings
from pathlib import Path
from time import perf_counter
-from typing import Dict, List, Optional, Union
+from typing import List, Optional, Union
import numpy as np
import torch
@@ -37,13 +38,20 @@
get_compilation_dims,
)
from QEfficient.generation.vlm_generation import VisionLanguageGeneration
-from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH
+from QEfficient.transformers.modeling_utils import (
+ DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH,
+ SPECIALIZED_PREFILL_ONLY_MODEL_ARCH,
+)
from QEfficient.transformers.models.pytorch_transforms import (
BlockedKVAttentionTransform,
CustomOpsTransform,
KVCacheExternalModuleMapperTransform,
KVCacheTransform,
PoolingTransform,
+ PrefillOnlyChunkedTransform,
+ PrefillOnlyTransform,
+ RevertPrefillKeepAttentionTransform,
+ RevertPrefillOnlyTransform,
SamplerTransform,
SpDTransform,
VlmKVOffloadTransform,
@@ -61,7 +69,10 @@
get_padding_shape_from_config,
)
from QEfficient.utils.check_ccl_specializations import process_ccl_specializations
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
class QEFFTransformersBase(QEFFBaseModel):
@@ -124,21 +135,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying HuggingFace model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
class MultimodalUtilityMixin:
"""
@@ -316,7 +312,7 @@ def get_model_config(self) -> dict:
"""
return self.model.config.__dict__
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model to ONNX format using ``torch.onnx.export``.
@@ -353,7 +349,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
output_names,
dynamic_axes,
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
)
def compile(
@@ -603,15 +599,7 @@ def __init__(self, model: nn.modules, **kwargs):
self.model = model.get_qeff_vision_encoder()
self.hash_params["qeff_auto_class"] = self.__class__.__name__
- def export(
- self,
- inputs,
- output_names,
- dynamic_axes,
- export_dir=None,
- offload_pt_weights=True,
- use_onnx_subfunctions: bool = False,
- ):
+ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs):
"""
Exports the vision encoder component to ONNX format.
@@ -641,7 +629,7 @@ def export(
dynamic_axes,
export_dir=export_dir,
offload_pt_weights=offload_pt_weights,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
)
def compile(
@@ -701,21 +689,6 @@ def compile(
**compiler_options,
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying vision encoder model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@property
def get_model_config(self) -> dict:
"""
@@ -749,7 +722,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
- def __init__(self, model, qaic_config, **kwargs):
+ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
"""
Initializes the language decoder component for multimodal models.
@@ -763,7 +736,7 @@ def __init__(self, model, qaic_config, **kwargs):
**kwargs :
Additional keyword arguments passed to the base class constructor.
"""
- super().__init__(model, **kwargs)
+ super().__init__(model, qaic_config=qaic_config, **kwargs)
self.model = model.get_qeff_language_decoder()
self.model.qaic_config = qaic_config
self.hash_params["qeff_auto_class"] = self.__class__.__name__
@@ -771,15 +744,7 @@ def __init__(self, model, qaic_config, **kwargs):
if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))
- def export(
- self,
- inputs,
- output_names,
- dynamic_axes,
- export_dir=None,
- offload_pt_weights=True,
- use_onnx_subfunctions: bool = False,
- ):
+ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs):
"""
Exports the language decoder component to ONNX format.
@@ -809,7 +774,7 @@ def export(
dynamic_axes,
export_dir=export_dir,
offload_pt_weights=offload_pt_weights,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
)
def compile(
@@ -869,21 +834,6 @@ def compile(
**compiler_options,
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying language decoder model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@property
def get_model_config(self) -> dict:
"""
@@ -924,16 +874,16 @@ def __init__(
----------
model : nn.Module
The full HuggingFace multimodal model.
+ qaic_config : dict, optional
+ A dictionary for QAIC-specific configurations.
**kwargs :
- Additional keyword arguments. `full_batch_size` is not supported here.
-
- Raises
- ------
- NotImplementedError
- If `full_batch_size` is provided.
+ Additional keyword arguments.
"""
if kwargs.pop("full_batch_size", None):
- raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
+ continuous_batching = True
+ warnings.warn(
+ "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
+ )
self.model = model
self.config = model.config
@@ -945,21 +895,11 @@ def __init__(
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
self.input_shapes, self.output_names = None, None
-
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying multimodal model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
+ # ---Sampling---
+ # Note: SamplerTransform should be applied after all other transforms
+ # are done. The role of the sampler is to just add nodes at the output of the
+ # previous transform function.
+ self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs):
@@ -1070,6 +1010,19 @@ def export(
kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode
)
output_names = self.model.get_output_names(kv_offload=True)
+ if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get(
+ "include_sampler", False
+ ):
+ logits_index = output_names["lang"].index("logits")
+ output_names["lang"][logits_index] = "next_tokens"
+ inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs(
+ example_inputs=inputs["lang"],
+ output_names=output_names["lang"],
+ dynamic_axes=dynamic_axes["lang"],
+ continuous_batching=self.continuous_batching,
+ vocab_size=self.model.language_model.config.vocab_size,
+ qaic_config=self.lang_model.model.qaic_config,
+ )
self.vision_model.export(
inputs["vision"],
@@ -1302,6 +1255,7 @@ def generate(
generation_len: Optional[int] = None,
image_height: Optional[int] = None,
image_width: Optional[int] = None,
+ **kwargs,
) -> Union[torch.Tensor, np.ndarray]:
"""
Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards.
@@ -1362,6 +1316,7 @@ def generate(
comp_ctx_lengths_decode=self.comp_ctx_lengths_decode,
image_height=image_height,
image_width=image_width,
+ **kwargs,
)
# Call generate method
@@ -1644,10 +1599,15 @@ def __init__(
Raises
------
NotImplementedError
- If `full_batch_size` is provided.
+ If `full_batch_size` is provided or `include_sampler` is True.
"""
if kwargs.pop("full_batch_size", None):
+ warnings.warn(
+ "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
+ )
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
+ if qaic_config is not None and qaic_config.pop("include_sampler", False):
+ raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.")
super().__init__(model, **kwargs)
self.model.qaic_config = qaic_config
@@ -2131,21 +2091,6 @@ def cloud_ai_100_generate(
),
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying multimodal model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@property
def get_model_config(self) -> dict:
"""
@@ -2279,6 +2224,8 @@ def from_pretrained(
If True, uses the dual QPC approach (vision encoder KV offloaded).
If False, uses the single QPC approach (entire model in one QPC).
If None, the default behavior of the internal classes is used (typically dual QPC).
+ qaic_config : dict, optional
+ A dictionary for QAIC-specific configurations.
**kwargs :
Additional arguments passed to HuggingFace's ``from_pretrained``.
@@ -2306,7 +2253,6 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
-
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(
model,
@@ -2359,11 +2305,30 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+ def prefill(
+ self,
+ enable: Optional[bool] = True,
+ enable_chunking: Optional[bool] = False,
+ retain_full_kv: Optional[bool] = False,
+ ):
+ if enable:
+ if enable_chunking:
+ self.model, tf = PrefillOnlyChunkedTransform.apply(self.model)
+ else:
+ self.model, tf = PrefillOnlyTransform.apply(self.model)
+
+ else:
+ if retain_full_kv:
+ self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model)
+ else:
+ self.model, tf = RevertPrefillOnlyTransform.apply(self.model)
+
def __init__(
self,
model: nn.Module,
continuous_batching: bool = False,
qaic_config: Optional[dict] = None,
+ max_seq_len_cached: Optional[int] = None,
**kwargs,
):
"""
@@ -2383,6 +2348,8 @@ def __init__(
- **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens.
For Speculative Decoding Target Language Models, this is always True.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
+ - **include_guided_decoding** (bool): If True, enables guided token-level filtering
+ during decoding. Only works when include_sampler=True.
- **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation.
**kwargs :
Additional keyword arguments passed to the base class constructor.
@@ -2411,6 +2378,7 @@ def __init__(
# Set use_cache=True to get KV values as output during ONNX export
model.config.use_cache = True
+ setattr(model.config, "max_seq_len_cached", max_seq_len_cached)
super().__init__(model, qaic_config=qaic_config, **kwargs)
self.num_layers = model.config.num_hidden_layers
self.continuous_batching = continuous_batching
@@ -2423,6 +2391,7 @@ def __init__(
if qaic_config:
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
+ self.hash_params["max_seq_len_cached"] = max_seq_len_cached
# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
@@ -2437,21 +2406,6 @@ def __init__(
if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying Causal Language Model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
def __repr__(self) -> str:
return self.__class__.__name__ + "\n" + self.model.__repr__()
@@ -2462,6 +2416,7 @@ def from_pretrained(
pretrained_model_name_or_path,
continuous_batching: bool = False,
qaic_config: Optional[dict] = None,
+ max_seq_len_cached: Optional[int] = None,
*args,
**kwargs,
):
@@ -2491,6 +2446,8 @@ def from_pretrained(
and ``return_pdfs=False`` for regular model.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
The values provided in ``top_ks`` tensor must be less than this maximum limit.
+ - **include_guided_decoding** (bool): If True, enables guided token-level filtering
+ during decoding. Only works when include_sampler=True.
*args :
Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`.
@@ -2525,7 +2482,6 @@ def from_pretrained(
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
# This is support models that should be classified to in a different auto class but transformers load them via this class
-
if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
model,
@@ -2540,6 +2496,7 @@ def from_pretrained(
continuous_batching=continuous_batching,
qaic_config=qaic_config,
pretrained_model_name_or_path=pretrained_model_name_or_path,
+ max_seq_len_cached=max_seq_len_cached,
**kwargs,
)
@@ -2555,7 +2512,56 @@ def get_model_config(self) -> dict:
"""
return self.model.config.__dict__
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, **kwargs) -> str:
+ def get_seq_len_and_handle_specialized_prefill_model(
+ self, prefill_seq_len: Optional[int] = None, enable_chunking=False
+ ) -> int:
+ self.hash_params["prefill_only"] = True
+ if enable_chunking:
+ self.hash_params["chunking"] = True
+ return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
+
+ num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None)
+ if num_q_blocks is None:
+ block_size = 128
+ if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128:
+ raise ValueError(
+ f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. "
+ f"Or set `NUM_Q_BLOCKS` ENV variable"
+ f"Received: prefill_seq_len={prefill_seq_len}"
+ )
+
+ num_q_blocks = prefill_seq_len // block_size
+ logger.warning(
+ f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override"
+ )
+ os.environ["NUM_Q_BLOCKS"] = str(num_q_blocks)
+ num_q_blocks = int(num_q_blocks)
+
+ num_ffn_blocks = os.environ.get("NUM_FFN_BLOCKS", None)
+ num_ffn_blocks = int(num_ffn_blocks) if num_ffn_blocks else num_ffn_blocks
+ min_seq_len = max(num_q_blocks, num_ffn_blocks) if num_ffn_blocks else num_q_blocks
+ if (num_ffn_blocks and min_seq_len % num_ffn_blocks != 0) or min_seq_len % num_q_blocks != 0:
+ raise ValueError(
+ f"Got NUM_FFN_BLOCKS={num_ffn_blocks} and NUM_Q_BLOCKS={num_q_blocks}, tried to set seq_len={min_seq_len} for export but,"
+ "seq_len is not divisible by either num_ffn_blocks or num_q_blocks, try chaning the values."
+ )
+
+ self.hash_params["NUM_Q_BLOCKS"] = num_q_blocks
+ self.hash_params["NUM_FFN_BLOCKS"] = num_ffn_blocks
+ self.hash_params["ENABLE_OPT_SWA"] = os.environ.get("ENABLE_OPT_SWA", "0")
+ return (
+ min_seq_len
+ if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
+ else constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
+ )
+
+ def export(
+ self,
+ export_dir: Optional[str] = None,
+ prefill_only: Optional[bool] = False,
+ prefill_seq_len: Optional[int] = None,
+ **kwargs,
+ ) -> str:
"""
Export the model to ONNX format using ``torch.onnx.export``.
@@ -2581,6 +2587,33 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
kv_cache_shape = get_padding_shape_from_config(
self.model.config, fbs if self.continuous_batching else bs, seq_len
)
+ enable_chunking = kwargs.get("enable_chunking", False)
+ if prefill_only:
+ if not enable_chunking and self.continuous_batching:
+ raise NotImplementedError(
+ "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
+ )
+ self.prefill(enable=True, enable_chunking=enable_chunking)
+ self.hash_params.pop("retain_full_kv", None)
+ seq_len = (
+ self.get_seq_len_and_handle_specialized_prefill_model(
+ prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking
+ )
+ if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH
+ else seq_len
+ )
+ kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len
+ else:
+ self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
+ self.hash_params.pop("prefill_only", None)
+ self.hash_params.pop("NUM_Q_BLOCKS", None)
+ self.hash_params.pop("NUM_FFN_BLOCKS", None)
+ self.hash_params.pop("ENABLE_OPT_SWA", None)
+ self.hash_params.pop("chunking", None)
+ if kwargs.get("retain_full_kv", False):
+ kv_cache_shape[2] = seq_len + self.model.config.sliding_window
+ self.hash_params["retain_full_kv"] = True
+
example_inputs = {
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
"position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
@@ -2629,7 +2662,13 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
else:
# HACK: create common function for this including above if condition code
pkv_dynamic_axes = (
- self.model.get_pkv_dynamic_axes() if hasattr(self.model, "get_pkv_dynamic_axes") else pkv_dynamic_axes
+ self.model.get_pkv_dynamic_axes(
+ retain_full_kv=kwargs.get("retain_full_kv", False)
+ or (prefill_only and kwargs.get("enable_chunking", False)),
+ continuous_batching=self.continuous_batching,
+ )
+ if hasattr(self.model, "get_pkv_dynamic_axes")
+ else pkv_dynamic_axes
)
pkv_dynamic_axes = (
[pkv_dynamic_axes] * self.model.config.num_hidden_layers
@@ -2638,7 +2677,6 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
)
for i in range(self.num_layers):
- pkv_dynamic_axes[i][0] = "full_batch_size" if self.continuous_batching else "batch_size"
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i]
@@ -2654,100 +2692,24 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"}
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
- example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(
+ example_inputs, output_names, dynamic_axes = get_sampling_inputs_and_outputs(
example_inputs=example_inputs,
output_names=output_names,
dynamic_axes=dynamic_axes,
+ continuous_batching=self.continuous_batching,
+ vocab_size=self.model.config.vocab_size,
+ qaic_config=self.model.qaic_config,
)
-
return self._export(
example_inputs,
output_names,
dynamic_axes,
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
offload_pt_weights=kwargs.get("offload_pt_weights", True),
+ prefill_only=prefill_only,
)
- def get_sampling_inputs_and_outputs(
- self,
- example_inputs: Dict[str, torch.Tensor],
- output_names: List[str],
- dynamic_axes: Dict[str, Dict[int, str]],
- ):
- """
- Updates the example inputs, output names, and dynamic axes to include
- parameters relevant for on-device sampling during ONNX export.
-
- Parameters
- ----------
- example_inputs : Dict[str, torch.Tensor]
- Current dictionary of example inputs.
- output_names : List[str]
- Current list of output names.
- dynamic_axes : Dict[str, Dict[int, str]]
- Current dictionary of dynamic axes configurations.
-
- Returns
- -------
- Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
- Updated example inputs, output names, and dynamic axes including
- sampling-related parameters.
- """
- bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
- fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
-
- example_inputs["last_accepted_output_tokens"] = torch.zeros(
- (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
- )
- dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}
-
- example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
- (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
- )
- dynamic_axes["past_repetition_penalty_buffer"] = {
- 0: "full_batch_size" if self.continuous_batching else "batch_size",
- }
- output_names.append("past_repetition_penalty_buffer_RetainedState")
-
- example_inputs["repetition_penalties"] = (
- torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
- )
- dynamic_axes["repetition_penalties"] = {0: "batch_size"}
-
- example_inputs["past_presence_penalty_buffer"] = torch.zeros(
- (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
- )
- dynamic_axes["past_presence_penalty_buffer"] = {
- 0: "full_batch_size" if self.continuous_batching else "batch_size",
- }
- output_names.append("past_presence_penalty_buffer_RetainedState")
-
- example_inputs["presence_penalties"] = (
- torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
- )
- dynamic_axes["presence_penalties"] = {0: "batch_size"}
-
- example_inputs["temperatures"] = (
- torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
- )
- dynamic_axes["temperatures"] = {0: "batch_size"}
-
- max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
- example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
- dynamic_axes["top_ks"] = {0: "batch_size"}
-
- example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
- dynamic_axes["top_ps"] = {0: "batch_size"}
-
- example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
- dynamic_axes["min_ps"] = {0: "batch_size"}
-
- example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
- dynamic_axes["random_numbers"] = {0: "batch_size"}
-
- return example_inputs, output_names, dynamic_axes
-
def build_prefill_specialization(
self,
prefill_seq_len: int = 32,
@@ -2756,6 +2718,7 @@ def build_prefill_specialization(
batch_size: int = 1,
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
+ **kwargs,
):
"""
Builds a dictionary representing a compilation specialization for the prefill phase.
@@ -2778,11 +2741,17 @@ def build_prefill_specialization(
Dict[str, Union[int, str]]
A dictionary defining the prefill specialization.
"""
+ if prefill_seq_len == 1 and self.continuous_batching:
+ exec_batch_size = full_batch_size
+ else:
+ exec_batch_size = 1 if self.continuous_batching else batch_size
+
if hasattr(self.model, "get_specializations"):
spec = self.model.get_specializations(
- batch_size=1 if self.continuous_batching else batch_size,
+ batch_size=exec_batch_size,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
+ **kwargs,
)[0]
else:
spec = {
@@ -2810,6 +2779,7 @@ def build_decode_specialization(
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
+ **kwargs,
):
"""
Builds a dictionary representing a compilation specialization for the decode phase.
@@ -2880,6 +2850,9 @@ def compile(
num_speculative_tokens: Optional[int] = None,
prefill_only: Optional[bool] = None,
use_onnx_subfunctions: bool = False,
+ offload_pt_weights: Optional[bool] = True,
+ enable_chunking: Optional[bool] = False,
+ retain_full_kv: Optional[bool] = None,
**compiler_options,
) -> str:
"""
@@ -2960,6 +2933,20 @@ def compile(
If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models.
"""
+ if prefill_only is None or not prefill_only:
+ if self.continuous_batching and full_batch_size is None:
+ raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")
+ if kv_cache_batch_size and not full_batch_size:
+ raise ValueError(
+ "KV caching requires continuous batching. Please set `full_batch_size` and "
+ "enable `continuous_batching=True` in `from_pretrained`."
+ )
+ else:
+ if self.continuous_batching:
+ if not isinstance(kv_cache_batch_size, int):
+ raise ValueError(
+ "Please pass valid integer for kv_cache_batch_size as continuous_batching is enabled for prefill-only model"
+ )
# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
@@ -2997,15 +2984,6 @@ def compile(
if self.is_tlm:
num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len)
- if self.continuous_batching and full_batch_size is None:
- raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")
-
- if kv_cache_batch_size and not full_batch_size:
- raise ValueError(
- "KV caching requires continuous batching. Please set `full_batch_size` and "
- "enable `continuous_batching=True` in `from_pretrained`."
- )
-
if (
self.model.qaic_config is not None
and self.model.qaic_config.get("include_sampler", False)
@@ -3014,15 +2992,23 @@ def compile(
):
raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.")
+ if kv_cache_batch_size and prefill_only is not None and prefill_only:
+ logger.warning(
+ "kv_cache_batch_size will be ignored as prefill_only is set to True unless this is GPTOSS model"
+ )
+
# Infer kv_cache_batch_size if not provided
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
# --- Specializations ---
specializations = []
if prefill_only is None or prefill_only or prefill_seq_len == 1:
+ # TODO: we are handling decode-only case inside prefill call which is utterly mis-leading
if self.comp_ctx_lengths_prefill is not None:
# Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization
for i in range(0, len(self.comp_ctx_lengths_prefill)):
+ if prefill_only or enable_chunking:
+ raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL")
specializations.append(
self.build_prefill_specialization(
prefill_seq_len=prefill_seq_len,
@@ -3042,6 +3028,8 @@ def compile(
batch_size=batch_size,
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
+ prefill_only=prefill_only,
+ enable_chunking=enable_chunking,
)
)
@@ -3069,6 +3057,7 @@ def compile(
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
+ prefill_only=prefill_only,
)
if decode_spec:
specializations.append(decode_spec)
@@ -3081,7 +3070,6 @@ def compile(
for i in range(self.num_layers):
for kv in ["key", "value"]:
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
-
qpc_path = self._compile(
onnx_path=onnx_path,
compile_dir=compile_dir,
@@ -3096,6 +3084,10 @@ def compile(
aic_num_cores=num_cores,
mxint8_kv_cache=mxint8_kv_cache,
use_onnx_subfunctions=use_onnx_subfunctions,
+ prefill_only=prefill_only,
+ offload_pt_weights=offload_pt_weights,
+ enable_chunking=enable_chunking,
+ retain_full_kv=retain_full_kv,
**compiler_options,
)
@@ -3287,7 +3279,7 @@ def get_model_config(self) -> dict:
"""
return self.model.config.__dict__
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model to ONNX format using ``torch.onnx.export``.
@@ -3315,7 +3307,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
output_names,
dynamic_axes,
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
)
def compile(
@@ -3663,7 +3655,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k
def get_model_config(self) -> dict:
return self.model.config.__dict__
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Exports the model to ``ONNX`` format using ``torch.onnx.export``.
@@ -3691,7 +3683,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
output_names,
dynamic_axes,
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
)
def compile(
diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py
index 21a867eb5..b978b6193 100644
--- a/QEfficient/transformers/models/pytorch_transforms.py
+++ b/QEfficient/transformers/models/pytorch_transforms.py
@@ -197,6 +197,10 @@
Starcoder2ForCausalLM,
Starcoder2Model,
)
+from transformers.models.t5.modeling_t5 import (
+ T5Attention,
+ T5LayerNorm,
+)
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
@@ -238,6 +242,7 @@
QEffGemma3Attention,
QEffGemma3CustomRMSNormAIC,
QEffGemma3DecoderLayer,
+ QEffGemma3DecoderWrapper,
QEffGemma3ForCausalLMModel,
QEffGemma3ForConditionalGeneration,
QEffGemma3TextModel,
@@ -261,6 +266,11 @@
QEffGptOssForCausalLM,
QEffGptOssMLP,
QEffGptOssModel,
+ QEffPrefillOnlyChunkedGptOssAttention,
+ QEffPrefillOnlyChunkedGptOssMLP,
+ QEffPrefillOnlyGptOssAttention,
+ QEffPrefillOnlyGptOssMLP,
+ QEffPrefillOnlyGptOssModel,
)
from QEfficient.transformers.models.gptj.modeling_gptj import (
QEffGPTJAttention,
@@ -292,6 +302,7 @@
QEffGrok1MultiHeadAttention,
)
from QEfficient.transformers.models.internvl.modeling_internvl import (
+ QEffInternDecoderWrapper,
QEffInternVisionEmbeddings,
QEffInternVLModel,
)
@@ -303,6 +314,7 @@
QEffLlamaRotaryEmbedding,
)
from QEfficient.transformers.models.llama4.modeling_llama4 import (
+ QEffLlama4DecoderWrapper,
QEffLlama4ForCausalLM,
QEffLlama4ForConditionalGeneration,
QEffLlama4Router,
@@ -315,9 +327,11 @@
QEffLlama4VisionModel,
)
from QEfficient.transformers.models.llava.modeling_llava import (
+ QEFFLlavaDecoderWrapper,
QEffLlavaForConditionalGeneration,
)
from QEfficient.transformers.models.llava_next.modeling_llava_next import (
+ QEffLlavaNextDecoderWrapper,
QEffLlavaNextForConditionalGeneration,
)
from QEfficient.transformers.models.mistral.modeling_mistral import (
@@ -395,6 +409,7 @@
QEffQwen2_5_VLModel,
QEffQwen2_5_VLTextModel,
QEffQwen2_5_VLVisionAttention,
+ QEffQwen_2_5_vl_DecoderWrapper,
QEffQwen_2_5_vl_ForConditionalGeneration,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
@@ -417,6 +432,10 @@
QEffStarcoder2ForCausalLM,
QEffStarcoder2Model,
)
+from QEfficient.transformers.models.t5.modeling_t5 import (
+ QEffT5Attention,
+ QEffT5LayerNorm,
+)
from QEfficient.transformers.models.whisper.modeling_whisper import (
QEffWhisperAttention,
QEffWhisperDecoder,
@@ -634,6 +653,39 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
return model, transformed
+class PrefillOnlyTransform(ModuleMappingTransform):
+ _module_mapping = {
+ QEffGptOssModel: QEffPrefillOnlyGptOssModel,
+ QEffGptOssAttention: QEffPrefillOnlyGptOssAttention,
+ QEffGptOssMLP: QEffPrefillOnlyGptOssMLP,
+ }
+
+
+class PrefillOnlyChunkedTransform(ModuleMappingTransform):
+ _module_mapping = {
+ QEffGptOssModel: QEffPrefillOnlyGptOssModel,
+ QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
+ QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP,
+ }
+
+
+class RevertPrefillKeepAttentionTransform(ModuleMappingTransform):
+ _module_mapping = {
+ QEffGptOssModel: QEffPrefillOnlyGptOssModel,
+ QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
+ QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
+ QEffPrefillOnlyGptOssMLP: QEffGptOssMLP,
+ QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP,
+ }
+
+
+class RevertPrefillOnlyTransform(ModuleMappingTransform):
+ _module_mapping = {
+ **{v: k for k, v in PrefillOnlyTransform._module_mapping.items()},
+ **{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()},
+ }
+
+
class SpDTransform:
"""
Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill.
@@ -707,14 +759,20 @@ class SamplerTransform:
_module_mapping = {
QEffFalconForCausalLM,
QEffGemmaForCausalLM,
+ QEffGemma3DecoderWrapper,
QEffGPT2LMHeadModel,
QEffGPTJForCausalLM,
QEffGraniteForCausalLM,
QEffGraniteMoeForCausalLM,
+ QEffInternDecoderWrapper,
QEffLlamaForCausalLM,
+ QEffLlama4DecoderWrapper,
+ QEFFLlavaDecoderWrapper,
+ QEffLlavaNextDecoderWrapper,
QEffMptForCausalLM,
QEffPhi3ForCausalLM,
QEffQwen2ForCausalLM,
+ QEffQwen_2_5_vl_DecoderWrapper,
}
@classmethod
@@ -808,6 +866,14 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform):
_match_class_replace_method = {}
+class T5ModelTransform(ModuleMappingTransform):
+ # supported architectures
+ _module_mapping = {
+ T5Attention: QEffT5Attention,
+ T5LayerNorm: QEffT5LayerNorm,
+ }
+
+
class PoolingTransform:
"""
Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output.
diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
index 63e046600..ececa6759 100644
--- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
+++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
@@ -38,7 +38,9 @@
from QEfficient.utils import constants
from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1):
diff --git a/QEfficient/transformers/models/t5/__init__.py b/QEfficient/transformers/models/t5/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/transformers/models/t5/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/transformers/models/t5/modeling_t5.py b/QEfficient/transformers/models/t5/modeling_t5.py
new file mode 100644
index 000000000..f54201465
--- /dev/null
+++ b/QEfficient/transformers/models/t5/modeling_t5.py
@@ -0,0 +1,145 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+from transformers import EncoderDecoderCache
+from transformers.models.t5.modeling_t5 import (
+ T5Attention,
+ T5LayerNorm,
+)
+
+
+class QEffT5LayerNorm(T5LayerNorm):
+ def forward(self, hidden_states):
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
+ # half-precision inputs is done in fp32
+
+ div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32))
+ variance = div_first.pow(2).sum(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+class QEffT5Attention(T5Attention):
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ key_value_states=None,
+ position_bias=None,
+ past_key_value=None,
+ layer_head_mask=None,
+ query_length=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ # Input is (batch_size, seq_length, dim)
+ # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
+ is_cross_attention = key_value_states is not None
+
+ query_states = self.q(hidden_states)
+ query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
+ if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache):
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = past_key_value.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_value.self_attention_cache
+ else:
+ curr_past_key_value = past_key_value
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_value is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states = self.k(current_states)
+ value_states = self.v(current_states)
+ key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention:
+ past_key_value.is_updated[self.layer_idx] = True
+
+ # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
+
+ if position_bias is None:
+ key_length = key_states.shape[-2]
+ # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
+ real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(
+ real_seq_length, key_length, device=scores.device, cache_position=cache_position
+ )
+ if past_key_value is not None: # This block is where the patch applies
+ position_bias = position_bias[:, :, -1:, :] # Added by patch
+
+ if mask is not None:
+ causal_mask = mask[:, :, :, : key_states.shape[-2]]
+ position_bias = position_bias + causal_mask
+
+ if self.pruned_heads:
+ mask = torch.ones(position_bias.shape[1])
+ mask[list(self.pruned_heads)] = 0
+ position_bias_masked = position_bias[:, mask.bool()]
+ else:
+ position_bias_masked = position_bias
+
+ scores += position_bias_masked
+
+ # (batch_size, n_heads, seq_length, key_length)
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(batch_size, -1, self.inner_dim)
+ attn_output = self.o(attn_output)
+
+ outputs = (attn_output, position_bias)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py
index dfadc00ef..dc2308e99 100644
--- a/QEfficient/transformers/quantizers/__init__.py
+++ b/QEfficient/transformers/quantizers/__init__.py
@@ -5,6 +5,6 @@
#
# -----------------------------------------------------------------------------
-from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers
+from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers
-__all__ = ["replace_transformers_quantizers"]
+__all__ = ["replace_transformers_quantizers", "undo_transformers_quantizers"]
diff --git a/QEfficient/transformers/quantizers/quantizer_awq.py b/QEfficient/transformers/quantizers/quantizer_awq.py
index ef8a03521..0968bdd27 100644
--- a/QEfficient/transformers/quantizers/quantizer_awq.py
+++ b/QEfficient/transformers/quantizers/quantizer_awq.py
@@ -15,7 +15,9 @@
replace_linear_layer_with_target_layer,
replace_quantization_scales,
)
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
class QEffAwqConfig(AwqConfig):
diff --git a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py
index e7e14166d..86d7ae139 100644
--- a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py
+++ b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py
@@ -14,7 +14,9 @@
from transformers.utils.quantization_config import CompressedTensorsConfig, QuantizationConfigMixin, QuantizationMethod
from QEfficient.transformers.quantizers.quantizer_utils import get_keys_to_not_convert
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
FP8_DTYPE = torch.float8_e4m3fn
diff --git a/QEfficient/transformers/quantizers/quantizer_gptq.py b/QEfficient/transformers/quantizers/quantizer_gptq.py
index 8a0bea1a2..073f370e1 100644
--- a/QEfficient/transformers/quantizers/quantizer_gptq.py
+++ b/QEfficient/transformers/quantizers/quantizer_gptq.py
@@ -15,7 +15,9 @@
repack_zeros,
replace_linear_layer_with_target_layer,
)
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
class QEffGPTQConfig(GPTQConfig):
diff --git a/QEfficient/transformers/quantizers/quantizer_mxfp4.py b/QEfficient/transformers/quantizers/quantizer_mxfp4.py
index 2ffba1bea..bfbfe473e 100644
--- a/QEfficient/transformers/quantizers/quantizer_mxfp4.py
+++ b/QEfficient/transformers/quantizers/quantizer_mxfp4.py
@@ -14,7 +14,9 @@
from transformers.utils.quantization_config import Mxfp4Config
from QEfficient.transformers.quantizers.quantizer_utils import convert_moe_packed_tensors, get_keys_to_not_convert
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("MODEL", loglevel="INFO")
class QEffMxfp4GptOssExperts(nn.Module):
diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py
index 96846e712..5c86b6355 100644
--- a/QEfficient/transformers/sampler/sampler.py
+++ b/QEfficient/transformers/sampler/sampler.py
@@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput):
probs: torch.FloatTensor = None
next_tokens: torch.IntTensor = None
+ vision_embeds: Optional[torch.FloatTensor] = None # For VLMs
+ image_idx: Optional[torch.IntTensor] = None # for VLMs
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_repetition_penalty_buffer: Optional[torch.Tensor] = None
past_presence_penalty_buffer: Optional[torch.Tensor] = None
@@ -47,7 +49,6 @@ def prefill_path(
positions_mask = (position_ids[:, :1] != zero_tensor).view(-1, 1)
mul_value = CtxScatterFuncCB3D.apply(mul_value, batch_index, zero_tensor, positions_mask)
past_repetition_penalty_buffer *= mul_value
- past_presence_penalty_buffer *= mul_value
# Mask out-of-bounds or invalid position_ids or input_ids
input_ids = torch.where(position_ids == -1, torch.iinfo(torch.int32).max, input_ids)
@@ -59,6 +60,9 @@ def prefill_path(
input_ids,
torch.ones(input_ids.shape, dtype=torch.bool),
)
+
+ mul_value = torch.zeros(past_presence_penalty_buffer.shape[0], 1, dtype=torch.bool)
+ past_presence_penalty_buffer *= mul_value
return past_repetition_penalty_buffer, past_presence_penalty_buffer
@@ -103,6 +107,7 @@ def sampler_forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -112,6 +117,8 @@ def sampler_forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: Optional[int] = None,
+ vision_embeds: Optional[torch.FloatTensor] = None,
+ image_idx: Optional[torch.IntTensor] = None,
last_accepted_output_tokens: Optional[torch.Tensor] = None, # (batch_size, spec_length or less)
past_repetition_penalty_buffer: Optional[torch.Tensor] = None,
repetition_penalties: Optional[torch.Tensor] = None,
@@ -122,11 +129,15 @@ def sampler_forward(
top_ps: Optional[torch.Tensor] = None,
min_ps: Optional[torch.Tensor] = None,
random_numbers: Optional[torch.Tensor] = None,
+ token_bitmasks: Optional[torch.Tensor] = None,
) -> Union[Tuple, SamplerOutput]:
r"""
Perform the sampling of next tokens on the QAIC device (instead of the host)
and return the next tokens and/or probability distributions.
+ The vision_embeds and image_idx parameters are optional
+ and are used only for VLMs when supported by the original forward function.
+
Args:
last_accepted_output_tokens (`torch.Tensor`, *optional*):
Output tokens accepted by the Speculative Decoding Draft Language Model.
@@ -169,21 +180,43 @@ def sampler_forward(
random_numbers (`torch.Tensor`, *optional*):
Sampling parameter that represents the random seeds to use for random sampling.
Must be in [-1, 1].
- """
- outputs = self.old_forward(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- batch_index=batch_index,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- cache_position=cache_position,
- )
+ token_bitmasks (`torch.Tensor`, *optional*):
+ Boolean mask used to guide token-level filtering during decoding. Each
+ element of this tensor indicates whether the corresponding token should be
+ kept (1) or masked (0). Shape: (batch_size, vocab_size)
+ """
+ if vision_embeds is not None:
+ forward_kwargs = dict(
+ input_ids=input_ids,
+ vision_embeds=vision_embeds,
+ position_ids=position_ids,
+ image_idx=image_idx,
+ past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
+ )
+ if batch_index is not None:
+ forward_kwargs["batch_index"] = batch_index
+
+ logits, vision_embeds, image_idx, past_key_values = self.old_forward(**forward_kwargs)
+ outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values)
+ if position_ids.dim() == 3: # For models using m-rope
+ position_ids = position_ids[0]
+ else:
+ outputs = self.old_forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
+ batch_index=batch_index,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
logits = outputs.get("logits", None)
assert logits is not None, f"{self.model.__class__.__name__} does not return logits."
@@ -197,6 +230,13 @@ def sampler_forward(
batch_index = torch.arange(batch_size).view(-1, 1)
batch_index_reshaped = batch_index.view(-1)
+
+ # Guided decoding
+ if token_bitmasks is not None and (token_bitmasks != 1).any():
+ assert spec_length == 1, "Currently, guided decoding is not supported with Speculative Decoding"
+ # Mask logits where token_bitmasks is 0 with -inf
+ logits = torch.where(token_bitmasks == 1, logits, torch.finfo(torch.float16).min)
+
# Prefill
past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path(
input_ids=input_ids,
@@ -224,17 +264,6 @@ def sampler_forward(
is_prefill, past_presence_penalty_buffer_prefill, past_presence_penalty_buffer_decode
)
- # Greedy Sampling
- greedy_samples = torch.argmax(logits, dim=1, keepdim=True) # (batch_size * spec_length, 1)
- if (temperatures == 0).all() and not self.qaic_config.get("return_pdfs", False):
- return SamplerOutput(
- probs=None,
- next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits
- past_key_values=outputs.past_key_values,
- past_repetition_penalty_buffer=past_repetition_penalty_buffer,
- past_presence_penalty_buffer=past_presence_penalty_buffer,
- )
-
# Repetition Penalty
if (repetition_penalties != 1.0).any():
past_repetition_penalty_buffer_selected = past_repetition_penalty_buffer[batch_index_reshaped].repeat(
@@ -253,6 +282,19 @@ def sampler_forward(
) # (batch_size * spec_length, vocab_size)
logits -= presence_penalties * past_presence_penalty_buffer_selected
+ # Greedy Sampling
+ greedy_samples = torch.argmax(logits, dim=1, keepdim=True) # (batch_size * spec_length, 1)
+ if (temperatures == 0).all() and not self.qaic_config.get("return_pdfs", False):
+ return SamplerOutput(
+ probs=None,
+ next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits
+ vision_embeds=outputs.get("vision_embeds", None),
+ image_idx=outputs.get("image_idx", None),
+ past_key_values=outputs.get("past_key_values", None),
+ past_repetition_penalty_buffer=past_repetition_penalty_buffer,
+ past_presence_penalty_buffer=past_presence_penalty_buffer,
+ )
+
# TODO: Frequency Penalty
# Temperature Scaling
@@ -300,9 +342,8 @@ def sampler_forward(
) # (batch_size, spec_length, vocab_size)
# Random Sampling
- topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids)
gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick
- y = topk_probs_asc + gumbel_noise
+ y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids)
random_samples_indices = torch.argmax(y, dim=1, keepdim=True)
random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1)
@@ -314,7 +355,9 @@ def sampler_forward(
return SamplerOutput(
probs=probs,
next_tokens=next_tokens, # Return sampled next tokens instead of logits
- past_key_values=outputs.past_key_values,
+ vision_embeds=outputs.get("vision_embeds", None),
+ image_idx=outputs.get("image_idx", None),
+ past_key_values=outputs.get("past_key_values", None),
past_repetition_penalty_buffer=past_repetition_penalty_buffer,
past_presence_penalty_buffer=past_presence_penalty_buffer,
)
diff --git a/QEfficient/transformers/transform.py b/QEfficient/transformers/transform.py
index 11d7c1dfd..ecbc16a38 100644
--- a/QEfficient/transformers/transform.py
+++ b/QEfficient/transformers/transform.py
@@ -13,7 +13,9 @@
from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_utils import TransformersToQEffModulesDict
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
def replace_module_with_qeff_layers(model: nn.Module) -> None:
diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py
index 49f0ad30b..3d6583f85 100755
--- a/QEfficient/utils/__init__.py
+++ b/QEfficient/utils/__init__.py
@@ -16,7 +16,6 @@
create_model_params,
custom_format_warning,
dump_qconfig,
- export_wrapper,
generate_mdp_partition_config,
get_num_layers_from_config,
get_num_layers_vlm,
diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py
index 131a7fc26..21d60fa1c 100644
--- a/QEfficient/utils/_utils.py
+++ b/QEfficient/utils/_utils.py
@@ -12,7 +12,6 @@
import subprocess
import xml.etree.ElementTree as ET
from dataclasses import dataclass
-from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import requests
@@ -27,10 +26,11 @@
PreTrainedTokenizerFast,
)
-from QEfficient.utils.cache import QEFF_HOME
from QEfficient.utils.constants import KWARGS_INCLUSION_LIST, QEFF_MODELS_DIR, Constants, QnnConstants
-from QEfficient.utils.hash_utils import create_export_hash, json_serializable
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.hash_utils import json_serializable
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
class LRUCache:
@@ -532,61 +532,11 @@ def create_model_params(qeff_model, **kwargs) -> Dict:
"""
model_params = copy.deepcopy(kwargs)
model_params = {k: v for k, v in model_params.items() if k in KWARGS_INCLUSION_LIST}
- model_params["config"] = qeff_model.model.config.to_diff_dict()
model_params["peft_config"] = getattr(qeff_model.model, "active_peft_config", None)
model_params["applied_transform_names"] = qeff_model._transform_names()
return model_params
-def export_wrapper(func):
- def wrapper(self, *args, **kwargs):
- export_dir = kwargs.get("export_dir", None)
- parent_dir = self.model_architecture or self.model_name
- export_dir = Path(export_dir or (QEFF_HOME / parent_dir / self.model_name))
-
- # PREPROCESSING OF PARAMETERS
-
- # Get the original signature
- original_sig = inspect.signature(func)
-
- # Remove 'self' from parameters
- params = list(original_sig.parameters.values())[1:] # skip 'self'
- new_sig = inspect.Signature(params)
-
- # Bind args and kwargs to the new signature
- bound_args = new_sig.bind(*args, **kwargs)
- bound_args.apply_defaults()
-
- # Get arguments as a dictionary
- all_args = bound_args.arguments
-
- export_hash, filtered_hash_params = create_export_hash(
- model_params=self.hash_params,
- output_names=all_args.get("output_names"),
- dynamic_axes=all_args.get("dynamic_axes"),
- export_kwargs=all_args.get("export_kwargs", None),
- onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None),
- use_onnx_subfunctions=all_args.get("use_onnx_subfunctions", False),
- )
-
- export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)
- kwargs["export_dir"] = export_dir
- self.export_hash = export_hash
-
- # _EXPORT CALL
- onnx_path = func(self, *args, **kwargs)
-
- # POST-PROCESSING
- # Dump JSON file with hashed parameters
- hashed_params_export_path = export_dir / "hashed_export_params.json"
- create_json(hashed_params_export_path, filtered_hash_params)
- logger.info("Hashed parameters exported successfully.")
-
- return onnx_path
-
- return wrapper
-
-
def execute_command(process: str, command: str, output_file_path: Optional[str] = None):
"""
Executes the give command using subprocess.
diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py
index e0b003422..613d7049a 100644
--- a/QEfficient/utils/constants.py
+++ b/QEfficient/utils/constants.py
@@ -144,6 +144,13 @@ def get_models_dir():
# Molmo Constants
MOLMO_IMAGE_HEIGHT = 536
MOLMO_IMAGE_WIDTH = 354
+# Flux Transformer Constants
+FLUX_ONNX_EXPORT_SEQ_LENGTH = 256
+FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM = 4096
+FLUX_ADALN_HIDDEN_DIM = 3072
+FLUX_ADALN_DUAL_BLOCK_CHUNKS = 12 # 6 chunks for norm1 + 6 chunks for norm1_context
+FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3
+FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM
class Constants:
diff --git a/QEfficient/utils/device_utils.py b/QEfficient/utils/device_utils.py
index a76dfae8a..72e5b423e 100644
--- a/QEfficient/utils/device_utils.py
+++ b/QEfficient/utils/device_utils.py
@@ -10,7 +10,9 @@
import subprocess
from QEfficient.utils.constants import Constants
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
def is_networks_loaded(stdout):
diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py
new file mode 100644
index 000000000..0babacb2a
--- /dev/null
+++ b/QEfficient/utils/export_utils.py
@@ -0,0 +1,221 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import copy
+import inspect
+import re
+import warnings
+from pathlib import Path
+from typing import Dict
+
+from QEfficient.base.onnx_transforms import CustomOpTransform, RenameFunctionOutputsTransform
+from QEfficient.transformers.cache_utils import InvalidIndexProvider
+from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
+from QEfficient.utils.cache import QEFF_HOME
+from QEfficient.utils.hash_utils import create_export_hash
+from QEfficient.utils.logging_utils import QEFFLogger
+from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
+
+
+def export_wrapper(func):
+ """
+ Decorator for export methods that orchestrates the complete export lifecycle.
+
+ Responsibilities:
+ 1. Prepare export directory structure
+ 2. Generate reproducible hash for export configuration
+ 3. Setup ONNX subfunction environment (if enabled)
+ 4. Execute the wrapped export function
+ 5. Cleanup subfunction environment (if enabled)
+ 6. Save export metadata
+
+ Args:
+ func: The export method to wrap (typically _export)
+
+ Returns:
+ Wrapped function with complete export lifecycle management
+ """
+
+ def wrapper(self, *args, **kwargs):
+ # 1. Setup ONNX subfunctions if requested
+ if use_onnx_subfunctions := kwargs.pop("use_onnx_subfunctions", False):
+ args, kwargs = _setup_onnx_subfunctions(self, args, kwargs)
+
+ # 2. Prepare export directory
+ export_dir = _prepare_export_directory(self, kwargs)
+
+ # 3. Generate hash and finalize export directory path
+ export_hash, filtered_hash_params = _generate_export_hash(self, args, kwargs, func)
+ export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)
+ kwargs["export_dir"] = export_dir
+ self.export_hash = export_hash
+
+ # 4. Execute the actual export
+ onnx_path = func(self, *args, **kwargs)
+
+ # 5. Save export metadata
+ _save_export_metadata(export_dir, filtered_hash_params)
+
+ # 6. Always cleanup subfunctions if they were setup
+ if use_onnx_subfunctions:
+ _cleanup_onnx_subfunctions(self)
+
+ return onnx_path
+
+ return wrapper
+
+
+def _prepare_export_directory(qeff_model, kwargs) -> Path:
+ """
+ Prepare and return the base export directory path.
+
+ Args:
+ qeff_model: The QEff model instance
+ kwargs: Keyword arguments containing optional export_dir
+
+ Returns:
+ Path object for the base export directory
+ """
+ export_dir = kwargs.get("export_dir", None)
+ parent_dir = qeff_model.model_architecture or qeff_model.model_name
+ return Path(export_dir or (QEFF_HOME / parent_dir / qeff_model.model_name))
+
+
+def _generate_export_hash(qeff_model, args, kwargs, func):
+ """
+ Generate export hash from model parameters and export arguments.
+
+ The hash ensures reproducibility and prevents conflicts between
+ different export configurations.
+
+ Args:
+ qeff_model: The QEff model instance
+ args: Positional arguments to the export function
+ kwargs: Keyword arguments to the export function
+ func: The export function being wrapped
+
+ Returns:
+ Tuple of (export_hash: str, filtered_hash_params: dict)
+ """
+ # Extract function signature
+ original_sig = inspect.signature(func)
+ params = list(original_sig.parameters.values())[1:] # Skip 'self'
+ new_sig = inspect.Signature(params)
+ # Bind all arguments
+ bound_args = new_sig.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+ all_args = bound_args.arguments
+
+ # Use the model's current configuration for hashing to ensure any post-load modifications are captured
+ # TODO: Replace with get_model_config property of modeling classes and remove the if-else
+ # Determine the config dict to use, preferring .to_diff_dict() if available
+ if hasattr(qeff_model.model, "config") and hasattr(qeff_model.model.config, "to_diff_dict"):
+ config_val = qeff_model.model.config.to_diff_dict()
+ elif hasattr(qeff_model.model, "model") and hasattr(qeff_model.model.model.config, "to_diff_dict"):
+ config_val = qeff_model.model.model.config.to_diff_dict()
+ else:
+ config_val = qeff_model.model.config
+
+ copy_of_hash_params = copy.deepcopy(qeff_model.hash_params)
+ copy_of_hash_params.update(
+ {
+ "config": config_val,
+ }
+ )
+ # Generate hash from relevant parameters
+ export_hash, filtered_hash_params = create_export_hash(
+ model_params=copy_of_hash_params,
+ output_names=all_args.get("output_names"),
+ dynamic_axes=all_args.get("dynamic_axes"),
+ export_kwargs=all_args.get("export_kwargs", None),
+ onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None),
+ )
+
+ return export_hash, filtered_hash_params
+
+
+def _setup_onnx_subfunctions(qeff_model, args, kwargs):
+ """
+ Setup ONNX subfunction export environment.
+
+ This function prepares the model and environment for exporting with
+ ONNX subfunctions enabled. It:
+ - Applies necessary torch patches
+ - Modifies output names for subfunction compatibility
+ - Adds subfunction-specific ONNX transforms
+ - Updates export kwargs with module classes
+
+ Args:
+ qeff_model: The QEff model instance
+ kwargs: Export keyword arguments (modified in-place).
+ """
+ warnings.warn(
+ "The subfunction feature is experimental. Please note that using compile "
+ "consecutively with and without subfunction may produce inconsistent results."
+ )
+
+ # Apply torch patches for subfunction support
+ apply_torch_patches()
+ InvalidIndexProvider.SUBFUNC_ENABLED = True
+ # Transform output names for subfunction compatibility
+ if "output_names" in kwargs:
+ kwargs["output_names"] = [
+ re.sub("_RetainedState", "_InternalRetainedState", name) for name in kwargs["output_names"]
+ ]
+ else:
+ args = list(args)
+ args[1] = [re.sub("_RetainedState", "_InternalRetainedState", name) for name in args[1]]
+ args = tuple(args)
+ # Add subfunction-specific ONNX transforms
+ qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform)
+ qeff_model._onnx_transforms.append(CustomOpTransform)
+
+ # TODO: Handle this in the modelling class QEFFTransformersBase,remove from here. Refer diffusers implementation
+ kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model)
+ return args, kwargs
+
+
+def _cleanup_onnx_subfunctions(qeff_model):
+ """
+ Cleanup ONNX subfunction export environment.
+
+ Restores the model and environment to pre-subfunction state by:
+ - Undoing torch patches
+ - Resetting InvalidIndexProvider flag
+ - Restoring original ONNX transforms list
+
+ Args:
+ qeff_model: The QEff model instance
+
+ Note:
+ This function is called in a finally block to ensure cleanup
+ even if export fails. Errors during cleanup are logged but
+ not re-raised to avoid masking the original exception.
+ """
+ # Undo torch patches
+ undo_torch_patches()
+ InvalidIndexProvider.SUBFUNC_ENABLED = False
+ qeff_model._onnx_transforms.remove(RenameFunctionOutputsTransform)
+ qeff_model._onnx_transforms.remove(CustomOpTransform)
+
+
+def _save_export_metadata(export_dir: Path, filtered_hash_params: Dict):
+ """
+ Save export metadata to JSON file for reproducibility.
+
+ Args:
+ export_dir: Directory where the export was saved
+ filtered_hash_params: Dictionary of parameters used for hashing
+ """
+ # Import here to avoid circular dependency
+ from QEfficient.utils._utils import create_json
+
+ hashed_params_path = export_dir / "hashed_export_params.json"
+ create_json(hashed_params_path, filtered_hash_params)
+ logger.info("Hashed parameters exported successfully.")
diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py
index 948b72e6a..10e6686d0 100644
--- a/QEfficient/utils/hash_utils.py
+++ b/QEfficient/utils/hash_utils.py
@@ -14,7 +14,8 @@
def json_serializable(obj):
if isinstance(obj, set):
- return sorted(obj)
+ # Convert set to a sorted list of strings for consistent hashing
+ return sorted([cls.__name__ if isinstance(cls, type) else str(cls) for cls in obj])
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
@@ -55,8 +56,6 @@ def create_export_hash(**kwargs):
export_params = {}
export_params["output_names"] = kwargs.get("output_names")
export_params["dynamic_axes"] = kwargs.get("dynamic_axes")
- if kwargs.get("use_onnx_subfunctions"):
- export_params["use_onnx_subfunctions"] = True
export_hash_params["export_params"] = export_params
export_kwargs = kwargs.get("export_kwargs")
@@ -68,5 +67,4 @@ def create_export_hash(**kwargs):
export_hash_params.update(onnx_transform_kwargs)
if export_hash_params.get("peft_config") is not None and not isinstance(export_hash_params["peft_config"], dict):
export_hash_params["peft_config"] = export_hash_params["peft_config"].to_dict()
-
return hash_dict_params(export_hash_params), export_hash_params
diff --git a/QEfficient/utils/logging_utils.py b/QEfficient/utils/logging_utils.py
index d2086d830..f2659dff6 100644
--- a/QEfficient/utils/logging_utils.py
+++ b/QEfficient/utils/logging_utils.py
@@ -5,54 +5,332 @@
#
# -----------------------------------------------------------------------------
+import json
import logging
+import os
+import queue
+import threading
+from datetime import datetime
+from logging.handlers import RotatingFileHandler
+from typing import Any, Dict, List, Optional
+from tabulate import tabulate
-class QEffFormatter(logging.Formatter):
- """
- Formatter class used to set colors for printing different logging levels of messages on console.
+
+class JSONNamespaceFormatter(logging.Formatter):
"""
+ Custom formatter to output log records in JSON format with metadata.
+
+ Methods:
+ format(record): Formats a log record into a JSON string.
- cyan: str = "\x1b[38;5;14m"
- yellow: str = "\x1b[33;20m"
- red: str = "\x1b[31;20m"
- bold_red: str = "\x1b[31;1m"
- reset: str = "\x1b[0m"
- common_format: str = "%(levelname)s - %(name)s - %(message)s" # type: ignore
- format_with_line_info = "%(levelname)s - %(name)s - %(message)s (%(filename)s:%(lineno)d)" # type: ignore
-
- FORMATS = {
- logging.DEBUG: cyan + format_with_line_info + reset,
- logging.INFO: cyan + common_format + reset,
- logging.WARNING: yellow + common_format + reset,
- logging.ERROR: red + format_with_line_info + reset,
- logging.CRITICAL: bold_red + format_with_line_info + reset,
- }
+ Parameters:
+ record (logging.LogRecord): The log record to format.
+
+ Returns:
+ str: JSON-formatted log string.
+ """
def format(self, record):
+ log_record = {
+ "date": datetime.fromtimestamp(record.created).strftime("%Y-%m-%d"),
+ "time": datetime.fromtimestamp(record.created).strftime("%H:%M:%S"),
+ "level": record.levelname,
+ "namespace": getattr(record, "namespace", "default"),
+ "file": record.filename,
+ "line": record.lineno,
+ "message": record.getMessage(),
+ }
+ return json.dumps(log_record)
+
+
+class QEFFLoggerThread(threading.Thread):
+ """
+ Background thread to handle logging asynchronously using a queue.
+
+ Attributes:
+ logger (logging.Logger): Logger instance to handle log records.
+ log_queue (queue.Queue): Queue from which log records are consumed.
+ running (bool): Flag to control thread execution.
+ """
+
+ def __init__(self, logger, log_queue):
+ """
+ Initialize the logging thread.
+
+ Parameters:
+ logger (logging.Logger): Logger instance.
+ log_queue (queue.Queue): Queue for log records.
+ """
+ super().__init__(daemon=True)
+ self.logger = logger
+ self.log_queue = log_queue
+ self.running = True
+
+ def run(self):
+ """
+ Continuously process log records from the queue and pass them to the logger.
+ """
+ while self.running:
+ try:
+ record = self.log_queue.get(timeout=1)
+ self.logger.handle(record)
+ except queue.Empty:
+ continue
+
+ def stop(self):
"""
- Overriding the base class method to Choose format based on log level.
+ Stop the logging thread gracefully.
"""
- log_fmt = self.FORMATS.get(record.levelno)
- formatter = logging.Formatter(log_fmt)
- return formatter.format(record)
+ self.running = False
-def create_logger() -> logging.Logger:
+class QEFFLogger:
"""
- Creates a logger object with Colored QEffFormatter.
+ Singleton logger class for structured logging with namespace support.
+
+ Class Attributes:
+ _instance (Optional[logging.Logger]): Singleton logger instance.
+ _logfile (Optional[str]): Path to the log file.
+ _log_queue (queue.Queue): Queue for asynchronous logging.
+ _logger_thread (Optional[QEFFLoggerThread]): Background logging thread.
"""
- logger = logging.getLogger("QEfficient")
- # create console handler and set level to debug
- ch = logging.StreamHandler()
- ch.setLevel(logging.INFO)
- # define formatter
- ch.setFormatter(QEffFormatter())
+ _instance: Optional[logging.Logger] = None
+ _logfile: Optional[str] = None
+ _log_queue: queue.Queue = queue.Queue()
+ _logger_thread: Optional[QEFFLoggerThread] = None
+
+ def __init__(self, loglevel: Optional[str] = "INFO", log_path: Optional[str] = None):
+ """
+ Initialize the logger instance with specified log level and path.
+
+ Parameters:
+ loglevel (str): Logging level (e.g., "INFO", "DEBUG").
+ log_path (str): Optional path to the log file.
+ """
+ if QEFFLogger._instance is None:
+ self.loglevel = loglevel
+ self.log_path = log_path
+ self.logger = self._initialize_logger()
+ QEFFLogger._instance = self.logger
+ QEFFLogger._logger_thread = QEFFLoggerThread(self.logger, QEFFLogger._log_queue)
+ QEFFLogger._logger_thread.start()
+
+ def _initialize_logger(self) -> logging.Logger:
+ """
+ Set up the logger with rotating file handler and JSON formatter.
+
+ Returns:
+ logging.Logger: Configured logger instance.
+ """
+ if self.log_path is None:
+ log_dir = os.path.expanduser("~/.cache/qefficient_logs")
+ os.makedirs(log_dir, exist_ok=True)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ self.log_path = os.path.join(log_dir, f"QEFF_{timestamp}.log")
+
+ QEFFLogger._logfile = self.log_path
+
+ numeric_level = getattr(logging, self.loglevel.upper(), None)
+ if not isinstance(numeric_level, int):
+ raise ValueError(f"Invalid log level: {self.loglevel}")
+
+ logger = logging.getLogger("QEFF_LOGGER")
+ logger.setLevel(numeric_level)
+
+ if not logger.handlers:
+ handler = RotatingFileHandler(self.log_path, maxBytes=5 * 1024 * 1024, backupCount=10)
+ handler.setFormatter(JSONNamespaceFormatter())
+ logger.addHandler(handler)
+
+ return logger
+
+ @classmethod
+ def get_logger(
+ cls, namespace: str, loglevel: Optional[str] = "INFO", log_path: Optional[str] = None
+ ) -> logging.Logger:
+ """
+ Retrieve a logger adapter with a specific namespace.
+
+ Parameters:
+ namespace (str): Logical grouping for the log.
+ loglevel (str): Logging level.
+ log_path (str): Optional path to the log file.
+
+ Returns:
+ logging.Logger: Logger adapter with namespace.
+ """
+ if cls._instance is None:
+ cls(loglevel, log_path)
+ return logging.LoggerAdapter(cls._instance, {"namespace": namespace})
+
+ @classmethod
+ def log(cls, level: str, namespace: str, msg: str, fn: str = "", lno: int = 0, func: str = ""):
+ """
+ Log a message with specified level and metadata.
+
+ Parameters:
+ level (str): Logging level (e.g., "INFO", "ERROR").
+ namespace (str): Logical grouping for the log.
+ msg (str): Log message.
+ fn (str): Filename where the log is generated.
+ lno (int): Line number in the file.
+ func (str): Function name.
+ """
+ if cls._instance is None:
+ raise RuntimeError("Logger has not been initialized. Call get_logger() first.")
+
+ level_num = getattr(logging, level.upper(), None)
+ if not isinstance(level_num, int):
+ raise ValueError(f"Invalid log level: {level}")
+
+ record = cls._instance.makeRecord(
+ name="QEFF_LOGGER",
+ level=level_num,
+ fn=fn,
+ lno=lno,
+ msg=msg,
+ args=(),
+ exc_info=None,
+ func=func,
+ extra={"namespace": namespace},
+ )
+ cls._log_queue.put(record)
+
+ @classmethod
+ def set_loglevel(cls, loglevel: Optional[str] = "INFO"):
+ """
+ Update the log level of the logger.
+
+ Parameters:
+ loglevel (str): New log level to set.
+ """
+ if cls._instance is None:
+ raise RuntimeError("Logger has not been initialized yet. Call get_logger() first.")
+
+ numeric_level = getattr(logging, loglevel.upper(), None)
+ if not isinstance(numeric_level, int):
+ raise ValueError(f"Invalid log level: {loglevel}")
+
+ cls._instance.setLevel(numeric_level)
+
+ @classmethod
+ def close_logger(cls):
+ """
+ Gracefully shut down the logger and its thread.
+ """
+ if cls._logger_thread:
+ cls._logger_thread.stop()
+ cls._logger_thread.join()
+ cls._logger_thread = None
+
+ if cls._instance:
+ handlers = cls._instance.handlers[:]
+ for handler in handlers:
+ handler.close()
+ cls._instance.removeHandler(handler)
+ cls._instance = None
+ cls._logfile = None
+
+ @classmethod
+ def _parse_dt(cls, date_str: str, time_str: str) -> datetime:
+ """Parse 'YYYY-MM-DD' and 'HH:MM:SS' into a datetime."""
+ return datetime.strptime(f"{date_str} {time_str}", "%Y-%m-%d %H:%M:%S")
+
+ @classmethod
+ def print_table(cls) -> None:
+ """
+ Parse the line-delimited JSON log in cls._logfile and print timing table with t1 as baseline (0.0s):
+ - Model Loading : t2 - t1
+ - Model Exporting : t3 - t2
+ - Model Compilation : t4 - t3
+ - Text Generation : t5 - t4
+ - Total Time : t5 - t1
+
+ Milestones (matched to your log sample):
+ t1: first log line timestamp (baseline)
+ t2: "PyTorch export successful"
+ t3: "Transformed ONNX saved"
+ t4: "Model compilation is finished and saved"
+ t5: "Text Generated finised"
+ If t5 is missing, we fall back to "specialization_file_path" as readiness marker.
+ """
+ path = cls._logfile
+ if not path:
+ raise FileNotFoundError("Log file path is not set (cls._logfile is None).")
+ if not os.path.exists(path):
+ raise FileNotFoundError(f"Log file does not exist: {path}")
+
+ t_start: Optional[datetime] = None
+ t_export_done: Optional[datetime] = None
+ t_onnx_saved: Optional[datetime] = None
+ t_compile_done: Optional[datetime] = None
+ t_text_done: Optional[datetime] = None
+ t_text_ready: Optional[datetime] = None
+
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ try:
+ rec: Dict[str, Any] = json.loads(line)
+ except json.JSONDecodeError:
+ continue
+
+ date_str = rec.get("date")
+ time_str = rec.get("time")
+ msg = rec.get("message", "")
+ if not date_str or not time_str:
+ continue
+
+ ts = cls._parse_dt(date_str, time_str)
+
+ if t_start is None:
+ t_start = ts
+
+ if ("PyTorch export successful" in msg) and (t_export_done is None):
+ t_export_done = ts
+
+ if ("Transformed ONNX saved" in msg) and (t_onnx_saved is None):
+ t_onnx_saved = ts
+
+ if ("Model compilation is finished and saved" in msg) and (t_compile_done is None):
+ t_compile_done = ts
+
+ if ("Text Generated finised" in msg) and (t_text_done is None):
+ t_text_done = ts
+
+ if ("specialization_file_path" in msg) and (t_text_ready is None):
+ t_text_ready = ts
+
+ if t_start is None:
+ raise ValueError("Could not determine start time (no valid log lines with date/time).")
+
+ if t_text_done is None:
+ t_text_done = t_text_ready
+
+ t_export_done = t_export_done or t_start
+ t_onnx_saved = t_onnx_saved or t_export_done
+ t_compile_done = t_compile_done or t_onnx_saved
+ t_text_done = t_text_done or t_compile_done
+
+ def to_offset_seconds(t: datetime) -> float:
+ return (t - t_start).total_seconds()
- logger.addHandler(ch)
- return logger
+ o1 = 0.0
+ o2 = to_offset_seconds(t_export_done)
+ o3 = to_offset_seconds(t_onnx_saved)
+ o4 = to_offset_seconds(t_compile_done)
+ o5 = to_offset_seconds(t_text_done)
+ timing_data: List[List[Any]] = [
+ ["Model Loading", max(0.0, o2 - o1)],
+ ["Model Exporting", max(0.0, o3 - o2)],
+ ["Model Compilation", max(0.0, o4 - o3)],
+ ["Text Generation", max(0.0, o5 - o4)],
+ ["Total Time", max(0.0, o5 - o1)],
+ ]
-# Define the logger object that can be used for logging purposes throughout the module.
-logger = create_logger()
+ print(tabulate(timing_data, headers=["Step", "Time (s)"], tablefmt="github", floatfmt=".3f"))
diff --git a/QEfficient/utils/sampler_utils.py b/QEfficient/utils/sampler_utils.py
index 6fb1b326f..847266ae8 100644
--- a/QEfficient/utils/sampler_utils.py
+++ b/QEfficient/utils/sampler_utils.py
@@ -5,13 +5,20 @@
#
# -----------------------------------------------------------------------------
-from typing import Optional, Set
+from typing import Dict, List, Optional, Set
+import torch
+
+from QEfficient.utils import constants
from QEfficient.utils.constants import Constants
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
-def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[bool] = None) -> bool:
+def validate_sampler_inputs(
+ session_inputs: Set[str], include_sampler: Optional[bool] = None, include_guided_decoding: Optional[bool] = None
+) -> bool:
"""
Validates whether the `QAICInferenceSession` inputs match inputs required for on-device sampling.
@@ -28,7 +35,7 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[
ValueError if partial support is detected or if user intent conflicts with QPC capabilities.
"""
- sampler_inputs = Constants.SAMPLER_INPUTS
+ sampler_inputs = Constants.SAMPLER_INPUTS | ({"token_bitmasks"} if include_guided_decoding else set())
count = len(sampler_inputs & session_inputs)
session_includes_sampler = True
@@ -56,3 +63,92 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[
)
return session_includes_sampler
+
+
+def get_sampling_inputs_and_outputs(
+ example_inputs: Dict[str, torch.Tensor],
+ output_names: List[str],
+ dynamic_axes: Dict[str, Dict[int, str]],
+ continuous_batching: bool,
+ vocab_size: int,
+ qaic_config: Dict,
+):
+ """
+ Updates the example inputs, output names, and dynamic axes to include
+ parameters relevant for on-device sampling during ONNX export.
+
+ Parameters
+ ----------
+ example_inputs : Dict[str, torch.Tensor]
+ Current dictionary of example inputs.
+ output_names : List[str]
+ Current list of output names.
+ dynamic_axes : Dict[str, Dict[int, str]]
+ Current dictionary of dynamic axes configurations.
+ continuous_batching : bool
+ Whether this model will be used for continuous batching in the future.
+ vocab_size: int
+ Vocabulary size for this model.
+ qaic_config : Dict
+ QAIC config dictionary.
+
+ Returns
+ -------
+ Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
+ Updated example inputs, output names, and dynamic axes including
+ sampling-related parameters.
+ """
+ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+ fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
+ seq_len: int = example_inputs["input_ids"].shape[-1]
+
+ example_inputs["last_accepted_output_tokens"] = torch.zeros((bs, seq_len), dtype=torch.int64)
+ dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}
+
+ example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
+ (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool
+ )
+ dynamic_axes["past_repetition_penalty_buffer"] = {
+ 0: "full_batch_size" if continuous_batching else "batch_size",
+ }
+ output_names.append("past_repetition_penalty_buffer_RetainedState")
+
+ example_inputs["repetition_penalties"] = (
+ torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
+ )
+ dynamic_axes["repetition_penalties"] = {0: "batch_size"}
+
+ example_inputs["past_presence_penalty_buffer"] = torch.zeros(
+ (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool
+ )
+ dynamic_axes["past_presence_penalty_buffer"] = {
+ 0: "full_batch_size" if continuous_batching else "batch_size",
+ }
+ output_names.append("past_presence_penalty_buffer_RetainedState")
+
+ example_inputs["presence_penalties"] = (
+ torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
+ )
+ dynamic_axes["presence_penalties"] = {0: "batch_size"}
+
+ example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
+ dynamic_axes["temperatures"] = {0: "batch_size"}
+
+ max_top_k_ids = qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
+ example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
+ dynamic_axes["top_ks"] = {0: "batch_size"}
+
+ example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
+ dynamic_axes["top_ps"] = {0: "batch_size"}
+
+ example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
+ dynamic_axes["min_ps"] = {0: "batch_size"}
+
+ example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
+ dynamic_axes["random_numbers"] = {0: "batch_size"}
+
+ if qaic_config.get("include_guided_decoding", False):
+ example_inputs["token_bitmasks"] = torch.zeros((bs, vocab_size), dtype=torch.bool)
+ dynamic_axes["token_bitmasks"] = {0: "batch_size"}
+
+ return example_inputs, output_names, dynamic_axes
diff --git a/docs/image/girl_laughing.png b/docs/image/girl_laughing.png
new file mode 100644
index 000000000..9e58da61d
Binary files /dev/null and b/docs/image/girl_laughing.png differ
diff --git a/examples/diffusers/flux/README.md b/examples/diffusers/flux/README.md
new file mode 100644
index 000000000..2a3c1605f
--- /dev/null
+++ b/examples/diffusers/flux/README.md
@@ -0,0 +1,243 @@
+# FLUX.1-schnell Image Generation Examples
+
+This directory contains examples demonstrating how to use the QEffFluxPipeline to generate images using the FLUX.1-schnell model from Black Forest Labs.
+
+## Overview
+
+FLUX.1-schnell is a fast, distilled version of the FLUX.1 text-to-image model optimized for speed with minimal quality loss. These examples show how to leverage Qualcomm Cloud AI 100 acceleration for efficient image generation.
+
+## Files
+
+- **`flux_1_schnell.py`** - Basic example showing simple image generation
+- **`flux_1_shnell_custom.py`** - Advanced example with customization options
+- **`flux_config.json`** - Configuration file for pipeline modules
+
+## Quick Start
+
+### Basic Usage
+
+The simplest way to generate images with FLUX.1-schnell:
+
+```python
+from QEfficient import QEffFluxPipeline
+import torch
+
+# Initialize pipeline
+pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+
+# Generate image
+output = pipeline(
+ prompt="A laughing girl",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+
+# Save image
+output.images[0].save("girl_laughing.png")
+```
+
+Run the basic example:
+```bash
+python flux_1_schnell.py
+```
+
+## Advanced Customization
+
+The `flux_1_shnell_custom.py` example demonstrates several advanced features:
+
+### 1. Custom Model Components
+
+You can provide custom text encoders, transformers, and tokenizers:
+
+```python
+pipeline = QEffFluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ text_encoder=custom_text_encoder,
+ transformer=custom_transformer,
+ tokenizer=custom_tokenizer,
+)
+```
+
+### 2. Custom Scheduler
+
+Replace the default scheduler with your own:
+
+```python
+pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config)
+```
+
+### 3. Reduce Model Layers for Faster Inference
+
+Trade quality for speed by reducing transformer blocks:
+
+```python
+original_blocks = pipeline.transformer.model.transformer_blocks
+org_single_blocks = pipeline.transformer.model.single_transformer_blocks
+pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]])
+pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]])
+pipeline.transformer.model.config['num_layers'] = 1
+pipeline.transformer.model.config['num_single_layers'] = 1
+```
+
+### 4. Pre-compile with Custom Configuration
+
+Compile the model separately before generation:
+
+```python
+pipeline.compile(
+ compile_config="examples/diffusers/flux/flux_config.json",
+ height=512,
+ width=512,
+ use_onnx_subfunctions=False
+)
+```
+
+### 5. Runtime Configuration
+
+Use custom configuration during generation:
+
+```python
+output = pipeline(
+ prompt="A girl laughing",
+ custom_config_path="examples/diffusers/flux/flux_config.json",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+```
+
+Run the advanced example:
+```bash
+python flux_1_shnell_custom.py
+```
+
+## Configuration File
+
+The `flux_config.json` file controls compilation and execution settings for each pipeline module:
+
+### Module Structure
+
+The configuration includes four main modules:
+
+1. **text_encoder** (CLIP) - Encodes text prompts (77 token sequence)
+2. **text_encoder_2** (T5) - Secondary text encoder (256 token sequence)
+3. **transformer** - Core diffusion transformer model
+4. **vae_decoder** - Decodes latents to images
+
+### Configuration Parameters
+
+Each module has three sections:
+
+#### Specializations
+- `batch_size`: Batch size for inference
+- `seq_len`: Sequence length for text encoders
+- `steps`: Number of inference steps (transformer only)
+- `channels`: Number of channels (VAE decoder only)
+
+#### Compilation
+- `onnx_path`: Path to pre-exported ONNX model (null for auto-export)
+- `compile_dir`: Directory for compiled artifacts (null for auto-generation)
+- `mdp_ts_num_devices`: Number of devices for model data parallelism
+- `mxfp6_matmul`: Enable MXFP6 quantization for matrix multiplication
+- `convert_to_fp16`: Convert model to FP16 precision
+- `aic_num_cores`: Number of AI cores to use
+- `mos`: Multi-output streaming (transformer only)
+- `mdts-mos`: Multi-device tensor slicing with MOS (transformer only)
+- `aic-enable-depth-first`: Enable depth-first compilation (VAE only)
+
+#### Execute
+- `device_ids`: List of device IDs to use (null for auto-selection)
+
+### Example Configuration Snippet
+
+```json
+{
+ "transformer": {
+ "specializations": {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation": {
+ "mdp_ts_num_devices": 4,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute": {
+ "device_ids": null
+ }
+ }
+}
+```
+
+## Key Parameters
+
+### Generation Parameters
+
+- **`prompt`** (str): Text description of the image to generate
+- **`height`** (int): Output image height in pixels (default: 1024)
+- **`width`** (int): Output image width in pixels (default: 1024)
+- **`guidance_scale`** (float): Classifier-free guidance scale (0.0 for schnell)
+- **`num_inference_steps`** (int): Number of denoising steps (4 recommended for schnell)
+- **`max_sequence_length`** (int): Maximum text sequence length (256 recommended)
+- **`generator`** (torch.Generator): Random seed for reproducibility
+- **`parallel_compile`** (bool): Enable parallel compilation of modules
+- **`use_onnx_subfunctions`** (bool): Enable ONNX modular export (experimental)
+
+### Performance Tuning
+
+- **Faster inference**: Reduce `num_inference_steps` or model layers
+- **Better quality**: Increase `num_inference_steps` or use full model
+- **Memory optimization**: Adjust `mdp_ts_num_devices` in config
+- **Precision trade-offs**: Toggle `mxfp6_matmul` and `convert_to_fp16`
+
+## Output
+
+The pipeline returns an output object containing:
+- `images`: List of generated PIL Image objects
+- Performance metrics (timing information)
+
+Example output:
+```python
+print(output) # Displays performance information
+image = output.images[0] # Access the generated image
+image.save("output.png") # Save to disk
+```
+
+## Hardware Requirements
+
+- Qualcomm Cloud AI 100 accelerator
+- Sufficient memory for model compilation and execution
+- Multiple devices recommended for optimal transformer performance (see `mdp_ts_num_devices`)
+
+## Notes
+
+- FLUX.1-schnell is optimized for 4-step generation with `guidance_scale=0.0`
+- The transformer module benefits most from multi-device parallelism
+- ONNX subfunctions (`use_onnx_subfunctions=True`) is experimental and may improve compile time but is not recommended for production use
+- Custom configurations allow fine-tuning for specific hardware setups
+
+## Troubleshooting
+
+- **Out of memory**: Reduce image dimensions or increase `mdp_ts_num_devices`
+- **Slow compilation**: Enable `parallel_compile=True`
+- **Quality issues**: Ensure using recommended parameters (4 steps, guidance_scale=0.0)
+- **Device errors**: Check `device_ids` in config or set to `null` for auto-selection
+
+## References
+
+- [FLUX.1 Model Card](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
+- [QEfficient Documentation](../../../README.md)
+- [Diffusers Pipeline Guide](../../README.md)
diff --git a/examples/diffusers/flux/flux_1_schnell.py b/examples/diffusers/flux/flux_1_schnell.py
new file mode 100644
index 000000000..46f26bb6b
--- /dev/null
+++ b/examples/diffusers/flux/flux_1_schnell.py
@@ -0,0 +1,45 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+FLUX.1-schnell Image Generation Example
+
+This example demonstrates how to use the QEffFluxPipeline to generate images
+using the FLUX.1-schnell model from Black Forest Labs. FLUX.1-schnell is a
+fast, distilled version of the FLUX.1 text-to-image model optimized for
+speed with minimal quality loss.
+"""
+
+import torch
+
+from QEfficient import QEffFluxPipeline
+
+# Initialize the FLUX.1-schnell pipeline from pretrained weights
+pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+
+# Generate an image from a text prompt
+# use_onnx_subfunctions=True enables ONNX-based optimizations for faster compilation
+output = pipeline(
+ prompt="A laughing girl",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+
+# Extract the generated image from the output
+image = output.images[0]
+
+# Save the generated image to disk
+image.save("girl_laughing.png")
+
+# Print the output object (contains perf info)
+print(output)
diff --git a/examples/diffusers/flux/flux_1_shnell_custom.py b/examples/diffusers/flux/flux_1_shnell_custom.py
new file mode 100644
index 000000000..201ebe659
--- /dev/null
+++ b/examples/diffusers/flux/flux_1_shnell_custom.py
@@ -0,0 +1,113 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+FLUX.1 Schnell Custom Configuration Example
+
+This example demonstrates how to customize the FLUX.1 model with various options:
+1. Custom image dimensions (height/width)
+2. Custom transformer model and text encoder
+3. Custom scheduler configuration
+4. Reduced model layers for faster inference
+5. Custom compilation settings
+6. Custom runtime configuration via JSON config file
+
+Use this example to learn how to fine-tune FLUX.1 for your specific needs.
+"""
+
+import torch
+
+from QEfficient import QEffFluxPipeline
+
+# ============================================================================
+# PIPELINE INITIALIZATION WITH CUSTOM PARAMETERS
+# ============================================================================
+
+# Option 1: Basic initialization with default parameters
+pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+# Option 2: Advanced initialization with custom modules
+# Uncomment and modify to use your own custom components:
+#
+# pipeline = QEffFluxPipeline.from_pretrained(
+# "black-forest-labs/FLUX.1-schnell",
+# text_encoder=custom_text_encoder, # Your custom CLIP text encoder
+# transformer=custom_transformer, # Your custom transformer model
+# tokenizer=custom_tokenizer, # Your custom tokenizer
+# )
+
+# ============================================================================
+# OPTIONAL: CUSTOM SCHEDULER CONFIGURATION
+# ============================================================================
+# Uncomment to use a custom scheduler (e.g., different sampling methods):
+#
+# pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config)
+
+# ============================================================================
+# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE
+# ============================================================================
+# Reduce the number of transformer blocks to speed up image generation.
+#
+# Trade-off: Faster inference but potentially lower image quality
+# Use case: Quick testing, prototyping, or when speed is critical
+#
+# Uncomment the following lines to use only the first transformer block:
+#
+# original_blocks = pipeline.transformer.model.transformer_blocks
+# org_single_blocks = pipeline.transformer.model.single_transformer_blocks
+# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]])
+# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]])
+# pipeline.transformer.model.config['num_layers'] = 1
+# pipeline.transformer.model.config['num_single_layers'] = 1
+
+# ============================================================================
+# OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION
+# ============================================================================
+# Pre-compile the model for optimized performance on target hardware.
+#
+# When to use:
+# - When you want to compile the model separately before generation
+# - When you need to skip image generation and only prepare the model
+#
+# NOTE-1: If compile_config is not specified, the default configuration from
+# QEfficient/diffusers/pipelines/flux/flux_config.json will be used
+#
+# NOTE-2: use_onnx_subfunctions=True enables modular ONNX export optimizations (Experimental so not recommended)
+# This feature improves export performance by breaking down the model into smaller,
+# more manageable ONNX functions, which can lead to improve compile time.
+# Uncomment to compile with a custom configuration:
+# pipeline.compile(
+# compile_config="examples/diffusers/flux/flux_config.json",
+# height=512,
+# width=512,
+# use_onnx_subfunctions=False
+# )
+
+# ============================================================================
+# IMAGE GENERATION WITH CUSTOM RUNTIME CONFIGURATION
+# ============================================================================
+# Generate an image using the configured pipeline.
+#
+# Note: Use of custom_config_path provides flexibility to set device_ids for each
+# module, so you can skip the separate pipeline.compile() step.
+
+output = pipeline(
+ prompt="A laughing girl",
+ custom_config_path="examples/diffusers/flux/flux_config.json",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+
+image = output.images[0]
+# Save the generated image to disk
+image.save("laughing_girl.png")
+print(output)
diff --git a/examples/diffusers/flux/flux_config.json b/examples/diffusers/flux/flux_config.json
new file mode 100644
index 000000000..73b92265f
--- /dev/null
+++ b/examples/diffusers/flux/flux_config.json
@@ -0,0 +1,99 @@
+{
+ "description": "Default configuration for Flux pipeline",
+
+ "modules":
+ {
+ "text_encoder":
+ {
+ "specializations":{
+ "batch_size": 1,
+ "seq_len": 77
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+
+ },
+ "text_encoder_2":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only": true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "transformer":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 4,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts-mos": 1,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "vae_decoder":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "channels": 16
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "aic-enable-depth-first": true,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ }
+ }
+}
diff --git a/examples/disagg_serving/README.md b/examples/disagg_serving/README.md
new file mode 100644
index 000000000..fcf665357
--- /dev/null
+++ b/examples/disagg_serving/README.md
@@ -0,0 +1,31 @@
+# We should be using disaggragate serving for GPTOSS model for best performance
+ - GPT-OSS model has 128/4 for 120b and 32/4 ratio of total_experts/experts_per_tok
+ - We use read all experts only once always strategy in prefill-only model
+ - And we treat weights activtions meaning read only chosen experts for decode-only model
+
+# Prefill-only model
+## Blocking default behviour when `prefill_only=True` in compile API
+ - NUM_Q_BLOCKS= set number of Q blocks in attention
+ - NUM_FFN_BLOCKS= set number of blocks in FFN
+ - ENABLE_OPT_SWA="0" or "1" to enable/disable optimized SWA. when enabled we will be using only valid KVs for given block in Attention reducing MACs
+ - prefix_caching is not supported with this mode
+
+## Chunking pass `enable_chunking=True` and `prefill_only=True` in compile API
+ - Optimized SWA i.e. reading only valid KV as per diagonal attention mask is enabled for this version by default
+ - This model can be used for prefix_caching by passing `kv_cache_batch_size=` in compile API
+
+# Decode-only model
+## Retain Sliding window length of KV for sliding window layers, default behavour when `prefill_seq_len=1` in compile API
+ - This reduces the amount of DDR used by the model
+ - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=` and optinally `kv_cache_batch_size=` if needed
+## Full KV for sliding window layers pass `retain_full_kv=True` along with `prefill_seq_len=1` in compile API
+ - This uses higher DDR as we are retaining ctx_len KV even for sliding window layers but will be reading only sliding window len kv in attention
+ - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=` and optinally `kv_cache_batch_size=` if needed
+ - This is enabled for the usecase of multi-turn chat, where we will be running prefill-> decode and then use cache of prefill as well as decode combined to again run prefill, so we want to retain full KV for sliding window layers
+
+
+NOTE:
+* decode-only model currently fails compilation with `use_onnx_subfunctions=True` so avoid using it
+* 120B model needs NPI, there are two versions of NPI one with and without subfunction both are uploaded here, pass it as `node_precision_info=`
+* It is advised to use `use_onnx_subfunctions=True` with prefill-only model, otherwise the compilation times are too high, with this the model is supposed to export and fail during compile as it needs assert sdk, so user is supposed to run this compilation manually by pasting the command printed in the error
+
diff --git a/examples/disagg_serving/gpt_oss_disagg_mode.py b/examples/disagg_serving/gpt_oss_disagg_mode.py
new file mode 100644
index 000000000..fd0d5b045
--- /dev/null
+++ b/examples/disagg_serving/gpt_oss_disagg_mode.py
@@ -0,0 +1,137 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import time
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer
+
+from QEfficient import QEFFAutoModelForCausalLM
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+
+model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32
+
+prompt = """
+Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures.
+
+As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown.
+
+The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location.
+"""
+all_outputs = []
+# Run prefill
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+PREFILL_SEQ_LEN = 256
+CTX_LEN = 256
+inputs = tokenizer(prompt, return_tensors="np", padding=True)
+position_ids = inputs["attention_mask"].sum(1, keepdims=True)
+padded_len = inputs["input_ids"].shape[1]
+num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
+padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
+
+# Initialize variables specific to request
+# Calculate the max generation length.
+max_gen_len = CTX_LEN - position_ids.max()
+generation_len = max_gen_len
+
+
+qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
+config = qeff_model.model.config
+inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+inputs.pop("token_type_ids", None)
+inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
+past_key_values = []
+for i in range(config.num_hidden_layers):
+ cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN
+ pad_shape = (1, 8, cache_len, 64)
+ past_key = torch.zeros((pad_shape), dtype=torch.float32)
+ past_value = torch.zeros((pad_shape), dtype=torch.float32)
+ pkv = (past_key, past_value)
+ past_key_values.append(pkv)
+inputs["past_key_values"] = past_key_values
+
+
+decode_qpc_path = qeff_model.compile(
+ prefill_seq_len=1,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=True,
+ mxint8_kv_cache=True,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ offload_pt_weights=False,
+)
+prefill_qpc_path = qeff_model.compile(
+ prefill_seq_len=PREFILL_SEQ_LEN,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=True,
+ mxint8_kv_cache=True,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ prefill_only=True,
+ use_onnx_subfunctions=True,
+)
+
+prefill_session = QAICInferenceSession(prefill_qpc_path)
+
+logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
+prefill_session.set_buffers({"logits": logits_out_placeholder})
+inputs.pop("past_key_values")
+inputs = {k: v.detach().numpy() for k, v in inputs.items()}
+st = time.time()
+qpc_out = prefill_session.run(inputs)
+print(f"time for prefill_run={time.time() - st} sec\n")
+
+decode_session = QAICInferenceSession(decode_qpc_path)
+decode_session.set_buffers({"logits": logits_out_placeholder})
+
+decode_inputs = {
+ "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1),
+ "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1,
+}
+print("pos_id for decodee", decode_inputs["position_ids"])
+
+all_outputs.append(decode_inputs["input_ids"][0][0])
+for i in range(config.num_hidden_layers):
+ if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window:
+ k = qpc_out[f"past_key.{i}_RetainedState"]
+ v = qpc_out[f"past_value.{i}_RetainedState"]
+ mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window
+ decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2)
+ decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2)
+ else:
+ decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
+ decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
+
+st = time.time()
+decode_out = decode_session.run(decode_inputs)
+print(f"time for first run of decode with KV as input = {time.time() - st} sec\n")
+decode_session.skip_buffers(
+ [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")]
+)
+pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1
+st = time.time()
+for i in range(generation_len - 2):
+ loop_decode_inputs = {
+ "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
+ "position_ids": pos_id,
+ }
+ all_outputs.append(loop_decode_inputs["input_ids"][0][0])
+ decode_out = decode_session.run(loop_decode_inputs)
+ pos_id += 1
+
+
+print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}")
+print(all_outputs)
+print(tokenizer.decode(all_outputs))
diff --git a/examples/disagg_serving/subfunction_120b_npi.yaml b/examples/disagg_serving/subfunction_120b_npi.yaml
new file mode 100644
index 000000000..762703d58
--- /dev/null
+++ b/examples/disagg_serving/subfunction_120b_npi.yaml
@@ -0,0 +1,27 @@
+FP32NodeInstanceNames:
+ - CustomRMSNorm_58
+ - onnx::Shape_1033777
+ - CustomRMSNorm_349
+ - hidden.127
+ - CustomRMSNorm_27448
+ - onnx::Shape_1066066
+ - CustomRMSNorm_27709
+ - hidden.131
+ - CustomRMSNorm_54808
+ - onnx::Shape_878
+ - CustomRMSNorm_55105
+ - hidden
+ - hidden_states.259
+ - Add_348
+ - Add_347
+ - onnx::Add_1034099
+ - hidden_states.267
+ - Add_27708
+ - onnx::Add_1066358
+ - Add_27707
+ - hidden_states.3
+ - Add_55104
+ - onnx::Add_1209
+ - Add_55103
+ - /model/norm/CustomRMSNorm
+ - /model/norm/CustomRMSNorm_output_0
\ No newline at end of file
diff --git a/examples/disagg_serving/without_subfunc_npi_120b.yaml b/examples/disagg_serving/without_subfunc_npi_120b.yaml
new file mode 100644
index 000000000..ec6cf034f
--- /dev/null
+++ b/examples/disagg_serving/without_subfunc_npi_120b.yaml
@@ -0,0 +1,148 @@
+FP32NodeInstanceNames:
+ - /model/layers.0/Add_1_output_0
+ - /model/layers.0/Add_output_0
+ - /model/layers.0/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.1/Add_1_output_0
+ - /model/layers.1/Add_output_0
+ - /model/layers.1/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.10/Add_1_output_0
+ - /model/layers.10/Add_output_0
+ - /model/layers.10/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.11/Add_1_output_0
+ - /model/layers.11/Add_output_0
+ - /model/layers.11/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.12/Add_1_output_0
+ - /model/layers.12/Add_output_0
+ - /model/layers.12/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.13/Add_1_output_0
+ - /model/layers.13/Add_output_0
+ - /model/layers.13/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.14/Add_1_output_0
+ - /model/layers.14/Add_output_0
+ - /model/layers.14/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.15/Add_1_output_0
+ - /model/layers.15/Add_output_0
+ - /model/layers.15/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.16/Add_1_output_0
+ - /model/layers.16/Add_output_0
+ - /model/layers.16/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.17/Add_1_output_0
+ - /model/layers.17/Add_output_0
+ - /model/layers.17/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.18/Add_1_output_0
+ - /model/layers.18/Add_output_0
+ - /model/layers.18/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.19/Add_1_output_0
+ - /model/layers.19/Add_output_0
+ - /model/layers.19/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.2/Add_1_output_0
+ - /model/layers.2/Add_output_0
+ - /model/layers.2/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.20/Add_1_output_0
+ - /model/layers.20/Add_output_0
+ - /model/layers.20/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.21/Add_1_output_0
+ - /model/layers.21/Add_output_0
+ - /model/layers.21/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.22/Add_1_output_0
+ - /model/layers.22/Add_output_0
+ - /model/layers.22/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.23/Add_1_output_0
+ - /model/layers.23/Add_output_0
+ - /model/layers.23/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.24/Add_1_output_0
+ - /model/layers.24/Add_output_0
+ - /model/layers.24/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.25/Add_1_output_0
+ - /model/layers.25/Add_output_0
+ - /model/layers.25/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.26/Add_1_output_0
+ - /model/layers.26/Add_output_0
+ - /model/layers.26/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.27/Add_1_output_0
+ - /model/layers.27/Add_output_0
+ - /model/layers.27/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.28/Add_1_output_0
+ - /model/layers.28/Add_output_0
+ - /model/layers.28/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.29/Add_1_output_0
+ - /model/layers.29/Add_output_0
+ - /model/layers.29/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.3/Add_1_output_0
+ - /model/layers.3/Add_output_0
+ - /model/layers.3/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.30/Add_1_output_0
+ - /model/layers.30/Add_output_0
+ - /model/layers.30/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.31/Add_1_output_0
+ - /model/layers.31/Add_output_0
+ - /model/layers.31/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.32/Add_1_output_0
+ - /model/layers.32/Add_output_0
+ - /model/layers.32/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.33/Add_1_output_0
+ - /model/layers.33/Add_output_0
+ - /model/layers.33/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.34/Add_1_output_0
+ - /model/layers.34/Add_output_0
+ - /model/layers.34/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.35/Add_1_output_0
+ - /model/layers.35/Add_output_0
+ - /model/norm/Add_output_0
+ - /model/layers.35/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.4/Add_1_output_0
+ - /model/layers.4/Add_output_0
+ - /model/layers.4/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.5/Add_1_output_0
+ - /model/layers.5/Add_output_0
+ - /model/layers.5/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.6/Add_1_output_0
+ - /model/layers.6/Add_output_0
+ - /model/layers.6/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.7/Add_1_output_0
+ - /model/layers.7/Add_output_0
+ - /model/layers.7/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.8/Add_1_output_0
+ - /model/layers.8/Add_output_0
+ - /model/layers.8/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.9/Add_1_output_0
+ - /model/layers.9/Add_output_0
+ - /model/layers.9/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/norm/CustomRMSNorm_output_0
+
\ No newline at end of file
diff --git a/examples/gpt_oss_disagg_mode_with_chunking.py b/examples/gpt_oss_disagg_mode_with_chunking.py
new file mode 100644
index 000000000..363e2806c
--- /dev/null
+++ b/examples/gpt_oss_disagg_mode_with_chunking.py
@@ -0,0 +1,137 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import time
+
+import numpy as np
+import torch
+from transformers import AutoConfig, AutoTokenizer
+
+from QEfficient import QEFFAutoModelForCausalLM
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+
+model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32
+
+prompt = """
+Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures.
+
+As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown.
+
+The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location.
+"""
+# Run prefill
+config = AutoConfig.from_pretrained(model_id)
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+PREFILL_SEQ_LEN = 128
+CTX_LEN = 128 * 3
+
+qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
+
+decode_qpc_path = qeff_model.compile(
+ prefill_seq_len=1,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=True,
+ mxint8_kv_cache=True,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step
+ retain_full_kv=True,
+)
+
+
+# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68
+# prefill_qpc_path = "provide path here"
+prefill_qpc_path = qeff_model.compile(
+ prefill_seq_len=PREFILL_SEQ_LEN,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=True,
+ mxint8_kv_cache=True,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ prefill_only=True,
+ enable_chunking=True,
+ use_onnx_subfunctions=True,
+)
+
+
+inputs = tokenizer(prompt, return_tensors="np", padding=True)
+position_ids = inputs["attention_mask"].sum(1, keepdims=True)
+generation_len = CTX_LEN - position_ids.max()
+padded_len = inputs["input_ids"].shape[1]
+num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
+padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
+inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+inputs.pop("token_type_ids", None)
+inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
+inputs.pop("past_key_values", None)
+inputs = {k: v.detach().numpy() for k, v in inputs.items()}
+
+
+decode_session = QAICInferenceSession(decode_qpc_path)
+prefill_session = QAICInferenceSession(prefill_qpc_path)
+
+all_outputs = []
+for i in range(num_chunks):
+ chunk_inputs = inputs.copy()
+ chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+ chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+ ins = time.time()
+ qpc_out = prefill_session.run(chunk_inputs)
+ print(f"time for this run={time.time() - ins}")
+ for i in range(config.num_hidden_layers):
+ inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
+ inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
+
+all_outputs.append(np.argmax(qpc_out["logits"]))
+decode_inputs = {
+ "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1),
+ "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1,
+}
+for i in range(config.num_hidden_layers):
+ decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
+ decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
+
+st = time.time()
+decode_out = decode_session.run(decode_inputs)
+print(f"time for first run of decode with KV as input = {time.time() - st} sec\n")
+all_outputs.append(np.argmax(decode_out["logits"]))
+pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1
+loop_decode_inputs = {
+ "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
+ "position_ids": pos_id,
+}
+
+for i in range(config.num_hidden_layers):
+ loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
+ loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]
+
+st = time.time()
+for i in range(generation_len - 2):
+ decode_out = decode_session.run(loop_decode_inputs)
+ all_outputs.append(np.argmax(decode_out["logits"]))
+ pos_id += 1
+ for i in range(config.num_hidden_layers):
+ loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
+ loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]
+
+ loop_decode_inputs.update(
+ {
+ "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
+ "position_ids": pos_id,
+ }
+ )
+ft = time.time()
+
+print(f"decode tok/sec={(generation_len - 2) / (ft - st)}")
+print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}")
diff --git a/examples/performance/on_device_sampling.py b/examples/performance/on_device_sampling.py
index 6cc72b715..da9c5b43b 100644
--- a/examples/performance/on_device_sampling.py
+++ b/examples/performance/on_device_sampling.py
@@ -21,6 +21,7 @@ def main(args, **kwargs):
include_sampler = None
return_pdfs = None
max_top_k_ids = None
+ include_guided_decoding = None
sampling_params = None
bs = args.full_batch_size if args.full_batch_size is not None else args.batch_size
if args.override_qaic_config is not None:
@@ -28,6 +29,8 @@ def main(args, **kwargs):
if include_sampler is not None:
return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true"
max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512))
+ np.random.seed(int(args.random_number))
+ include_guided_decoding = args.override_qaic_config.get("aic_include_guided_decoding", None) == "true"
sampling_params = {
"repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
"presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
@@ -36,7 +39,9 @@ def main(args, **kwargs):
"top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1),
"top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1),
"min_ps": np.array(args.min_p, dtype=np.float32).repeat(bs).reshape(-1, 1),
- "random_numbers": np.array(args.random_number, dtype=np.float32).repeat(bs).reshape(-1, 1),
+ "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=max_top_k_ids), (bs, 1)).astype(
+ np.float32
+ ),
}
qaic_config = {
k: v
@@ -44,13 +49,12 @@ def main(args, **kwargs):
"include_sampler": include_sampler,
"return_pdfs": return_pdfs,
"max_top_k_ids": max_top_k_ids,
+ "include_guided_decoding": include_guided_decoding,
}.items()
if v is not None
}
print("qaic_config:")
pprint(qaic_config)
- print("sampling_params:")
- pprint(sampling_params)
# Load model with On Device Sampler enabled
qeff_model = AutoModelForCausalLM.from_pretrained(
@@ -60,6 +64,19 @@ def main(args, **kwargs):
)
print(f"{args.model_name} optimized for AI 100 \n", qeff_model)
+ if include_guided_decoding:
+ # Ideally this should come from a logits processor like xgrammar, but for the sake of the
+ # example, we generate a random bitmask
+ sampling_params.update(
+ {
+ "token_bitmasks": np.tile(
+ np.random.choice([True, False], size=(qeff_model.model.config.vocab_size,)), (bs, 1)
+ )
+ }
+ )
+ print("sampling_params:")
+ pprint(sampling_params)
+
# Compile the model for inference
generated_qpc_path = qeff_model.compile(
prefill_seq_len=args.prompt_len,
@@ -88,6 +105,7 @@ def main(args, **kwargs):
generation_len=args.generation_len,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
+ include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)
@@ -106,14 +124,14 @@ def main(args, **kwargs):
--num-cores 16 \
--mxint8-kv-cache \
--mxfp6-matmul \
- --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \
+ --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \
--repetition-penalty 1.9 \
--presence-penalty 0.8 \
--temperature 0.67 \
- --top-k 54720 \
+ --top-k 54 \
--top-p 0.89 \
--min-p 0.6 \
- --random-number 0.26
+ --random-number 26
2. For non-continuous batching:
python3.10 examples/on_device_sampling.py \
@@ -126,14 +144,34 @@ def main(args, **kwargs):
--num-cores 16 \
--mxint8-kv-cache \
--mxfp6-matmul \
- --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \
+ --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \
+ --repetition-penalty 1.9 \
+ --presence-penalty 0.8 \
+ --temperature 0.67 \
+ --top-k 54 \
+ --top-p 0.89 \
+ --min-p 0.6 \
+ --random-number 26
+
+ 3. With guided decoding:
+ python3.10 examples/on_device_sampling.py \
+ --model-name 'meta-llama/Llama-3.1-8B' \
+ --prompt-len 128 \
+ --ctx-len 256 \
+ --generation-len 20 \
+ --full-batch-size 2 \
+ --device-group [0,1,2,3] \
+ --num-cores 16 \
+ --mxint8-kv-cache \
+ --mxfp6-matmul \
+ --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:true" \
--repetition-penalty 1.9 \
--presence-penalty 0.8 \
--temperature 0.67 \
- --top-k 54720 \
+ --top-k 54 \
--top-p 0.89 \
--min-p 0.6 \
- --random-number 0.26
+ --random-number 26
"""
parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling")
diff --git a/pyproject.toml b/pyproject.toml
index 8e179ab4a..fe0c42ec2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,9 +20,10 @@ classifiers = [
requires-python = ">=3.8,<3.11"
dependencies = [
"transformers==4.55.0",
+ "diffusers== 0.35.1",
"huggingface-hub==0.34.0",
"hf_transfer==0.1.9",
- "peft==0.13.2",
+ "peft==0.17.0",
"datasets==2.20.0",
"fsspec==2023.6.0",
"multidict==6.0.4",
diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile
index 134770638..e9925dee2 100644
--- a/scripts/Jenkinsfile
+++ b/scripts/Jenkinsfile
@@ -22,6 +22,7 @@ pipeline {
. preflight_qeff/bin/activate &&
pip install --upgrade pip setuptools &&
pip install .[test] &&
+ pip install .[diffusers] &&
pip install junitparser pytest-xdist &&
pip install librosa==0.10.2 soundfile==0.13.1 && #packages needed to load example for whisper testing
pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.22.0+cpu einops==0.8.1 && #packages to load VLMs
@@ -41,7 +42,7 @@ pipeline {
mkdir -p $PWD/Non_cli_qaic &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Non_cli_qaic &&
- pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log1.xml &&
+ pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log1.xml &&
junitparser merge tests/tests_log1.xml tests/tests_log.xml &&
deactivate"
'''
@@ -69,7 +70,7 @@ pipeline {
}
stage('QAIC MultiModal Tests') {
steps {
- timeout(time: 60, unit: 'MINUTES') {
+ timeout(time: 120, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
cd /efficient-transformers &&
@@ -86,7 +87,7 @@ pipeline {
}
stage('Inference Tests') {
steps {
- timeout(time: 60, unit: 'MINUTES') {
+ timeout(time: 120, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
#source /qnn_sdk/bin/envsetup.sh &&
@@ -162,12 +163,13 @@ pipeline {
// }
stage('Finetune CLI Tests') {
steps {
- timeout(time: 5, unit: 'MINUTES') {
+ timeout(time: 20, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
cd /efficient-transformers &&
. preflight_qeff/bin/activate &&
pip install /opt/qti-aic/integrations/torch_qaic/py310/torch_qaic-0.1.0-cp310-cp310-linux_x86_64.whl &&
+ pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0 --index-url https://download.pytorch.org/whl/cpu &&
mkdir -p $PWD/cli_qaic_finetuning &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/cli_qaic_finetuning &&
diff --git a/scripts/finetune/run_ft_model.py b/scripts/finetune/run_ft_model.py
index f5b64e717..97ffb80cb 100644
--- a/scripts/finetune/run_ft_model.py
+++ b/scripts/finetune/run_ft_model.py
@@ -14,7 +14,9 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from QEfficient.finetune.configs.training import TrainConfig
-from QEfficient.finetune.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("FT", loglevel="INFO")
# Suppress all warnings
warnings.filterwarnings("ignore")
diff --git a/scripts/perplexity_computation/calculate_perplexity.py b/scripts/perplexity_computation/calculate_perplexity.py
index e2988a0ae..04ce624ec 100644
--- a/scripts/perplexity_computation/calculate_perplexity.py
+++ b/scripts/perplexity_computation/calculate_perplexity.py
@@ -18,8 +18,9 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils.logging_utils import QEFFLogger
-logger = logging.getLogger(__name__)
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
# 1. Data Loading
diff --git a/tests/base/test_export_memory_offload.py b/tests/base/test_export_memory_offload.py
index d1b7a4653..f63b18f1a 100644
--- a/tests/base/test_export_memory_offload.py
+++ b/tests/base/test_export_memory_offload.py
@@ -27,7 +27,7 @@
@pytest.fixture
def tmp_cache(tmp_path, monkeypatch):
- monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path)
+ monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path)
yield tmp_path
diff --git a/tests/conftest.py b/tests/conftest.py
index ba0f341fe..b5037b7d0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -13,9 +13,11 @@
from transformers import AutoConfig
from QEfficient.utils.constants import QEFF_MODELS_DIR
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
from QEfficient.utils.test_utils import ModelConfig
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
+
def get_custom_model_config_dict(configs):
"""
diff --git a/tests/diffusers/diffusers_utils.py b/tests/diffusers/diffusers_utils.py
new file mode 100644
index 000000000..305116c03
--- /dev/null
+++ b/tests/diffusers/diffusers_utils.py
@@ -0,0 +1,175 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+Common utilities for diffusion pipeline testing.
+Provides essential functions for MAD validation, image validation
+hash verification, and other testing utilities.
+"""
+
+import os
+from typing import Any, Dict, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+
+class DiffusersTestUtils:
+ """Essential utilities for diffusion pipeline testing"""
+
+ @staticmethod
+ def validate_image_generation(
+ image: Image.Image, expected_size: Tuple[int, int], min_variance: float = 1.0
+ ) -> Dict[str, Any]:
+ """
+ Validate generated image properties.
+ Args:
+ image: Generated PIL Image
+ expected_size: Expected (width, height) tuple
+ min_variance: Minimum pixel variance to ensure image is not blank
+
+ Returns:
+ Dict containing validation results
+ Raises:
+ AssertionError: If image validation fails
+ """
+ # Basic image validation
+ assert isinstance(image, Image.Image), f"Expected PIL Image, got {type(image)}"
+ assert image.size == expected_size, f"Expected size {expected_size}, got {image.size}"
+ assert image.mode in ["RGB", "RGBA"], f"Unexpected image mode: {image.mode}"
+
+ # Variance check (ensure image is not blank)
+ img_array = np.array(image)
+ image_variance = float(img_array.std())
+ assert image_variance > min_variance, f"Generated image appears blank (variance: {image_variance:.2f})"
+
+ return {
+ "size": image.size,
+ "mode": image.mode,
+ "variance": image_variance,
+ "mean_pixel_value": float(img_array.mean()),
+ "min_pixel": int(img_array.min()),
+ "max_pixel": int(img_array.max()),
+ "valid": True,
+ }
+
+ @staticmethod
+ def check_file_exists(file_path: str, file_type: str = "file") -> bool:
+ """
+ Check if file exists and log result.
+ Args:
+ file_path: Path to check
+ file_type: Description of file type for logging
+ Returns:
+ bool: True if file exists
+ """
+ exists = os.path.exists(file_path)
+ status = "β
" if exists else "β"
+ print(f"{status} {file_type}: {file_path}")
+ return exists
+
+ @staticmethod
+ def print_test_header(title: str, config: Dict[str, Any]) -> None:
+ """
+ Print formatted test header with configuration details.
+
+ Args:
+ title: Test title
+ config: Test configuration dictionary
+ """
+ print(f"\n{'=' * 80}")
+ print(f"{title}")
+ print(f"{'=' * 80}")
+
+ if "model_setup" in config:
+ setup = config["model_setup"]
+ for k, v in setup.items():
+ print(f"{k} : {v}")
+
+ if "functional_testing" in config:
+ func = config["functional_testing"]
+ print(f"Test Prompt: {func.get('test_prompt', 'N/A')}")
+ print(f"Inference Steps: {func.get('num_inference_steps', 'N/A')}")
+ print(f"Guidance Scale: {func.get('guidance_scale', 'N/A')}")
+
+ print(f"{'=' * 80}")
+
+
+class MADValidator:
+ """Specialized class for MAD validation - always enabled, always reports, always fails on exceed"""
+
+ def __init__(self, tolerances: Dict[str, float] = None):
+ """
+ Initialize MAD validator.
+ MAD validation is always enabled, always reports values, and always fails if tolerance is exceeded.
+
+ Args:
+ tolerances: Dictionary of module_name -> tolerance mappings
+ """
+ self.tolerances = tolerances
+ self.results = {}
+
+ def calculate_mad(
+ self, tensor1: Union[torch.Tensor, np.ndarray], tensor2: Union[torch.Tensor, np.ndarray]
+ ) -> float:
+ """
+ Calculate Max Absolute Deviation between two tensors.
+
+ Args:
+ tensor1: First tensor (PyTorch or NumPy)
+ tensor2: Second tensor (PyTorch or NumPy)
+
+ Returns:
+ float: Maximum absolute difference between tensors
+ """
+ if isinstance(tensor1, torch.Tensor):
+ tensor1 = tensor1.detach().numpy()
+ if isinstance(tensor2, torch.Tensor):
+ tensor2 = tensor2.detach().numpy()
+
+ return float(np.max(np.abs(tensor1 - tensor2)))
+
+ def validate_module_mad(
+ self,
+ pytorch_output: Union[torch.Tensor, np.ndarray],
+ qaic_output: Union[torch.Tensor, np.ndarray],
+ module_name: str,
+ step_info: str = "",
+ ) -> bool:
+ """
+ Validate MAD for a specific module.
+ Always validates, always reports, always fails if tolerance exceeded.
+
+ Args:
+ pytorch_output: PyTorch reference output
+ qaic_output: QAIC inference output
+ module_name: Name of the module
+ step_info: Additional step information for logging
+
+ Returns:
+ bool: True if validation passed
+
+ Raises:
+ AssertionError: If MAD exceeds tolerance
+ """
+ mad_value = self.calculate_mad(pytorch_output, qaic_output)
+
+ # Always report MAD value
+ step_str = f" {step_info}" if step_info else ""
+ print(f"π {module_name.upper()} MAD{step_str}: {mad_value:.8f}")
+
+ # Always validate - fail if tolerance exceeded
+ tolerance = self.tolerances.get(module_name, 1e-2)
+ if mad_value > tolerance:
+ raise AssertionError(f"{module_name} MAD {mad_value:.6f} exceeds tolerance {tolerance:.6f}")
+
+ # Store result
+ if module_name not in self.results:
+ self.results[module_name] = []
+ self.results[module_name].append({"mad": mad_value, "step_info": step_info, "tolerance": tolerance})
+ return True
diff --git a/tests/diffusers/flux_test_config.json b/tests/diffusers/flux_test_config.json
new file mode 100644
index 000000000..7d0c17d55
--- /dev/null
+++ b/tests/diffusers/flux_test_config.json
@@ -0,0 +1,123 @@
+{
+ "model_setup": {
+ "height": 256,
+ "width": 256,
+ "num_transformer_layers": 2,
+ "num_single_layers": 2,
+ "use_onnx_subfunctions": false
+ },
+ "mad_validation": {
+ "tolerances": {
+ "clip_text_encoder": 0.1,
+ "t5_text_encoder": 5.5,
+ "transformer": 2.0,
+ "vae_decoder": 1.0
+ }
+ },
+ "pipeline_params": {
+ "test_prompt": "A cat holding a sign that says hello world",
+ "num_inference_steps": 2,
+ "guidance_scale": 0.0,
+ "max_sequence_length": 256,
+ "validate_gen_img": true,
+ "min_image_variance": 1.0,
+ "custom_config_path": null
+ },
+ "validation_checks": {
+ "image_generation": true,
+ "onnx_export": true,
+ "compilation": true
+ },
+ "modules":
+ {
+ "text_encoder":
+ {
+ "specializations":{
+ "batch_size": 1,
+ "seq_len": 77
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+
+ },
+ "text_encoder_2":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "transformer":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts-mos": 1,
+ "aic-enable-depth-first": true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "vae_decoder":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "channels": 16
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ }
+ }
+
+}
diff --git a/tests/diffusers/test_flux.py b/tests/diffusers/test_flux.py
new file mode 100644
index 000000000..6f4396a20
--- /dev/null
+++ b/tests/diffusers/test_flux.py
@@ -0,0 +1,448 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import os
+import time
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import pytest
+import torch
+from diffusers import FluxPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+
+from QEfficient import QEffFluxPipeline
+from QEfficient.diffusers.pipelines.pipeline_utils import (
+ ModulePerf,
+ QEffPipelineOutput,
+ set_module_device_ids,
+)
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils._utils import load_json
+from tests.diffusers.diffusers_utils import DiffusersTestUtils, MADValidator
+
+# Test Configuration for 256x256 resolution with 2 layers # update mad tolerance
+CONFIG_PATH = "tests/diffusers/flux_test_config.json"
+INITIAL_TEST_CONFIG = load_json(CONFIG_PATH)
+
+
+def flux_pipeline_call_with_mad_validation(
+ pipeline,
+ pytorch_pipeline,
+ height: int = 256,
+ width: int = 256,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ custom_config_path: Optional[str] = None,
+ parallel_compile: bool = False,
+ mad_tolerances: Dict[str, float] = None,
+):
+ """
+ Pipeline call function that replicates the exact flow of pipeline_flux.py.__call__()
+ while adding comprehensive MAD validation at each step.
+
+ This function follows the EXACT same structure as QEffFluxPipeline.__call__()
+ but adds MAD validation hooks throughout the process.
+ """
+ # Initialize MAD validator
+ mad_validator = MADValidator(tolerances=mad_tolerances)
+
+ device = "cpu"
+
+ # Step 1: Load configuration, compile models
+ pipeline.compile(compile_config=custom_config_path, parallel=parallel_compile, height=height, width=width)
+
+ # Set device IDs for all modules based on configuration
+ set_module_device_ids(pipeline)
+
+ # Validate all inputs
+ pipeline.model.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # Set pipeline attributes
+ pipeline._guidance_scale = guidance_scale
+ pipeline._interrupt = False
+ batch_size = INITIAL_TEST_CONFIG["modules"]["transformer"]["specializations"]["batch_size"]
+
+ # Step 3: Encode prompts with both text encoders
+ # Use pipeline's encode_prompt method
+ (t5_qaic_prompt_embeds, clip_qaic_pooled_prompt_embeds, text_ids, text_encoder_perf) = pipeline.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ (t5_torch_prompt_embeds, clip_torch_pooled_prompt_embeds, text_ids) = pytorch_pipeline.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ # Deactivate text encoder qpc sessions
+ pipeline.text_encoder.qpc_session.deactivate()
+ pipeline.text_encoder_2.qpc_session.deactivate()
+
+ # MAD Validation for Text Encoders
+ print("π Performing MAD validation for text encoders...")
+ mad_validator.validate_module_mad(
+ clip_qaic_pooled_prompt_embeds, clip_torch_pooled_prompt_embeds, module_name="clip_text_encoder"
+ )
+ mad_validator.validate_module_mad(t5_torch_prompt_embeds, t5_qaic_prompt_embeds, "t5_text_encoder")
+
+ # Step 4: Prepare timesteps for denoising
+ timesteps, num_inference_steps = retrieve_timesteps(pipeline.scheduler, num_inference_steps, device, timesteps)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0)
+ pipeline._num_timesteps = len(timesteps)
+
+ # Step 5: Prepare initial latents
+ num_channels_latents = pipeline.transformer.model.config.in_channels // 4
+ latents, latent_image_ids = pipeline.model.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ t5_qaic_prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # Step 6: Initialize transformer inference session
+ if pipeline.transformer.qpc_session is None:
+ pipeline.transformer.qpc_session = QAICInferenceSession(
+ str(pipeline.transformer.qpc_path), device_ids=pipeline.transformer.device_ids
+ )
+
+ # Calculate compressed latent dimension (cl) for transformer buffer allocation
+ from QEfficient.diffusers.pipelines.pipeline_utils import calculate_compressed_latent_dimension
+
+ cl, _, _ = calculate_compressed_latent_dimension(height, width, pipeline.model.vae_scale_factor)
+
+ # Allocate output buffer for transformer
+ output_buffer = {
+ "output": np.random.rand(batch_size, cl, pipeline.transformer.model.config.in_channels).astype(np.float32),
+ }
+ pipeline.transformer.qpc_session.set_buffers(output_buffer)
+
+ transformer_perf = []
+ pipeline.scheduler.set_begin_index(0)
+
+ # Step 7: Denoising loop
+ with pipeline.model.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if pipeline._interrupt:
+ continue
+
+ # Prepare timestep embedding
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ temb = pipeline.transformer.model.time_text_embed(timestep, clip_qaic_pooled_prompt_embeds)
+
+ # Compute AdaLN embeddings for dual transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(pipeline.transformer.model.transformer_blocks)):
+ block = pipeline.transformer.model.transformer_blocks[block_idx]
+ f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1)
+ f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1)
+ adaln_emb.append(torch.cat(list(f1) + list(f2)))
+ adaln_dual_emb = torch.stack(adaln_emb)
+
+ # Compute AdaLN embeddings for single transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(pipeline.transformer.model.single_transformer_blocks)):
+ block = pipeline.transformer.model.single_transformer_blocks[block_idx]
+ f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1)
+ adaln_emb.append(torch.cat(list(f1)))
+ adaln_single_emb = torch.stack(adaln_emb)
+
+ # Compute output AdaLN embedding
+ temp = pipeline.transformer.model.norm_out
+ adaln_out = temp.linear(temp.silu(temb))
+
+ # Normalize timestep to [0, 1] range
+ timestep = timestep / 1000
+
+ # Prepare all inputs for transformer inference
+ inputs_aic = {
+ "hidden_states": latents.detach().numpy(),
+ "encoder_hidden_states": t5_qaic_prompt_embeds.detach().numpy(),
+ "pooled_projections": clip_qaic_pooled_prompt_embeds.detach().numpy(),
+ "timestep": timestep.detach().numpy(),
+ "img_ids": latent_image_ids.detach().numpy(),
+ "txt_ids": text_ids.detach().numpy(),
+ "adaln_emb": adaln_dual_emb.detach().numpy(),
+ "adaln_single_emb": adaln_single_emb.detach().numpy(),
+ "adaln_out": adaln_out.detach().numpy(),
+ }
+
+ # MAD Validation for Transformer - PyTorch reference inference
+ noise_pred_torch = pytorch_pipeline.transformer(
+ hidden_states=latents,
+ encoder_hidden_states=t5_torch_prompt_embeds,
+ pooled_projections=clip_torch_pooled_prompt_embeds,
+ timestep=torch.tensor(timestep),
+ img_ids=latent_image_ids,
+ txt_ids=text_ids,
+ return_dict=False,
+ )[0]
+
+ # Run transformer inference and measure time
+ start_transformer_step_time = time.time()
+ outputs = pipeline.transformer.qpc_session.run(inputs_aic)
+ end_transformer_step_time = time.time()
+ transformer_perf.append(end_transformer_step_time - start_transformer_step_time)
+
+ noise_pred = torch.from_numpy(outputs["output"])
+
+ # Transformer MAD validation
+ mad_validator.validate_module_mad(
+ noise_pred_torch.detach().cpu().numpy(),
+ outputs["output"],
+ "transformer",
+ f"step {i} (t={t.item():.1f})",
+ )
+
+ # Update latents using scheduler
+ latents_dtype = latents.dtype
+ latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Handle dtype mismatch
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ latents = latents.to(latents_dtype)
+
+ # Update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
+ progress_bar.update()
+
+ # Step 8: Decode latents to images
+ if output_type == "latent":
+ image = latents
+ vae_decode_perf = 0.0 # No VAE decoding for latent output
+ else:
+ # Unpack and denormalize latents
+ latents = pipeline.model._unpack_latents(latents, height, width, pipeline.model.vae_scale_factor)
+
+ # Denormalize latents
+ latents = (latents / pipeline.vae_decode.model.scaling_factor) + pipeline.vae_decode.model.shift_factor
+ # Initialize VAE decoder inference session
+ if pipeline.vae_decode.qpc_session is None:
+ pipeline.vae_decode.qpc_session = QAICInferenceSession(
+ str(pipeline.vae_decode.qpc_path), device_ids=pipeline.vae_decode.device_ids
+ )
+
+ # Allocate output buffer for VAE decoder
+ output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.float32)}
+ pipeline.vae_decode.qpc_session.set_buffers(output_buffer)
+
+ # MAD Validation for VAE
+ # PyTorch reference inference
+ image_torch = pytorch_pipeline.vae.decode(latents, return_dict=False)[0]
+
+ # Run VAE decoder inference and measure time
+ inputs = {"latent_sample": latents.numpy()}
+ start_decode_time = time.time()
+ image = pipeline.vae_decode.qpc_session.run(inputs)
+ end_decode_time = time.time()
+ vae_decode_perf = end_decode_time - start_decode_time
+
+ # VAE MAD validation
+ mad_validator.validate_module_mad(image_torch.detach().cpu().numpy(), image["sample"], "vae_decoder")
+
+ # Post-process image
+ image_tensor = torch.from_numpy(image["sample"])
+ image = pipeline.model.image_processor.postprocess(image_tensor, output_type=output_type)
+
+ # Build performance metrics
+ perf_metrics = [
+ ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]),
+ ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]),
+ ModulePerf(module_name="transformer", perf=transformer_perf),
+ ModulePerf(module_name="vae_decoder", perf=vae_decode_perf),
+ ]
+
+ return QEffPipelineOutput(
+ pipeline_module=perf_metrics,
+ images=image,
+ )
+
+
+@pytest.fixture(scope="session")
+def flux_pipeline():
+ """Setup compiled Flux pipeline for testing"""
+ config = INITIAL_TEST_CONFIG["model_setup"]
+
+ pipeline = QEffFluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ use_onnx_subfunctions=config["use_onnx_subfunctions"],
+ )
+
+ # Reduce to 2 layers for testing
+ original_blocks = pipeline.transformer.model.transformer_blocks
+ org_single_blocks = pipeline.transformer.model.single_transformer_blocks
+
+ pipeline.transformer.model.config["num_layers"] = config["num_transformer_layers"]
+ pipeline.transformer.model.config["num_single_layers"] = config["num_single_layers"]
+ pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList(
+ [original_blocks[i] for i in range(0, pipeline.transformer.model.config["num_layers"])]
+ )
+ pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList(
+ [org_single_blocks[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])]
+ )
+
+ ### Pytorch pipeline
+ pytorch_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ original_blocks_pt = pytorch_pipeline.transformer.transformer_blocks
+ org_single_blocks_pt = pytorch_pipeline.transformer.single_transformer_blocks
+ pytorch_pipeline.transformer.transformer_blocks = torch.nn.ModuleList(
+ [original_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_layers"])]
+ )
+ pytorch_pipeline.transformer.single_transformer_blocks = torch.nn.ModuleList(
+ [org_single_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])]
+ )
+ return pipeline, pytorch_pipeline
+
+
+@pytest.mark.diffusion_models
+@pytest.mark.on_qaic
+def test_flux_pipeline(flux_pipeline):
+ """
+ Comprehensive Flux pipeline test that follows the exact same flow as pipeline_flux.py:
+ - 256x256 resolution - 2 transformer layers
+ - MAD validation
+ - Functional image generation test
+ - Export/compilation checks
+ - Returns QEffPipelineOutput with performance metrics
+ """
+ pipeline, pytorch_pipeline = flux_pipeline
+ config = INITIAL_TEST_CONFIG
+
+ # Print test header
+ DiffusersTestUtils.print_test_header(
+ f"FLUX PIPELINE TEST - {config['model_setup']['height']}x{config['model_setup']['width']} Resolution, {config['model_setup']['num_transformer_layers']} Layers",
+ config,
+ )
+
+ # Test parameters
+ test_prompt = config["pipeline_params"]["test_prompt"]
+ num_inference_steps = config["pipeline_params"]["num_inference_steps"]
+ guidance_scale = config["pipeline_params"]["guidance_scale"]
+ max_sequence_length = config["pipeline_params"]["max_sequence_length"]
+
+ # Generate with MAD validation
+ generator = torch.manual_seed(42)
+ start_time = time.time()
+
+ try:
+ # Run the pipeline with integrated MAD validation (follows exact pipeline flow)
+ result = flux_pipeline_call_with_mad_validation(
+ pipeline,
+ pytorch_pipeline,
+ height=config["model_setup"]["height"],
+ width=config["model_setup"]["width"],
+ prompt=test_prompt,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ max_sequence_length=max_sequence_length,
+ custom_config_path=CONFIG_PATH,
+ generator=generator,
+ mad_tolerances=config["mad_validation"]["tolerances"],
+ parallel_compile=True,
+ return_dict=True,
+ )
+
+ execution_time = time.time() - start_time
+
+ # Validate image generation
+ if config["pipeline_params"]["validate_gen_img"]:
+ assert result is not None, "Pipeline returned None"
+ assert hasattr(result, "images"), "Result missing 'images' attribute"
+ assert len(result.images) > 0, "No images generated"
+
+ generated_image = result.images[0]
+ expected_size = (config["model_setup"]["height"], config["model_setup"]["width"])
+ # Validate image properties using utilities
+ image_validation = DiffusersTestUtils.validate_image_generation(
+ generated_image, expected_size, config["pipeline_params"]["min_image_variance"]
+ )
+
+ print("\nβ
IMAGE VALIDATION PASSED")
+ print(f" - Size: {image_validation['size']}")
+ print(f" - Mode: {image_validation['mode']}")
+ print(f" - Variance: {image_validation['variance']:.2f}")
+ print(f" - Mean pixel value: {image_validation['mean_pixel_value']:.2f}")
+ file_path = "test_flux_256x256_2layers.png"
+ # Save test image
+ generated_image.save(file_path)
+
+ if os.path.exists(file_path):
+ print(f"Image saved successfully at: {file_path}")
+ else:
+ print("Image was not saved.")
+
+ if config["validation_checks"]["onnx_export"]:
+ # Check if ONNX files exist (basic check)
+ print("\nπ ONNX Export Validation:")
+ for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]:
+ module_obj = getattr(pipeline, module_name, None)
+ if module_obj and hasattr(module_obj, "onnx_path") and module_obj.onnx_path:
+ DiffusersTestUtils.check_file_exists(str(module_obj.onnx_path), f"{module_name} ONNX")
+
+ if config["validation_checks"]["compilation"]:
+ # Check if QPC files exist (basic check)
+ print("\nπ Compilation Validation:")
+ for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]:
+ module_obj = getattr(pipeline, module_name, None)
+ if module_obj and hasattr(module_obj, "qpc_path") and module_obj.qpc_path:
+ DiffusersTestUtils.check_file_exists(str(module_obj.qpc_path), f"{module_name} QPC")
+
+ # Print test summary using utilities
+ print(f"\nTotal execution time: {execution_time:.4f}s")
+ except Exception as e:
+ print(f"\nTEST FAILED: {e}")
+ raise
+
+
+if __name__ == "__main__":
+ # This allows running the test file directly for debugging
+ pytest.main([__file__, "-v", "-s", "-m", "flux"])
+# pytest tests/diffusers/test_flux.py -m flux -v -s --tb=short
diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py
index 00a4216b7..46b33c60b 100644
--- a/tests/peft/lora/test_lora_model.py
+++ b/tests/peft/lora/test_lora_model.py
@@ -222,7 +222,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate(
# export
start = perf_counter()
- qeff_model.export(export_dir=tmp_path)
+ onnx_path = qeff_model.export(export_dir=tmp_path)
end = perf_counter()
export_time_0 = end - start
model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.export_hash)
@@ -237,7 +237,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate(
assert export_time_1 < export_time_0
# test compile
- qeff_model.compile(prefill_seq_len=32, ctx_len=64)
+ qeff_model.compile(onnx_path=onnx_path, prefill_seq_len=32, ctx_len=64)
assert Path(qeff_model.qpc_path).is_dir()
assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))
diff --git a/tests/peft/test_peft_model.py b/tests/peft/test_peft_model.py
index cc94467db..c3bb2f140 100644
--- a/tests/peft/test_peft_model.py
+++ b/tests/peft/test_peft_model.py
@@ -178,9 +178,9 @@ def test_auto_peft_model_for_causal_lm_activate_invalid(base_config, adapter_con
def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_config, batch_size, tmp_path):
_, lora_model = create_peft_model(base_config, adapter_config)
qeff_model = QEffAutoPeftModelForCausalLM(lora_model)
- qeff_model.export(tmp_path)
+ onnx_path = qeff_model.export(tmp_path)
start = perf_counter()
- qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
+ qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
end = perf_counter()
compile_time_0 = end - start
@@ -197,7 +197,7 @@ def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_con
)
start = perf_counter()
- qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
+ qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
end = perf_counter()
compile_time_1 = end - start
assert compile_time_1 < 0.01 * compile_time_0
diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py
new file mode 100644
index 000000000..6358940df
--- /dev/null
+++ b/tests/transformers/models/test_disagg_mode.py
@@ -0,0 +1,192 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import time
+
+import numpy as np
+import pytest
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, HybridCache
+
+from QEfficient import QEFFAutoModelForCausalLM
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers
+
+model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32
+
+prompt2 = """
+Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures.
+
+As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown.
+
+The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location.
+"""
+prompt1 = "Once upon a time"
+
+prompts = [prompt1, prompt2]
+
+
+@pytest.mark.on_qaic
+@pytest.mark.parametrize("model_id", [model_id])
+@pytest.mark.parametrize("prompt", prompts)
+def test_disagg_mode_prefill(model_id, prompt):
+ # Run prefill
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ PREFILL_SEQ_LEN = 256
+ CTX_LEN = 256
+ inputs = tokenizer(prompt, return_tensors="np", padding=True)
+ padded_len = inputs["input_ids"].shape[1]
+ num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
+ padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
+
+ replace_transformers_quantizers()
+ model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2)
+ config = model.config
+ inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+ inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+ inputs.pop("token_type_ids", None)
+ inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()}
+ cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN)
+ ins = tokenizer(prompt, return_tensors="pt")
+ out = model(**ins, past_key_values=cache)
+
+ undo_transformers_quantizers()
+
+ qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2)
+ qeff_model.prefill(True)
+ config = qeff_model.model.config
+ inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+ inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+ inputs.pop("token_type_ids", None)
+ inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
+ past_key_values = []
+ for i in range(config.num_hidden_layers):
+ cache_len = 128 if i % 2 == 0 else PREFILL_SEQ_LEN
+ pad_shape = (1, 8, cache_len, 64)
+ past_key = torch.zeros((pad_shape), dtype=torch.float32)
+ past_value = torch.zeros((pad_shape), dtype=torch.float32)
+ pkv = (past_key, past_value)
+ past_key_values.append(pkv)
+ inputs["past_key_values"] = past_key_values
+
+ qeff_out = qeff_model.model(**inputs)
+
+ # Check our pytorch implementation
+ assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4
+
+ prefill_qpc_path = qeff_model.compile(
+ prefill_seq_len=PREFILL_SEQ_LEN,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=False,
+ mxint8_kv_cache=False,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ prefill_only=True,
+ )
+
+ prefill_session = QAICInferenceSession(prefill_qpc_path)
+ logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
+ prefill_session.set_buffers({"logits": logits_out_placeholder})
+ inputs.pop("past_key_values")
+ inputs = {k: v.detach().numpy() for k, v in inputs.items()}
+ st = time.time()
+ qpc_out = prefill_session.run(inputs)
+ print(f"time for prefill_run={time.time() - st} sec\n")
+ del prefill_session
+ # Check QAIC output isclose with QEFF pytorch output
+ assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2
+
+
+@pytest.mark.skip(reason="no way of currently testing this without the assert sdk")
+@pytest.mark.on_qaic
+@pytest.mark.parametrize("model_id", [model_id])
+@pytest.mark.parametrize("prompt", prompts)
+def test_disagg_mode_prefill_chunked(model_id, prompt):
+ # Run prefill
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ PREFILL_SEQ_LEN = 128
+ CTX_LEN = 128 * 3
+ inputs = tokenizer(prompt, return_tensors="np", padding=True)
+ padded_len = inputs["input_ids"].shape[1]
+ num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
+ padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
+
+ replace_transformers_quantizers()
+ model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2)
+ config = model.config
+ inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+ inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+ inputs.pop("token_type_ids", None)
+ inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()}
+ cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN)
+ ins = tokenizer(prompt, return_tensors="pt")
+ out = model(**ins, past_key_values=cache)
+
+ undo_transformers_quantizers()
+
+ qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2)
+ qeff_model.prefill(True, enable_chunking=True)
+ config = qeff_model.model.config
+ inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+ inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+ inputs.pop("token_type_ids", None)
+ inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
+ past_key_values = []
+ for i in range(config.num_hidden_layers):
+ cache_len = CTX_LEN
+ pad_shape = (1, 8, cache_len, 64)
+ past_key = torch.zeros((pad_shape), dtype=torch.float32)
+ past_value = torch.zeros((pad_shape), dtype=torch.float32)
+ pkv = (past_key, past_value)
+ past_key_values.append(pkv)
+ inputs["past_key_values"] = past_key_values
+
+ for i in range(num_chunks):
+ chunk_inputs = inputs.copy()
+ chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+ chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+
+ qeff_out = qeff_model.model(**chunk_inputs)
+ inputs["past_key_values"] = qeff_out["past_key_values"]
+
+ # Check our pytorch implementation
+ assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4
+
+ prefill_qpc_path = qeff_model.compile(
+ prefill_seq_len=PREFILL_SEQ_LEN,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=False,
+ mxint8_kv_cache=False,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ prefill_only=True,
+ enable_chunking=True,
+ )
+ prefill_session = QAICInferenceSession(prefill_qpc_path)
+ prefill_session.skip_buffers(
+ [x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")]
+ )
+ logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
+ prefill_session.set_buffers({"logits": logits_out_placeholder})
+ inputs.pop("past_key_values")
+ inputs = {k: v.detach().numpy() for k, v in inputs.items()}
+ st = time.time()
+ for i in range(num_chunks):
+ chunk_inputs = inputs.copy()
+ chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+ chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+ qpc_out = prefill_session.run(chunk_inputs)
+ print(f"time for prefill_run={time.time() - st} sec\n")
+ del prefill_session
+ # Check QAIC output isclose with QEFF pytorch output
+ assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 8e-2
diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py
index 9335e1d91..26cb6fda9 100644
--- a/tests/transformers/sampler/test_sampler.py
+++ b/tests/transformers/sampler/test_sampler.py
@@ -5,15 +5,18 @@
#
# -----------------------------------------------------------------------------
-from typing import List
+from typing import List, Optional, Tuple, Union
import numpy as np
import pytest
+from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
-from QEfficient import QEFFAutoModelForCausalLM
+from QEfficient import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import load_hf_tokenizer
from QEfficient.utils.constants import Constants
+from QEfficient.utils.test_utils import InternProcessor
+from tests.transformers.models.image_text_to_text.test_continuous_batching import set_num_layers
sampler_transform_configs = [
pytest.param(
@@ -24,6 +27,20 @@
20, # generation_len
2, # full_batch_size
1, # spec_length
+ False, # is_vlm
+ ),
+ pytest.param(
+ "OpenGVLab/InternVL2_5-1B", # model
+ (
+ ["https://picsum.photos/id/237/536/354"] * 2,
+ ["Can you describe the image in detail."] * 2,
+ ), # images and prompts
+ 128, # prefill_seq_len
+ 4096, # ctx_len
+ 20, # generation_len
+ 2, # full_batch_size
+ None, # spec_length
+ True, # is_vlm
),
]
greedy_sampling_configs = [
@@ -35,6 +52,20 @@
20, # generation_len
4, # full_batch_size
1, # spec_length
+ False, # is_vlm
+ ),
+ pytest.param(
+ "OpenGVLab/InternVL2_5-1B", # model
+ (
+ ["https://picsum.photos/id/237/536/354"] * 2,
+ ["Can you describe the image in detail."] * 2,
+ ), # images and prompts
+ 128, # prefill_seq_len
+ 4096, # ctx_len
+ 20, # generation_len
+ 2, # full_batch_size
+ None, # spec_length
+ True, # is_vlm
),
]
random_sampling_configs = [
@@ -46,23 +77,98 @@
20, # generation_len
4, # full_batch_size
1, # spec_length
+ False, # is_vlm
+ ),
+ pytest.param(
+ "OpenGVLab/InternVL2_5-1B", # model
+ (
+ ["https://picsum.photos/id/237/536/354"] * 4,
+ ["Can you describe the image in detail."] * 4,
+ ), # images and prompts
+ 128, # prefill_seq_len
+ 4096, # ctx_len
+ 20, # generation_len
+ 4, # full_batch_size
+ None, # spec_length
+ True, # is_vlm
+ ),
+]
+guided_decoding_configs = [
+ pytest.param(
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # model
+ Constants.INPUT_STR * 4, # prompts
+ 32, # prefill_seq_len
+ 64, # ctx_len
+ 20, # generation_len
+ 4, # full_batch_size
+ 1, # spec_length
+ False, # is_vlm
+ ),
+ pytest.param(
+ "OpenGVLab/InternVL2_5-1B", # model
+ (
+ ["https://picsum.photos/id/237/536/354"] * 2,
+ ["Can you describe the image in detail."] * 2,
+ ), # images and prompts
+ 128, # prefill_seq_len
+ 4096, # ctx_len
+ 20, # generation_len
+ 2, # full_batch_size
+ None, # spec_length
+ True, # is_vlm
),
]
+def prepare_model_setup(
+ model: str, is_vlm: bool, num_hidden_layers: int, prompts: Union[List, Tuple], spec_length: Optional[int]
+):
+ additional_configs = {}
+ additional_params = {}
+ if is_vlm:
+ config = AutoConfig.from_pretrained(model, trust_remote_code=True)
+ config = set_num_layers(config, n_layer=num_hidden_layers)
+ additional_configs["config"] = config
+ additional_configs["kv_offload"] = True
+ assert isinstance(prompts, tuple), "For VLMs, both image and text prompts must be provided."
+ additional_params["images"] = prompts[0]
+ prompts = prompts[1]
+
+ if "InternVL" in model:
+ additional_configs["trust_remote_code"] = True
+ model_hf = AutoModelForCausalLM.from_pretrained(
+ model,
+ config=config,
+ trust_remote_code=True,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True, use_fast=False)
+ additional_params["processor"] = InternProcessor(model_hf, tokenizer)
+ qeff_class = QEFFAutoModelForCausalLM
+ else:
+ additional_params["processor"] = AutoProcessor.from_pretrained(model)
+ qeff_class = QEFFAutoModelForImageTextToText
+ else:
+ if num_hidden_layers != -1:
+ additional_configs["num_hidden_layers"] = num_hidden_layers
+ spec_length = (spec_length or 1) - 1
+ qeff_class = QEFFAutoModelForCausalLM
+ return additional_configs, additional_params, prompts, spec_length, qeff_class
+
+
@pytest.mark.on_qaic
@pytest.mark.parametrize(
- "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length",
+ "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm",
sampler_transform_configs,
)
def test_sampler_transform(
model: str,
- prompts: List[str],
+ prompts: Union[List[str], tuple[List[str], List[str]]],
prefill_seq_len: int,
ctx_len: int,
generation_len: int,
full_batch_size: int,
- spec_length: int,
+ spec_length: Optional[int],
+ is_vlm: bool,
):
"""
Test if `SamplerTransform` adds nodes at the output of a `QEffForCausalLM model` to enable the
@@ -70,48 +176,78 @@ def test_sampler_transform(
next tokens and/or probability distributions.
"""
# Export and compile QEfficient models
- model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ num_hidden_layers = 2
+ additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
+ model, is_vlm, num_hidden_layers, prompts, spec_length
+ )
+ model_w_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
- num_hidden_layers=2,
qaic_config={
"include_sampler": True,
"return_pdfs": False,
"max_top_k_ids": 512,
},
+ **additional_configs,
)
- model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ model_w_sampler_w_guided_decoding = qeff_class.from_pretrained(
+ model,
+ continuous_batching=True,
+ qaic_config={
+ "include_sampler": True,
+ "return_pdfs": False,
+ "max_top_k_ids": 512,
+ "include_guided_decoding": True,
+ },
+ **additional_configs,
+ )
+ model_wo_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
- num_hidden_layers=2,
qaic_config={
"include_sampler": False,
"return_pdfs": False,
},
+ **additional_configs,
+ )
+ model_w_sampler_qpc_path = model_w_sampler.compile(
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ full_batch_size=full_batch_size,
+ num_devices=1,
+ num_cores=16,
+ num_speculative_tokens=spec_length,
+ mxint8_kv_cache=True,
+ mxfp6_matmul=True,
)
- model_w_sampler_qpc_path: str = model_w_sampler.compile(
+ model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding.compile(
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
- model_wo_sampler_qpc_path: str = model_wo_sampler.compile(
+ model_wo_sampler_qpc_path = model_wo_sampler.compile(
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
+ if is_vlm:
+ model_w_sampler_qpc_path = model_w_sampler_qpc_path[1]
+ model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding_qpc_path[1]
+ model_wo_sampler_qpc_path = model_wo_sampler_qpc_path[1]
# Init qaic session
model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path)
+ model_w_sampler_w_guided_decoding_session = QAICInferenceSession(model_w_sampler_w_guided_decoding_qpc_path)
model_wo_sampler_session = QAICInferenceSession(model_wo_sampler_qpc_path)
# Skip inputs/outputs buffers
@@ -119,6 +255,12 @@ def test_sampler_transform(
model_w_sampler_session.skip_buffers(
set([x for x in model_w_sampler_session.output_names if x.endswith("_RetainedState")])
)
+ model_w_sampler_w_guided_decoding_session.skip_buffers(
+ set([x for x in model_w_sampler_w_guided_decoding_session.input_names if x.startswith("past_")])
+ )
+ model_w_sampler_w_guided_decoding_session.skip_buffers(
+ set([x for x in model_w_sampler_w_guided_decoding_session.output_names if x.endswith("_RetainedState")])
+ )
model_wo_sampler_session.skip_buffers(
set([x for x in model_wo_sampler_session.input_names if x.startswith("past_")])
)
@@ -132,47 +274,58 @@ def test_sampler_transform(
assert input_name in model_w_sampler_session.input_names, (
f"Sampler input {input_name} not found in QPC compiled with On Device Sampler"
)
+ assert input_name in model_w_sampler_w_guided_decoding_session.input_names, (
+ f"Sampler input {input_name} not found in QPC compiled with On Device Sampler and Guided Decoding"
+ )
assert input_name not in model_wo_sampler_session.input_names, (
f"Sampler input {input_name} found in QPC compiled without On Device Sampler"
)
+ assert "token_bitmasks" in model_w_sampler_w_guided_decoding_session.input_names, (
+ "Sampler input token_bitmasks not found in QPC compiled with On Device Sampler and Guided Decoding"
+ )
@pytest.mark.on_qaic
@pytest.mark.parametrize(
- "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length",
+ "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm",
greedy_sampling_configs,
)
def test_greedy_sampling(
model: str,
- prompts: List[str],
+ prompts: Union[List[str], tuple[List[str], List[str]]],
prefill_seq_len: int,
ctx_len: int,
generation_len: int,
full_batch_size: int,
- spec_length: int,
+ spec_length: Optional[int],
+ is_vlm: bool,
):
"""
- Test greedy sampling with QPC compiled with and without On Device Sampling.
+ Test greedy sampling with QPCs compiled with and without On Device Sampling.
"""
# Export and compile QEfficient models
- model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ num_hidden_layers = 4
+ additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
+ model, is_vlm, num_hidden_layers, prompts, spec_length
+ )
+ model_w_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
- num_hidden_layers=4,
qaic_config={
"include_sampler": True,
"return_pdfs": False,
"max_top_k_ids": 512,
},
+ **additional_configs,
)
- model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ model_wo_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
- num_hidden_layers=4,
qaic_config={
"include_sampler": False,
"return_pdfs": False,
},
+ **additional_configs,
)
model_w_sampler.compile(
prefill_seq_len=prefill_seq_len,
@@ -180,7 +333,7 @@ def test_greedy_sampling(
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
@@ -190,7 +343,7 @@ def test_greedy_sampling(
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
@@ -211,8 +364,9 @@ def test_greedy_sampling(
"top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
"top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
- "random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32),
},
+ **additional_params,
)
model_wo_sampler_exec_info = model_wo_sampler.generate(
tokenizer=tokenizer,
@@ -221,6 +375,7 @@ def test_greedy_sampling(
include_sampler=False,
return_pdfs=False,
sampling_params=None,
+ **additional_params,
)
# Compare generated texts and ids
@@ -233,25 +388,29 @@ def test_greedy_sampling(
@pytest.mark.on_qaic
-@pytest.mark.skip
@pytest.mark.parametrize(
- "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length",
+ "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm",
random_sampling_configs,
)
def test_random_sampling(
model: str,
- prompts: List[str],
+ prompts: Union[List[str], tuple[List[str], List[str]]],
prefill_seq_len: int,
ctx_len: int,
generation_len: int,
full_batch_size: int,
- spec_length: int,
+ spec_length: Optional[int],
+ is_vlm: bool,
):
"""
- Test random sampling with QPC compiled with and without On Device Sampling.
+ Test random sampling with QPCs compiled with and without On Device Sampling.
"""
# Export and compile QEfficient models
- model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ num_hidden_layers = -1
+ additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
+ model, is_vlm, num_hidden_layers, prompts, spec_length
+ )
+ model_w_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
qaic_config={
@@ -259,14 +418,16 @@ def test_random_sampling(
"return_pdfs": False,
"max_top_k_ids": 512,
},
+ **additional_configs,
)
- model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ model_wo_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
qaic_config={
"include_sampler": False,
"return_pdfs": False,
},
+ **additional_configs,
)
model_w_sampler.compile(
prefill_seq_len=prefill_seq_len,
@@ -274,7 +435,7 @@ def test_random_sampling(
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
@@ -284,13 +445,14 @@ def test_random_sampling(
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
# Generate texts from prompts
tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model)
+ np.random.seed(0)
model_w_sampler_exec_info = model_w_sampler.generate(
tokenizer=tokenizer,
prompts=prompts,
@@ -301,12 +463,15 @@ def test_random_sampling(
"repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
# "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
- "temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
- "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
+ "temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
"top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
- "random_numbers": np.array(0.26, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=512), (full_batch_size, 1)).astype(
+ np.float32
+ ),
},
+ **additional_params,
)
model_wo_sampler_exec_info = model_wo_sampler.generate(
tokenizer=tokenizer,
@@ -315,63 +480,120 @@ def test_random_sampling(
include_sampler=False,
return_pdfs=False,
sampling_params=None,
+ **additional_params,
)
# Compare generated texts
- golden_texts = {
- "w_sampler": "Raymond and my favorite color, alongside reds or purples (I canβt have them both",
- "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ",
- }
- golden_ids = {
- "w_sampler": [
- [
- 21380,
- 322,
- 590,
- 25448,
- 2927,
- 29892,
- 19963,
- 2654,
- 29879,
- 470,
- 3708,
- 2701,
- 313,
- 29902,
- 508,
- 30010,
- 29873,
- 505,
- 963,
- 1716,
- ]
- ],
- "wo_sampler": [
- [
- 2259,
- 7075,
- 322,
- 306,
- 626,
- 263,
- 7047,
- 22055,
- 29889,
- 306,
- 505,
- 1063,
- 1985,
- 297,
- 278,
- 13661,
- 363,
- 278,
- 4940,
- 29871,
- ]
- ],
- }
+ if model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
+ golden_texts = {
+ "w_sampler": "Aiden and I am a freelance writer who loves to explore the world. With over",
+ "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ",
+ }
+ golden_ids = {
+ "w_sampler": [
+ [
+ 319,
+ 3615,
+ 322,
+ 306,
+ 626,
+ 263,
+ 3005,
+ 295,
+ 749,
+ 9227,
+ 1058,
+ 12355,
+ 267,
+ 304,
+ 26987,
+ 278,
+ 3186,
+ 29889,
+ 2973,
+ 975,
+ ]
+ ],
+ "wo_sampler": [
+ [
+ 2259,
+ 7075,
+ 322,
+ 306,
+ 626,
+ 263,
+ 7047,
+ 22055,
+ 29889,
+ 306,
+ 505,
+ 1063,
+ 1985,
+ 297,
+ 278,
+ 13661,
+ 363,
+ 278,
+ 4940,
+ 29871,
+ ]
+ ],
+ }
+ elif model == "OpenGVLab/InternVL2_5-1B":
+ golden_texts = {
+ "w_sampler": "The description of this picture would be as follows:\n\nAn adorable black puppy is sitting on a wooden surface",
+ "wo_sampler": "The image features a black puppy sitting on a wooden surface. The puppy has a shiny, glossy coat",
+ }
+ golden_ids = {
+ "w_sampler": [
+ [
+ 785,
+ 4008,
+ 315,
+ 419,
+ 6802,
+ 1035,
+ 387,
+ 438,
+ 11017,
+ 1447,
+ 2082,
+ 40608,
+ 3691,
+ 41189,
+ 374,
+ 11699,
+ 389,
+ 264,
+ 22360,
+ 7329,
+ ]
+ ],
+ "wo_sampler": [
+ [
+ 785,
+ 2168,
+ 4419,
+ 264,
+ 3691,
+ 41189,
+ 11699,
+ 389,
+ 264,
+ 22360,
+ 7329,
+ 13,
+ 576,
+ 41189,
+ 702,
+ 264,
+ 41199,
+ 11,
+ 73056,
+ 22875,
+ ]
+ ],
+ }
for i in range(full_batch_size):
assert (
tokenizer.decode(model_w_sampler_exec_info.generated_ids[i][:generation_len]) == golden_texts["w_sampler"]
@@ -385,3 +607,118 @@ def test_random_sampling(
assert (model_wo_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["wo_sampler"]).all(), (
"Without sampler generated ids do not match"
)
+
+
+@pytest.mark.on_qaic
+@pytest.mark.parametrize(
+ "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm",
+ guided_decoding_configs,
+)
+def test_guided_decoding(
+ model: str,
+ prompts: Union[List[str], tuple[List[str], List[str]]],
+ prefill_seq_len: int,
+ ctx_len: int,
+ generation_len: int,
+ full_batch_size: int,
+ spec_length: Optional[int],
+ is_vlm: bool,
+):
+ """
+ Test QPCs compiled with and without guided decoding.
+ """
+ # Export and compile QEfficient models
+ num_hidden_layers = 2
+ additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
+ model, is_vlm, num_hidden_layers, prompts, spec_length
+ )
+ model_w_sampler_w_guided_decoding = qeff_class.from_pretrained(
+ model,
+ continuous_batching=True,
+ qaic_config={
+ "include_sampler": True,
+ "return_pdfs": False,
+ "max_top_k_ids": 1024,
+ "include_guided_decoding": True,
+ },
+ **additional_configs,
+ )
+ model_w_sampler_wo_guided_decoding = qeff_class.from_pretrained(
+ model,
+ continuous_batching=True,
+ qaic_config={
+ "include_sampler": True,
+ "return_pdfs": False,
+ "max_top_k_ids": 1024,
+ },
+ **additional_configs,
+ )
+ model_w_sampler_w_guided_decoding.compile(
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ full_batch_size=full_batch_size,
+ num_devices=1,
+ num_cores=16,
+ num_speculative_tokens=spec_length,
+ mxint8_kv_cache=True,
+ mxfp6_matmul=True,
+ )
+ model_w_sampler_wo_guided_decoding.compile(
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ full_batch_size=full_batch_size,
+ num_devices=1,
+ num_cores=16,
+ num_speculative_tokens=spec_length,
+ mxint8_kv_cache=True,
+ mxfp6_matmul=True,
+ )
+
+ # Generate texts from prompts
+ tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model)
+ np.random.seed(0)
+ sampling_params = {
+ "repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ # "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "top_ks": np.array(1024, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
+ "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "random_numbers": np.zeros((full_batch_size, 1024), dtype=np.float32),
+ }
+ if is_vlm:
+ vocab_size = model_w_sampler_w_guided_decoding.model.language_model.config.vocab_size
+ else:
+ vocab_size = model_w_sampler_w_guided_decoding.model.config.vocab_size
+ model_w_sampler_w_guided_decoding_exec_info = model_w_sampler_w_guided_decoding.generate(
+ tokenizer=tokenizer,
+ prompts=prompts,
+ generation_len=generation_len,
+ include_sampler=True,
+ return_pdfs=False,
+ include_guided_decoding=True,
+ sampling_params={
+ **sampling_params,
+ **{
+ "token_bitmasks": np.tile(
+ np.random.choice([True, False], size=(vocab_size,)),
+ (full_batch_size, 1),
+ )
+ },
+ },
+ **additional_params,
+ )
+ model_w_sampler_wo_guided_decoding_exec_info = model_w_sampler_wo_guided_decoding.generate(
+ tokenizer=tokenizer,
+ prompts=prompts,
+ generation_len=generation_len,
+ include_sampler=True,
+ return_pdfs=False,
+ sampling_params=sampling_params,
+ **additional_params,
+ )
+ assert (
+ model_w_sampler_w_guided_decoding_exec_info.generated_ids
+ != model_w_sampler_wo_guided_decoding_exec_info.generated_ids
+ ).any(), "Sampler outputs with and without guided decoding should not match"
diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py
index 0810ac6ba..72477d56a 100644
--- a/tests/transformers/test_causal_lm.py
+++ b/tests/transformers/test_causal_lm.py
@@ -14,10 +14,11 @@
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
+from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils import constants, get_padding_shape_from_config
from QEfficient.utils.hash_utils import hash_dict_params
-configs = [
+test_configs = [
# name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params
("gpt2", 256, 2, 4, 128, 512, 127, {}),
("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}),
@@ -36,30 +37,43 @@
("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}),
]
-configs = [
- AutoConfig.for_model(
- model_name,
- max_position_embeddings=max_position_embeddings,
- num_hidden_layers=num_hidden_layers,
- num_attention_heads=num_attention_heads,
- hidden_size=hidden_size,
- intermediate_size=intermediate_size,
- vocab_size=vocab_size,
- **additional_params,
- )
- for (
- model_name,
- max_position_embeddings,
- num_hidden_layers,
- num_attention_heads,
- hidden_size,
- intermediate_size,
- vocab_size,
- additional_params,
- ) in configs
+test_prefill_only_specialized_models_configs = [
+ ("gpt_oss", 256, 2, 2, 32, 32, 127, {"num_key_value_heads": 2}),
]
+
+
+def get_auto_config_from_test_config(configs):
+ auto_configs = [
+ AutoConfig.for_model(
+ model_name,
+ max_position_embeddings=max_position_embeddings,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ vocab_size=vocab_size,
+ **additional_params,
+ )
+ for (
+ model_name,
+ max_position_embeddings,
+ num_hidden_layers,
+ num_attention_heads,
+ hidden_size,
+ intermediate_size,
+ vocab_size,
+ additional_params,
+ ) in configs
+ ]
+ return auto_configs
+
+
+configs = get_auto_config_from_test_config(test_configs)
config_ids = [x.model_type for x in configs]
+prefill_only_configs = get_auto_config_from_test_config(test_prefill_only_specialized_models_configs)
+prefill_only_config_ids = [x.model_type for x in prefill_only_configs]
+
model_kwargs = {"attn_implementation": "eager"}
@@ -144,20 +158,21 @@ def test_causal_lm_export_and_hash(config, cb, tmp_path):
@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("subfunc", [False, True], ids=["False", "True"])
@pytest.mark.parametrize("config", configs, ids=config_ids)
-def test_causal_lm_hash_creation(config, cb, tmp_path):
+def test_causal_lm_hash_creation(config, cb, subfunc, tmp_path):
model = AutoModelForCausalLM.from_config(config, **model_kwargs)
qeff_model = QEFFAutoModelForCausalLM(model, cb)
- qeff_model.export(tmp_path)
+ qeff_model.export(tmp_path, use_onnx_subfunctions=subfunc)
hash_params = {}
hash_params["config"] = qeff_model.model.config.to_diff_dict()
hash_params["peft_config"] = None
hash_params["applied_transform_names"] = qeff_model._transform_names()
hash_params["qeff_auto_class"] = qeff_model.__class__.__name__
+ hash_params["max_seq_len_cached"] = None
hash_params["qaic_config"] = None
# Create parameters separately for hash creation
-
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
@@ -190,12 +205,12 @@ def test_causal_lm_hash_creation(config, cb, tmp_path):
)
output_names = []
output_names.append("logits")
-
+ onnx_out_name_suffix = "InternalRetainedState" if subfunc else "RetainedState"
for i in range(qeff_model.num_layers):
pkv_dynamic_axes[i][0] = "full_batch_size" if qeff_model.continuous_batching else "batch_size"
for kv in ["key", "value"]:
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i]
- output_names.append(f"past_{kv}.{i}_RetainedState")
+ output_names.append(f"past_{kv}.{i}_{onnx_out_name_suffix}")
if qeff_model.continuous_batching:
dynamic_axes["batch_index"] = {0: "batch_size"}
@@ -204,14 +219,35 @@ def test_causal_lm_hash_creation(config, cb, tmp_path):
export_params["output_names"] = output_names
export_params["dynamic_axes"] = dynamic_axes
hash_params["export_params"] = export_params
+ if subfunc:
+ hash_params["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model)
+
manual_hash = hash_dict_params(hash_params)
assert manual_hash == qeff_model.export_hash
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", prefill_only_configs, ids=prefill_only_config_ids)
+def test_prefill_only_specialized_models(config, cb, tmp_path):
+ model = AutoModelForCausalLM.from_config(config, **model_kwargs)
+ qeff_model = QEFFAutoModelForCausalLM(model, cb)
+ if cb:
+ with pytest.raises(NotImplementedError):
+ qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False)
+ else:
+ with pytest.raises(ValueError):
+ qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False)
+ qeff_model.export(tmp_path, prefill_only=True, prefill_seq_len=256, offload_pt_weights=False)
+ first_export_hash = qeff_model.export_hash
+ qeff_model.export(tmp_path, prefill_only=False, offload_pt_weights=False)
+ second_export_hash = qeff_model.export_hash
+ assert first_export_hash != second_export_hash
+
+
@pytest.fixture
def tmp_cache(tmp_path, monkeypatch):
- monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path)
+ monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path)
yield tmp_path
diff --git a/tests/transformers/test_speech_seq2seq.py b/tests/transformers/test_speech_seq2seq.py
index 59281b73b..bc53cb539 100644
--- a/tests/transformers/test_speech_seq2seq.py
+++ b/tests/transformers/test_speech_seq2seq.py
@@ -141,7 +141,7 @@ def test_seq2seq_hash_creation(config, tmp_path):
@pytest.fixture
def tmp_cache(tmp_path, monkeypatch):
- monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path)
+ monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path)
yield tmp_path
diff --git a/tests/transformers/test_subfunction.py b/tests/transformers/test_subfunction.py
index 36cfc0ce5..47e49cf2c 100644
--- a/tests/transformers/test_subfunction.py
+++ b/tests/transformers/test_subfunction.py
@@ -4,7 +4,9 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
+from collections import Counter
+import onnx
import pytest
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@@ -44,24 +46,74 @@
config_ids = [x.model_type for x in configs]
+def has_gpt2block_function(onnx_path):
+ """Check if ONNX model contains QEffGPT2Block function definition."""
+ model = onnx.load(onnx_path, load_external_data=False)
+ function_names = [f.name for f in model.functions]
+ gpt2block_functions = [name for name in function_names if "QEffGPT2Block" in name]
+ return len(gpt2block_functions) > 0, gpt2block_functions
+
+
+def get_gpt2block_call_count(onnx_path):
+ """Get count of QEffGPT2Block function calls in the ONNX model graph."""
+ model = onnx.load(onnx_path, load_external_data=False)
+ calls = Counter([n.op_type for n in model.graph.node])
+ gpt2block_calls = {k: v for k, v in calls.items() if "QEffGPT2Block" in k}
+ return gpt2block_calls
+
+
+@pytest.mark.on_qaic
@pytest.mark.parametrize("config", configs, ids=config_ids)
def test_subfunction_vs_nonsubfunction(config, tmp_path):
tokenizer = AutoTokenizer.from_pretrained(config.model_type)
model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False)
- # model_0_0 = QEFFAutoModelForCausalLM.from_pretrained(config.model_type)
+ # Export with subfunctions enabled
with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False)
- hash_0_0 = model_0_0.export_hash
+ # Export without subfunctions
without_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=False)
- hash_0_1 = model_0_0.export_hash
- assert hash_0_0 != hash_0_1
+ # Verify that the model with subfunctions has QEffGPT2Block function definition
+ has_gpt2block, gpt2block_names = has_gpt2block_function(with_sub_func_onnx)
+ assert has_gpt2block, (
+ "Model exported with use_onnx_subfunctions=True should contain QEffGPT2Block function definition"
+ )
+ print(f"\nGpt2Block functions found: {gpt2block_names}")
+
+ # Verify that the model without subfunctions has no QEffGPT2Block function definition
+ has_gpt2block_without, _ = has_gpt2block_function(without_sub_func_onnx)
+ assert not has_gpt2block_without, (
+ "Model exported with use_onnx_subfunctions=False should not contain QEffGPT2Block function definition"
+ )
+
+ # Get QEffGPT2Block call counts
+ gpt2block_calls_with_sub = get_gpt2block_call_count(with_sub_func_onnx)
+ gpt2block_calls_without_sub = get_gpt2block_call_count(without_sub_func_onnx)
+ print(f"\nGpt2Block call counts with subfunctions: {gpt2block_calls_with_sub}")
+ print(f"QEffGPT2Block call counts without subfunctions: {gpt2block_calls_without_sub}")
+
+ # Verify that QEffGPT2Block function calls exist in the subfunction model
+ assert len(gpt2block_calls_with_sub) > 0, (
+ "Expected to find QEffGPT2Block function calls in graph when use_onnx_subfunctions=True"
+ )
+
+ # Verify that QEffGPT2Block function calls do NOT exist in the non-subfunction model
+ assert len(gpt2block_calls_without_sub) == 0, (
+ "Expected NO QEffGPT2Block function calls in graph when use_onnx_subfunctions=False"
+ )
+
+ # Compile and test generation to ensure functional equivalence
compile_params = {"prefill_seq_len": 8, "ctx_len": 16}
+
model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params)
generation_00 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer)
model_0_0.compile(onnx_path=without_sub_func_onnx, **compile_params)
generation_01 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer)
- assert generation_00.generated_texts == generation_01.generated_texts
+
+ # Verify that both models produce the same output
+ assert generation_00.generated_texts == generation_01.generated_texts, (
+ "Models with and without subfunctions should produce identical outputs"
+ )
diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py
index eb05b3f95..279d49bdf 100644
--- a/tests/transformers/test_transformer_pytorch_transforms.py
+++ b/tests/transformers/test_transformer_pytorch_transforms.py
@@ -20,7 +20,9 @@
from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform
from QEfficient.transformers.spd.turbo import ResBlock
from QEfficient.utils._utils import get_padding_shape_from_config
-from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.logging_utils import QEFFLogger
+
+logger = QEFFLogger.get_logger("INFRA", loglevel="INFO")
KVCacheTransformTestConfigs = [
("llama", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8),
diff --git a/tests/utils/test_hash_utils.py b/tests/utils/test_hash_utils.py
index fefa73973..b7a5495c6 100644
--- a/tests/utils/test_hash_utils.py
+++ b/tests/utils/test_hash_utils.py
@@ -41,7 +41,7 @@ def test_to_hashable_float_nan(value):
def test_json_serializable():
# Test with a set
- assert json_serializable({1, 2, 3}) == [1, 2, 3]
+ assert json_serializable({1, 2, 3}) == ["1", "2", "3"]
# Test with an unsupported type
with pytest.raises(TypeError):
json_serializable({1, 2, 3, {4, 5}})
diff --git a/tests/utils/test_logger.py b/tests/utils/test_logger.py
new file mode 100644
index 000000000..e6bd51c5c
--- /dev/null
+++ b/tests/utils/test_logger.py
@@ -0,0 +1,48 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import threading
+import time
+
+from QEfficient.utils.logging_utils import QEFFLogger
+
+# -------------------------------
+# Define namespace once
+# -------------------------------
+NAMESPACE = "model"
+
+# -------------------------------
+# Initialize logger
+# -------------------------------
+logger = QEFFLogger.get_logger(NAMESPACE, "DEBUG")
+
+
+# -------------------------------
+# Worker function for threads
+# -------------------------------
+def log_worker(thread_id):
+ for i in range(5):
+ logger.info(f"Thread-{thread_id} logging message {i}")
+ time.sleep(0.1)
+
+
+# -------------------------------
+# Create and start threads
+# -------------------------------
+threads = []
+for t_id in range(3):
+ t = threading.Thread(target=log_worker, args=(t_id,))
+ threads.append(t)
+ t.start()
+
+for t in threads:
+ t.join()
+
+# -------------------------------
+# Graceful shutdown
+# -------------------------------
+QEFFLogger.close_logger()