diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py
index 7f63b34ca..3c9f68efd 100644
--- a/QEfficient/__init__.py
+++ b/QEfficient/__init__.py
@@ -6,7 +6,17 @@
# -----------------------------------------------------------------------------
import os
-import warnings
+
+# ----------------------------------------------------------------------------- #
+# For faster downloads via hf_transfer
+# This code is put above import statements as this needs to be executed before
+# hf_transfer is imported (will happen on line 15 via leading imports)
+os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+# DO NOT ADD ANY CODE ABOVE THIS LINE
+# Please contact maintainers if you must edit this file above this line.
+# ----------------------------------------------------------------------------- #
+# Placeholder for all non-transformer models registered in QEfficient
+import warnings # noqa: I001
import QEfficient.utils.model_registery # noqa: F401
from QEfficient.base import (
@@ -18,6 +28,8 @@
QEFFCommonLoader,
)
from QEfficient.compile.compile_helper import compile
+from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline
+from QEfficient.diffusers.pipelines.wan.pipeline_wan import QEffWanPipeline
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.peft import QEffAutoPeftModelForCausalLM
@@ -25,6 +37,10 @@
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger
+# custom warning for the better logging experience
+warnings.formatwarning = custom_format_warning
+
+
# Users can use QEfficient.export for exporting models to ONNX
export = qualcomm_efficient_converter
__all__ = [
@@ -39,15 +55,10 @@
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
+ "QEffFluxPipeline",
+ "QEffWanPipeline",
]
-# For faster downloads via hf_transfer
-# This code is put above import statements as this needs to be executed before
-# hf_transfer is imported (will happen on line 15 via leading imports)
-os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
-# Placeholder for all non-transformer models registered in QEfficient
-# custom warning for the better logging experience
-warnings.formatwarning = custom_format_warning
# Conditionally import QAIC-related modules if the SDK is installed
__version__ = "0.0.1.dev0"
diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py
index ef7e83adf..2c98a83f3 100644
--- a/QEfficient/base/modeling_qeff.py
+++ b/QEfficient/base/modeling_qeff.py
@@ -8,7 +8,6 @@
import gc
import inspect
import logging
-import re
import shutil
import subprocess
import warnings
@@ -21,26 +20,21 @@
from QEfficient.base.onnx_transforms import (
BaseOnnxTransform,
- CustomOpTransform,
OnnxTransformPipeline,
- RenameFunctionOutputsTransform,
)
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
-from QEfficient.transformers.cache_utils import InvalidIndexProvider
-from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils import (
constants,
create_json,
create_model_params,
dump_qconfig,
- export_wrapper,
generate_mdp_partition_config,
hash_dict_params,
load_json,
)
-from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches
+from QEfficient.utils.export_utils import export_wrapper
logger = logging.getLogger(__name__)
@@ -66,6 +60,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
self.model = model
self.hash_params = create_model_params(self, **kwargs)
+ self.prefill_onnx_path: Optional[str] = None
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
@@ -125,9 +120,35 @@ def _model_offloaded_check(self) -> None:
logger.error(error_msg)
raise RuntimeError(error_msg)
+ @property
+ def model_name(self) -> str:
+ """
+ Get the model class name without QEff/QEFF prefix.
+
+ This property extracts the underlying model's class name and removes
+ any QEff or QEFF prefix that may have been added during wrapping.
+
+ Returns:
+ str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel")
+ """
+ mname = self.model.__class__.__name__
+ if mname.startswith("QEff") or mname.startswith("QEFF"):
+ mname = mname[4:]
+ return mname
+
@property
@abstractmethod
- def model_name(self) -> str: ...
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ This is an abstract property that must be implemented by all subclasses.
+ Typically returns: self.model.config.__dict__
+
+ Returns:
+ Dict: The configuration dictionary of the underlying model
+ """
+ pass
@abstractmethod
def export(self, export_dir: Optional[str] = None) -> Path:
@@ -184,11 +205,11 @@ def _export(
example_inputs: Dict[str, torch.Tensor],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
- export_kwargs: Optional[Dict[str, any]] = None,
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
- use_onnx_subfunctions: bool = False,
+ prefill_only: Optional[bool] = False,
+ **export_kwargs,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms
@@ -213,11 +234,16 @@ def _export(
instance using from_pretrained() for re-export.
"""
+ # TODO: Hack for retain_full_kv, handle this outside
+ export_kwargs.pop("retain_full_kv", None)
onnx_path = export_dir / f"{self.model_name}.onnx"
# Return early if ONNX already exists
if onnx_path.is_file():
- self.onnx_path = onnx_path
+ if prefill_only:
+ self.prefill_onnx_path = onnx_path
+ else:
+ self.onnx_path = onnx_path
return onnx_path
# check if the model is in meta state or weights are offloaded
@@ -253,19 +279,6 @@ def _export(
input_names.append(param)
try:
- # Initialize the registry with your custom ops
- export_kwargs = {} if export_kwargs is None else export_kwargs
- if use_onnx_subfunctions:
- warnings.warn(
- "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results."
- )
- apply_torch_patches()
- InvalidIndexProvider.SUBFUNC_ENABLED = True
- output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names]
- export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model)
- self._onnx_transforms.append(RenameFunctionOutputsTransform)
- self._onnx_transforms.append(CustomOpTransform)
-
torch.onnx.export(
self.model,
(example_inputs,),
@@ -309,15 +322,42 @@ def _export(
finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
- if use_onnx_subfunctions:
- undo_torch_patches()
- InvalidIndexProvider.SUBFUNC_ENABLED = False
- self._onnx_transforms.remove(CustomOpTransform)
- self._onnx_transforms.remove(RenameFunctionOutputsTransform)
-
- self.onnx_path = onnx_path
+ if prefill_only:
+ self.prefill_onnx_path = onnx_path
+ else:
+ self.onnx_path = onnx_path
return onnx_path
+ def get_onnx_path(
+ self,
+ prefill_only: Optional[bool] = False,
+ enable_chunking: Optional[bool] = False,
+ specializations: Optional[List[Dict[str, int]]] = None,
+ offload_pt_weights: Optional[bool] = True,
+ use_onnx_subfunctions: Optional[bool] = False,
+ retain_full_kv: Optional[bool] = False,
+ ):
+ kwargs = {
+ "offload_pt_weights": offload_pt_weights,
+ "use_onnx_subfunctions": use_onnx_subfunctions,
+ "retain_full_kv": retain_full_kv,
+ }
+ if prefill_only:
+ if self.prefill_onnx_path is None:
+ kwargs.update(
+ {
+ "prefill_only": prefill_only,
+ "prefill_seq_len": specializations[0].get("seq_len"),
+ "enable_chunking": enable_chunking,
+ }
+ )
+ self.export(**kwargs)
+ return self.prefill_onnx_path
+ else:
+ if self.onnx_path is None:
+ self.export(**kwargs)
+ return self.onnx_path
+
@dump_qconfig
def _compile(
self,
@@ -332,6 +372,10 @@ def _compile(
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
use_onnx_subfunctions: bool = False,
+ prefill_only: Optional[str] = None,
+ offload_pt_weights: Optional[bool] = True,
+ enable_chunking: Optional[bool] = False,
+ retain_full_kv: Optional[bool] = None,
**compiler_options,
) -> str:
"""
@@ -357,11 +401,18 @@ def _compile(
For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""
-
- if onnx_path is None and self.onnx_path is None:
- self.export(use_onnx_subfunctions=use_onnx_subfunctions)
-
- onnx_path = Path(onnx_path or self.onnx_path)
+ onnx_path = Path(
+ onnx_path
+ if onnx_path
+ else self.get_onnx_path(
+ prefill_only,
+ enable_chunking,
+ specializations,
+ offload_pt_weights,
+ use_onnx_subfunctions,
+ retain_full_kv,
+ )
+ )
compile_dir = Path(compile_dir or onnx_path.parent)
qpc_path = compile_dir / "qpc"
if not onnx_path.is_file():
@@ -423,6 +474,7 @@ def _compile(
"mdp_ts_num_devices": mdp_ts_num_devices,
"mdp_ts_json": mdp_ts_json,
"num_speculative_tokens": num_speculative_tokens,
+ "prefill_only": prefill_only,
}
compile_hash = hash_dict_params(compile_hash_params)
@@ -462,6 +514,16 @@ def _compile(
command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
+ if use_onnx_subfunctions:
+
+ class FeatureNotAvailableError(Exception):
+ pass
+
+ exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" {" ".join(command)}'
+ raise FeatureNotAvailableError(
+ "ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model."
+ + f"\nRun following command manually with assert compiler:\n{exec_command}"
+ )
try:
subprocess.run(command, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
@@ -482,5 +544,4 @@ def _compile(
logger.info("Hashed parameters exported successfully.")
self.qpc_path = qpc_path
-
return qpc_path
diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py
index 945850c50..16697cec9 100644
--- a/QEfficient/base/onnx_transforms.py
+++ b/QEfficient/base/onnx_transforms.py
@@ -19,16 +19,20 @@
from QEfficient.customop.ctx_scatter_gather import (
CtxGather,
CtxGather3D,
+ CtxGatherBlockedKV,
CtxGatherFunc,
CtxGatherFunc3D,
+ CtxGatherFuncBlockedKV,
CtxScatter,
CtxScatter3D,
CtxScatterFunc,
CtxScatterFunc3D,
)
from QEfficient.customop.ctx_scatter_gather_cb import (
+ CtxGatherBlockedKVCB,
CtxGatherCB,
CtxGatherCB3D,
+ CtxGatherFuncBlockedKVCB,
CtxGatherFuncCB,
CtxGatherFuncCB3D,
CtxScatterCB,
@@ -91,10 +95,12 @@ class CustomOpTransform(BaseOnnxTransform):
"CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D),
"CtxGatherFunc": (CtxGatherFunc, CtxGather),
"CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D),
- "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
"CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D),
- "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
"CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D),
+ "CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV),
+ "CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB),
+ "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
+ "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
}
@classmethod
diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py
index c7dc8639a..7b15effe7 100644
--- a/QEfficient/customop/ctx_scatter_gather.py
+++ b/QEfficient/customop/ctx_scatter_gather.py
@@ -136,6 +136,7 @@ class CtxGatherFunc(torch.autograd.Function):
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
+ ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices)
return data[batch_indices, head_indices, ctx_indices]
@staticmethod
diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py
index 8a06bc2b1..c15b60810 100644
--- a/QEfficient/customop/ctx_scatter_gather_cb.py
+++ b/QEfficient/customop/ctx_scatter_gather_cb.py
@@ -126,6 +126,7 @@ class CtxGatherFuncCB(torch.autograd.Function):
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = batch_index.view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
+ ctx_indices = torch.where(ctx_indices >= data.shape[2], 0, ctx_indices)
return data[batch_indices, head_indices, ctx_indices]
@staticmethod
diff --git a/QEfficient/diffusers/README.md b/QEfficient/diffusers/README.md
new file mode 100644
index 000000000..4777d48fb
--- /dev/null
+++ b/QEfficient/diffusers/README.md
@@ -0,0 +1,95 @@
+
+
+
+
+# **Diffusion Models on Qualcomm Cloud AI 100**
+
+
+
+
+### π¨ **Experience the Future of AI Image Generation**
+
+* Optimized for Qualcomm Cloud AI 100*
+
+

+
+**Generated with**: `black-forest-labs/FLUX.1-schnell` β’ `"A girl laughing"` β’ 4 steps β’ 0.0 guidance scale β’ β‘
+
+
+
+
+
+
+
+[](https://github.com/huggingface/diffusers)
+
+
+---
+
+## β¨ Overview
+
+QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware.
+
+## π οΈ Installation
+
+### Prerequisites
+
+Ensure you have Python 3.8+ and the required dependencies:
+
+```bash
+# Create Python virtual environment (Recommended Python 3.10)
+sudo apt install python3.10-venv
+python3.10 -m venv qeff_env
+source qeff_env/bin/activate
+pip install -U pip
+```
+
+### Install QEfficient
+
+```bash
+# Install from GitHub (includes diffusers support)
+pip install git+https://github.com/quic/efficient-transformers
+
+# Or build from source
+git clone https://github.com/quic/efficient-transformers.git
+cd efficient-transformers
+pip install build wheel
+python -m build --wheel --outdir dist
+pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl
+```
+
+---
+
+## π― Supported Models
+- β
[`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
+- β
[`lightx2v/Wan2.2-Lightning`](https://huggingface.co/lightx2v/Wan2.2-Lightning)
+
+---
+
+
+## π Examples
+
+Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory:
+
+---
+
+## π€ Contributing
+
+We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details.
+
+
+
+---
+
+## π Acknowledgments
+
+- **HuggingFace Diffusers**: For the excellent foundation library
+---
+
+## π Support
+
+- π **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/)
+- π **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues)
+
+---
+
diff --git a/QEfficient/diffusers/__init__.py b/QEfficient/diffusers/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/__init__.py b/QEfficient/diffusers/models/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/models/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/modeling_utils.py b/QEfficient/diffusers/models/modeling_utils.py
new file mode 100644
index 000000000..59727be2d
--- /dev/null
+++ b/QEfficient/diffusers/models/modeling_utils.py
@@ -0,0 +1,456 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import math
+import os
+from typing import Optional
+
+import torch
+
+
+def get_attention_blocking_config():
+ """
+ Get attention blocking configuration from environment variables.
+
+ Returns:
+ tuple: (blocking_mode, head_block_size, num_kv_blocks, num_q_blocks)
+ - blocking_mode (str): The blocking strategy ('kv', 'q', 'qkv', 'default')
+ - head_block_size (int or None): Number of attention heads per block
+ - num_kv_blocks (int or None): Number of key-value blocks
+ - num_q_blocks (int or None): Number of query blocks
+ """
+ mode = os.environ.get("ATTENTION_BLOCKING_MODE", "default").lower()
+ head_block_size = int(os.environ.get("head_block_size", 0)) or None
+ num_kv_blocks = int(os.environ.get("num_kv_blocks", 0)) or None
+ num_q_blocks = int(os.environ.get("num_q_blocks", 0)) or None
+
+ # Validate blocking mode
+ valid_modes = ["kv", "qkv", "q", "default"]
+ if mode not in valid_modes:
+ raise ValueError(f"Invalid ATTENTION_BLOCKING_MODE: {mode}. Must be one of {valid_modes}")
+
+ return mode, head_block_size, num_kv_blocks, num_q_blocks
+
+
+def apply_head_blocking(
+ q: torch.FloatTensor,
+ k: torch.FloatTensor,
+ v: torch.FloatTensor,
+ head_block_size: int,
+ attention_mask: Optional[torch.FloatTensor] = None,
+) -> torch.FloatTensor:
+ """
+ Forward pass with head-only blocking (default mode).
+
+ This method processes attention heads in blocks while computing full attention
+ matrices for each head block. It's less memory-efficient than other blocking
+ modes but simpler and faster for moderate sequence lengths.
+
+ Args:
+ q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH)
+ k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH)
+ v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH)
+ attention_mask (Optional[torch.FloatTensor]): Attention mask tensor
+
+ Returns:
+ torch.FloatTensor: Attention output of shape (BS, NH, CL, DH)
+ """
+ BS, NH, CL, DH = q.shape
+ scale_factor = 1.0 / math.sqrt(DH)
+
+ # Get head blocking configuration
+ head_block_size = head_block_size or NH
+ num_head_blocks = math.ceil(NH / head_block_size)
+
+ # Optimization: Handle small sequences with standard attention
+ BS, NH, K_CL, DH = k.shape
+ if K_CL <= 512:
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor
+ if attention_mask is not None:
+ scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device))
+ probs = torch.softmax(scores, dim=-1)
+ out = torch.matmul(probs, v)
+ return out
+
+ outputs = []
+
+ # Process each head block independently
+ for head_block_idx in range(num_head_blocks):
+ h_start = head_block_idx * head_block_size
+ h_end = min(h_start + head_block_size, NH)
+
+ # Extract head blocks
+ q_g = q[:, h_start:h_end, :, :]
+ k_g = k[:, h_start:h_end, :, :]
+ v_g = v[:, h_start:h_end, :, :]
+
+ # Compute full attention matrix for this head block
+ qkblock = torch.matmul(q_g, k_g.transpose(-2, -1)) * scale_factor
+
+ # Standard softmax computation
+ probs = torch.softmax(qkblock, dim=-1)
+
+ # Compute attention output
+ output_blocks = torch.matmul(probs, v_g)
+ outputs.append(output_blocks)
+
+ # Concatenate all head blocks along head dimension
+ out = torch.cat(outputs, dim=1) # (BS, NH, CL, DH)
+ return out
+
+
+def apply_kv_blocking(
+ q: torch.FloatTensor,
+ k: torch.FloatTensor,
+ v: torch.FloatTensor,
+ head_block_size: int,
+ num_kv_blocks: int,
+ attention_mask: Optional[torch.FloatTensor] = None,
+) -> torch.FloatTensor:
+ """
+ Forward pass with Key-Value blocking and head blocking.
+
+ This method processes key-value pairs in blocks while keeping queries intact.
+ It uses online softmax to maintain numerical stability and reduce memory usage
+ compared to computing full attention matrices.
+
+ Args:
+ q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH)
+ k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH)
+ v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH)
+ attention_mask (Optional[torch.FloatTensor]): Attention mask tensor
+
+ Returns:
+ torch.FloatTensor: Attention output of shape (BS, NH, CL, DH)
+ """
+ BS, NH, CL, DH = q.shape
+ scale_factor = 1.0 / math.sqrt(DH)
+
+ # Get blocking configuration
+ head_block_size = head_block_size or NH
+ num_kv_blocks = num_kv_blocks or CL
+ num_head_blocks = math.ceil(NH / head_block_size)
+ block_positions = [(i * CL) // num_kv_blocks for i in range(num_kv_blocks)]
+
+ # Handle small sequences with standard attention
+ BS, NH, K_CL, DH = k.shape
+ if K_CL <= 512:
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor
+ if attention_mask is not None:
+ scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device))
+ probs = torch.softmax(scores, dim=-1)
+ out = torch.matmul(probs, v)
+ return out
+
+ head_outputs = []
+
+ # Process each head block
+ for head_block_idx in range(num_head_blocks):
+ h_start = head_block_idx * head_block_size
+ h_end = min(h_start + head_block_size, NH)
+ num_h = h_end - h_start
+
+ q_g = q[:, h_start:h_end, :, :]
+ k_g = k[:, h_start:h_end, :, :]
+ v_g = v[:, h_start:h_end, :, :]
+
+ # Initialize online softmax statistics
+ running_exp_sum = torch.zeros((BS, num_h, CL), device=q.device, dtype=q.dtype)
+ running_max = torch.full((BS, num_h, CL), float("-inf"), device=q.device, dtype=q.dtype)
+ output_blocks = torch.zeros_like(q_g)
+
+ # Process K,V in blocks using online softmax
+ for kv_block_idx in range(num_kv_blocks):
+ ki = block_positions[kv_block_idx]
+
+ # Calculate KV block size
+ if kv_block_idx == num_kv_blocks - 1:
+ real_kv_len = CL - ki
+ else:
+ real_kv_len = block_positions[kv_block_idx + 1] - ki
+
+ k_block = k_g[:, :, ki : ki + real_kv_len, :]
+ v_block = v_g[:, :, ki : ki + real_kv_len, :]
+
+ # Compute attention scores for current KV block
+ qkblock = torch.matmul(q_g, k_block.transpose(-2, -1)) * scale_factor
+
+ # Online softmax: Update running maximum
+ prev_max = running_max.clone()
+ running_max = torch.maximum(prev_max, torch.max(qkblock, dim=-1)[0])
+
+ # Calculate numerical stability adjustments
+ delta_max = prev_max - running_max
+ curr_exp = torch.exp(qkblock - running_max.unsqueeze(-1))
+
+ # Update running sum of exponentials
+ prev_exp_sum = running_exp_sum.clone()
+ curr_exp_sum = torch.einsum("bhqk->bhq", curr_exp)
+ running_exp_sum = prev_exp_sum * torch.exp(delta_max) + curr_exp_sum
+
+ # Compute normalized attention weights
+ inv_running_exp_sum = 1.0 / running_exp_sum
+ softmax_qkblock = curr_exp * inv_running_exp_sum.unsqueeze(-1)
+
+ # Update output with rescaling
+ prev_out = output_blocks.clone()
+ rescale_factor = (prev_exp_sum * inv_running_exp_sum) * torch.exp(delta_max)
+ output_blocks = rescale_factor.unsqueeze(-1) * prev_out + torch.matmul(softmax_qkblock, v_block)
+
+ head_outputs.append(output_blocks)
+
+ out = torch.cat(head_outputs, dim=1) # (BS, NH, CL, DH)
+ return out
+
+
+def apply_q_blocking(
+ q: torch.FloatTensor,
+ k: torch.FloatTensor,
+ v: torch.FloatTensor,
+ head_block_size: int,
+ num_q_blocks: int,
+ attention_mask: Optional[torch.FloatTensor] = None,
+) -> torch.FloatTensor:
+ """
+ Forward pass with Query blocking and head blocking.
+
+ This method processes query tokens in blocks while keeping key-value pairs intact.
+ It's useful when the sequence length is large but memory constraints are primarily
+ due to the query dimension.
+
+ Args:
+ q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH)
+ k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH)
+ v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH)
+ attention_mask (Optional[torch.FloatTensor]): Attention mask tensor
+
+ Returns:
+ torch.FloatTensor: Attention output of shape (BS, NH, CL, DH)
+ """
+ BS, NH, CL, DH = q.shape
+ scale_factor = 1.0 / math.sqrt(DH)
+
+ # Get blocking configuration
+ head_block_size = head_block_size or NH
+ num_q_blocks = num_q_blocks or CL
+ num_head_blocks = math.ceil(NH / head_block_size)
+ q_block_positions = [(i * CL) // num_q_blocks for i in range(num_q_blocks)]
+
+ # Handle small sequences with standard attention
+ BS, NH, K_CL, DH = k.shape
+ if K_CL <= 512:
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor
+ if attention_mask is not None:
+ scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device))
+ probs = torch.softmax(scores, dim=-1)
+ out = torch.matmul(probs, v)
+ return out
+
+ head_outputs = []
+
+ # Process each head block
+ for head_block_idx in range(num_head_blocks):
+ h_start = head_block_idx * head_block_size
+ h_end = min(h_start + head_block_size, NH)
+
+ q_g = q[:, h_start:h_end, :, :]
+ k_g = k[:, h_start:h_end, :, :]
+ v_g = v[:, h_start:h_end, :, :]
+
+ q_output_list = []
+
+ # Process queries in blocks
+ for q_block_idx in range(num_q_blocks):
+ qi = q_block_positions[q_block_idx]
+
+ # Calculate Q block size
+ if q_block_idx == num_q_blocks - 1:
+ real_q_len = CL - qi
+ else:
+ real_q_len = q_block_positions[q_block_idx + 1] - qi
+
+ q_block = q_g[:, :, qi : qi + real_q_len, :]
+
+ # Compute attention for this query block against all keys
+ scores = torch.matmul(q_block, k_g.transpose(-2, -1)) * scale_factor
+ probs = torch.softmax(scores, dim=-1)
+ out_block = torch.matmul(probs, v_g)
+
+ q_output_list.append(out_block)
+
+ # Concatenate query blocks
+ head_output = torch.cat(q_output_list, dim=2)
+ head_outputs.append(head_output)
+
+ out = torch.cat(head_outputs, dim=1) # (BS, NH, CL, DH)
+ return out
+
+
+def apply_qkv_blocking(
+ q: torch.FloatTensor,
+ k: torch.FloatTensor,
+ v: torch.FloatTensor,
+ head_block_size: int,
+ num_kv_blocks: int,
+ num_q_blocks: int,
+ attention_mask: Optional[torch.FloatTensor] = None,
+) -> torch.FloatTensor:
+ """
+ Forward pass with combined Query, Key, Value blocking and head blocking.
+
+ This method implements the most memory-efficient attention computation by blocking
+ along all three dimensions: heads, queries, and key-values.
+
+ Args:
+ q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH)
+ k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH)
+ v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH)
+ attention_mask (Optional[torch.FloatTensor]): Attention mask tensor
+
+ Returns:
+ torch.FloatTensor: Attention output of shape (BS, NH, CL, DH)
+ """
+ BS, NH, CL, DH = q.shape
+ scale_factor = 1.0 / math.sqrt(DH)
+
+ # Get blocking configuration from environment variables
+ head_block_size = head_block_size or NH
+ num_kv_blocks = num_kv_blocks or CL
+ num_q_blocks = num_q_blocks or CL
+ num_head_blocks = math.ceil(NH / head_block_size)
+
+ # Calculate block positions for even distribution
+ kv_block_positions = [(i * CL) // num_kv_blocks for i in range(num_kv_blocks)]
+ q_block_positions = [(i * CL) // num_q_blocks for i in range(num_q_blocks)]
+
+ # Optimization: Use standard attention for small sequences
+ BS, NH, K_CL, DH = k.shape
+ if K_CL <= 512:
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor
+ if attention_mask is not None:
+ scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device))
+ probs = torch.softmax(scores, dim=-1)
+ out = torch.matmul(probs, v)
+ return out
+
+ head_outputs = []
+
+ # Process attention heads in blocks to reduce memory usage
+ for head_block_idx in range(num_head_blocks):
+ h_start = head_block_idx * head_block_size
+ h_end = min(h_start + head_block_size, NH)
+ num_h = h_end - h_start
+
+ # Extract current head block
+ q_g = q[:, h_start:h_end, :, :]
+ k_g = k[:, h_start:h_end, :, :]
+ v_g = v[:, h_start:h_end, :, :]
+ q_output_list = []
+
+ # Process queries in blocks within each head block
+ for q_block_idx in range(num_q_blocks):
+ qi = q_block_positions[q_block_idx]
+
+ # Calculate actual Q block size (handle remainder for last block)
+ if q_block_idx == num_q_blocks - 1:
+ real_q_len = CL - qi
+ else:
+ real_q_len = q_block_positions[q_block_idx + 1] - qi
+
+ q_block = q_g[:, :, qi : qi + real_q_len, :]
+
+ # Initialize online softmax statistics for this Q block
+ running_exp_sum = torch.zeros((BS, num_h, real_q_len), device=q.device, dtype=q.dtype)
+ running_max = torch.full((BS, num_h, real_q_len), float("-inf"), device=q.device, dtype=q.dtype)
+ output_blocks = torch.zeros((BS, num_h, real_q_len, DH), device=q.device, dtype=q.dtype)
+
+ # Process K,V in blocks for this Q block (online softmax)
+ for kv_block_idx in range(num_kv_blocks):
+ ki = kv_block_positions[kv_block_idx]
+
+ # Calculate actual KV block size
+ if kv_block_idx == num_kv_blocks - 1:
+ real_kv_len = CL - ki
+ else:
+ real_kv_len = kv_block_positions[kv_block_idx + 1] - ki
+
+ k_block = k_g[:, :, ki : ki + real_kv_len, :]
+ v_block = v_g[:, :, ki : ki + real_kv_len, :]
+
+ # Compute attention scores for current Q-K block
+ qkblock = torch.matmul(q_block, k_block.transpose(-2, -1)) * scale_factor
+
+ # Online softmax: Update running maximum
+ prev_max = running_max.clone()
+ if qkblock.shape[-1] == 0:
+ running_max = prev_max
+ else:
+ running_max = torch.maximum(prev_max, torch.max(qkblock, dim=-1)[0])
+
+ # Calculate adjustment factor for numerical stability
+ delta_max = prev_max - running_max
+ curr_exp = torch.exp(qkblock - running_max.unsqueeze(-1))
+
+ # Online softmax: Update running sum of exponentials
+ prev_exp_sum = running_exp_sum.clone()
+ curr_exp_sum = torch.einsum("bhqk->bhq", curr_exp)
+ running_exp_sum = prev_exp_sum * torch.exp(delta_max) + curr_exp_sum
+
+ # Compute normalized attention weights for this block
+ inv_running_exp_sum = 1.0 / running_exp_sum
+ softmax_qkblock = curr_exp * inv_running_exp_sum.unsqueeze(-1)
+
+ # Online softmax: Update output with rescaling of previous blocks
+ prev_out = output_blocks.clone()
+ rescale_factor = (prev_exp_sum * inv_running_exp_sum) * torch.exp(delta_max)
+ output_blocks = rescale_factor.unsqueeze(-1) * prev_out + torch.matmul(softmax_qkblock, v_block)
+
+ q_output_list.append(output_blocks)
+
+ # Concatenate all Q blocks for this head block
+ head_output = torch.cat(q_output_list, dim=2)
+ head_outputs.append(head_output)
+
+ # Concatenate all head blocks
+ out = torch.cat(head_outputs, dim=1)
+ return out
+
+
+def compute_blocked_attention(
+ q: torch.FloatTensor,
+ k: torch.FloatTensor,
+ v: torch.FloatTensor,
+ head_block_size: int,
+ num_kv_blocks: int,
+ num_q_blocks: int,
+ blocking_mode: str = "default",
+ attention_mask: Optional[torch.FloatTensor] = None,
+) -> torch.FloatTensor:
+ """
+ Main dispatcher function for different attention blocking strategies.
+
+ Args:
+ q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH)
+ k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH)
+ v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH)
+ head_block_size (int) : Head blocking size
+ num_kv_blocks (int) : Number of KV blocks
+ num_q_blocks (int) : Number of Q blocks
+ blocking_mode (str): Blocking strategy ('kv', 'q', 'qkv', 'default')
+ attention_mask (Optional[torch.FloatTensor]): Attention mask tensor
+
+ Returns:
+ torch.FloatTensor: Attention output of shape (BS, NH, CL, DH)
+ """
+ if blocking_mode == "kv":
+ return apply_kv_blocking(q, k, v, head_block_size, num_kv_blocks, attention_mask)
+ elif blocking_mode == "q":
+ return apply_q_blocking(q, k, v, head_block_size, num_q_blocks, attention_mask)
+ elif blocking_mode == "qkv":
+ return apply_qkv_blocking(q, k, v, head_block_size, num_kv_blocks, num_q_blocks, attention_mask)
+ else: # default
+ return apply_head_blocking(q, k, v, head_block_size, attention_mask)
diff --git a/QEfficient/diffusers/models/normalization.py b/QEfficient/diffusers/models/normalization.py
new file mode 100644
index 000000000..933832ed8
--- /dev/null
+++ b/QEfficient/diffusers/models/normalization.py
@@ -0,0 +1,40 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+from typing import Optional, Tuple
+
+import torch
+from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+
+
+class QEffAdaLayerNormZero(AdaLayerNormZero):
+ def forward(
+ self,
+ x: torch.Tensor,
+ shift_msa: Optional[torch.Tensor] = None,
+ scale_msa: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x
+
+
+class QEffAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
+ def forward(
+ self,
+ x: torch.Tensor,
+ scale_msa: Optional[torch.Tensor] = None,
+ shift_msa: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x
+
+
+class QEffAdaLayerNormContinuous(AdaLayerNormContinuous):
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
+ emb = conditioning_embedding
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py
new file mode 100644
index 000000000..4fb5c3f12
--- /dev/null
+++ b/QEfficient/diffusers/models/pytorch_transforms.py
@@ -0,0 +1,65 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
+from diffusers.models.transformers.transformer_flux import (
+ FluxAttention,
+ FluxAttnProcessor,
+ FluxSingleTransformerBlock,
+ FluxTransformer2DModel,
+ FluxTransformerBlock,
+)
+from diffusers.models.transformers.transformer_wan import WanAttention, WanAttnProcessor, WanTransformer3DModel
+from torch import nn
+
+from QEfficient.base.pytorch_transforms import ModuleMappingTransform
+from QEfficient.customop.rms_norm import CustomRMSNormAIC
+from QEfficient.diffusers.models.normalization import (
+ QEffAdaLayerNormContinuous,
+ QEffAdaLayerNormZero,
+ QEffAdaLayerNormZeroSingle,
+)
+from QEfficient.diffusers.models.transformers.transformer_flux import (
+ QEffFluxAttention,
+ QEffFluxAttnProcessor,
+ QEffFluxSingleTransformerBlock,
+ QEffFluxTransformer2DModel,
+ QEffFluxTransformerBlock,
+)
+from QEfficient.diffusers.models.transformers.transformer_wan import (
+ QEffWanAttention,
+ QEffWanAttnProcessor,
+ QEffWanTransformer3DModel,
+)
+
+
+class CustomOpsTransform(ModuleMappingTransform):
+ _module_mapping = {
+ RMSNorm: CustomRMSNormAIC,
+ nn.RMSNorm: CustomRMSNormAIC, # for torch.nn.RMSNorm
+ }
+
+
+class AttentionTransform(ModuleMappingTransform):
+ _module_mapping = {
+ FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock,
+ FluxTransformerBlock: QEffFluxTransformerBlock,
+ FluxTransformer2DModel: QEffFluxTransformer2DModel,
+ FluxAttention: QEffFluxAttention,
+ FluxAttnProcessor: QEffFluxAttnProcessor,
+ WanAttnProcessor: QEffWanAttnProcessor,
+ WanAttention: QEffWanAttention,
+ WanTransformer3DModel: QEffWanTransformer3DModel,
+ }
+
+
+class NormalizationTransform(ModuleMappingTransform):
+ _module_mapping = {
+ AdaLayerNormZero: QEffAdaLayerNormZero,
+ AdaLayerNormZeroSingle: QEffAdaLayerNormZeroSingle,
+ AdaLayerNormContinuous: QEffAdaLayerNormContinuous,
+ }
diff --git a/QEfficient/diffusers/models/transformers/__init__.py b/QEfficient/diffusers/models/transformers/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/models/transformers/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/transformers/transformer_flux.py b/QEfficient/diffusers/models/transformers/transformer_flux.py
new file mode 100644
index 000000000..40b7e3e7e
--- /dev/null
+++ b/QEfficient/diffusers/models/transformers/transformer_flux.py
@@ -0,0 +1,339 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.transformers.transformer_flux import (
+ FluxAttention,
+ FluxAttnProcessor,
+ FluxSingleTransformerBlock,
+ FluxTransformer2DModel,
+ FluxTransformerBlock,
+ _get_qkv_projections,
+)
+
+from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention, get_attention_blocking_config
+from QEfficient.utils.logging_utils import logger
+
+
+def qeff_apply_rotary_emb(
+ x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, :, None, :]
+ sin = sin[None, :, None, :]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+ B, S, H, D = x.shape
+ x_real, x_imag = x.reshape(B, -1, H, D // 2, 2).unbind(-1)
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+ return out
+
+
+class QEffFluxAttnProcessor(FluxAttnProcessor):
+ _attention_backend = None
+ _parallel_config = None
+
+ def __call__(
+ self,
+ attn: "QEffFluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = qeff_apply_rotary_emb(query, image_rotary_emb)
+ key = qeff_apply_rotary_emb(key, image_rotary_emb)
+
+ # Get blocking configuration
+ blocking_mode, head_block_size, num_kv_blocks, num_q_blocks = get_attention_blocking_config()
+ # Apply blocking using pipeline_utils
+ hidden_states = compute_blocked_attention(
+ query.transpose(1, 2),
+ key.transpose(1, 2),
+ value.transpose(1, 2),
+ blocking_mode=blocking_mode,
+ head_block_size=head_block_size,
+ num_kv_blocks=num_kv_blocks,
+ num_q_blocks=num_q_blocks,
+ attention_mask=attention_mask,
+ )
+
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class QEffFluxAttention(FluxAttention):
+ def __qeff_init__(self):
+ processor = QEffFluxAttnProcessor()
+ self.processor = processor
+
+
+class QEffFluxSingleTransformerBlock(FluxSingleTransformerBlock):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ shift_msa, scale_msa, gate = torch.split(temb, 1)
+ residual = hidden_states
+ norm_hidden_states = self.norm(hidden_states, scale_msa, shift_msa)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ # if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(torch.finfo(torch.float32).min, torch.finfo(torch.float32).max)
+
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
+
+
+class QEffFluxTransformerBlock(FluxTransformerBlock):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ temb1 = tuple(torch.split(temb[:6], 1))
+ temb2 = tuple(torch.split(temb[6:], 1))
+ norm_hidden_states = self.norm1(hidden_states, shift_msa=temb1[0], scale_msa=temb1[1])
+ gate_msa, shift_mlp, scale_mlp, gate_mlp = temb1[-4:]
+
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, shift_msa=temb2[0], scale_msa=temb2[1])
+
+ c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = temb2[-4:]
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ # if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class QEffFluxTransformer2DModel(FluxTransformer2DModel):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ adaln_emb: torch.Tensor = None,
+ adaln_single_emb: torch.Tensor = None,
+ adaln_out: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ return_dict: bool = True,
+ controlnet_blocks_repeat: bool = False,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=adaln_emb[index_block],
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ # For Xlabs ControlNet.
+ if controlnet_blocks_repeat:
+ hidden_states = (
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
+ )
+ else:
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=adaln_single_emb[index_block],
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_single_block_samples is not None:
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
+
+ hidden_states = self.norm_out(hidden_states, adaln_out)
+ output = self.proj_out(hidden_states)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py
new file mode 100644
index 000000000..31d3be2ce
--- /dev/null
+++ b/QEfficient/diffusers/models/transformers/transformer_wan.py
@@ -0,0 +1,291 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+"""
+QEfficient WAN Transformer Implementation
+
+This module provides optimized implementations of WAN transformers
+with various attention blocking strategies for memory efficiency and performance optimization.
+The implementation includes multiple blocking modes: head-only, KV-blocking, Q-blocking,
+and combined QKV-blocking.
+"""
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.transformers.transformer_wan import (
+ WanAttention,
+ WanAttnProcessor,
+ WanTransformer3DModel,
+ _get_qkv_projections,
+)
+from diffusers.utils import set_weights_and_activate_adapters
+
+from QEfficient.diffusers.models.modeling_utils import (
+ compute_blocked_attention,
+ get_attention_blocking_config,
+)
+
+
+class QEffWanAttnProcessor(WanAttnProcessor):
+ """
+ QEfficient WAN Attention Processor with Memory-Efficient Blocking Strategies.
+
+ This processor implements multiple attention blocking modes to reduce memory usage
+ and enable processing of longer sequences. It supports:
+ - Head blocking: Process attention heads in chunks
+ - KV blocking: Process key-value pairs in blocks
+ - Q blocking: Process query tokens in blocks
+ - QKV blocking: Combined query, key, and value blocking
+
+ Environment Variables:
+ ATTENTION_BLOCKING_MODE: Controls blocking strategy ('kv', 'q', 'qkv', 'default')
+ head_block_size: Number of attention heads to process per block
+ num_kv_blocks: Number of blocks for key-value processing
+ num_q_blocks: Number of blocks for query processing
+ """
+
+ def __call__(
+ self,
+ attn: "WanAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ """
+ Main attention processing pipeline with support for multiple blocking strategies.
+
+ This method orchestrates the complete attention computation including:
+ 1. QKV projection and normalization
+ 2. Rotary position embedding application
+ 3. Attention computation with selected blocking strategy
+ 4. Output projection
+
+ Args:
+ attn (WanAttention): The attention module instance
+ hidden_states (torch.Tensor): Input hidden states
+ encoder_hidden_states (Optional[torch.Tensor]): Cross-attention encoder states
+ attention_mask (Optional[torch.Tensor]): Attention mask
+ rotary_emb (Optional[Tuple[torch.Tensor, torch.Tensor]]): Rotary embeddings (cos, sin)
+
+ Returns:
+ torch.Tensor: Processed hidden states after attention
+ """
+ # Project inputs to query, key, value
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ # Apply layer normalization to queries and keys
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ # Reshape for multi-head attention: (batch, seq, dim) -> (batch, seq, heads, head_dim)
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ # Apply rotary position embeddings if provided
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ """Apply rotary position embeddings to the input tensor."""
+ # Split into real and imaginary parts for complex rotation
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2].type_as(hidden_states)
+ sin = freqs_sin[..., 1::2].type_as(hidden_states)
+
+ # Apply rotation: (x1 + ix2) * (cos + isin) = (x1*cos - x2*sin) + i(x1*sin + x2*cos)
+ real = x1 * cos - x2 * sin
+ img = x1 * sin + x2 * cos
+ x_rot = torch.stack([real, img], dim=-1)
+ return x_rot.flatten(-2).type_as(hidden_states)
+
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
+
+ # Get blocking configuration
+ blocking_mode, head_block_size, num_kv_blocks, num_q_blocks = get_attention_blocking_config()
+ # Apply blocking using pipeline_utils
+ hidden_states = compute_blocked_attention(
+ query.transpose(1, 2),
+ key.transpose(1, 2),
+ value.transpose(1, 2),
+ head_block_size,
+ num_kv_blocks,
+ num_q_blocks,
+ blocking_mode=blocking_mode,
+ attention_mask=attention_mask,
+ )
+
+ # Reshape back to original format
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+ # Apply output projection layers
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class QEffWanAttention(WanAttention):
+ """
+ QEfficient WAN Attention module with optimized processor.
+
+ This class extends the base WanAttention with QEfficient optimizations,
+ automatically setting up the QEffWanAttnProcessor for memory-efficient
+ attention computation.
+ """
+
+ def __qeff_init__(self):
+ """Initialize the QEfficient attention processor."""
+ processor = QEffWanAttnProcessor()
+ self.processor = processor
+
+
+class QEffWanTransformer3DModel(WanTransformer3DModel):
+ """
+ QEfficient 3D WAN Transformer Model with adapter support.
+
+ This model extends the base WanTransformer3DModel with QEfficient optimizations.
+ """
+
+ def set_adapters(
+ self,
+ adapter_names: Union[List[str], str],
+ weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
+ ):
+ """
+ Set the currently active adapters for use in the diffusion network.
+
+ This method manages PEFT adapters, allowing for efficient fine-tuning
+ and model customization without modifying the base model parameters.
+
+ Args:
+ adapter_names (Union[List[str], str]): Names of adapters to activate
+ weights (Optional[Union[float, Dict, List[float], List[Dict], List[None]]]):
+ Weights for each adapter. Can be:
+ - Single float: Applied to all adapters
+ - List of floats: One weight per adapter
+ - Dict: Detailed weight configuration
+ - None: Uses default weight of 1.0
+
+ Raises:
+ ValueError: If adapter names and weights lists have different lengths
+
+ Note:
+ - Adapters enable parameter-efficient fine-tuning
+ - Multiple adapters can be active simultaneously with different weights
+ - Weights control the influence of each adapter on the model output
+ """
+ # Normalize adapter names to list format
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
+
+ # Expand weights into a list, one entry per adapter
+ # Examples for 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
+ if not isinstance(weights, list):
+ weights = [weights] * len(adapter_names)
+
+ if len(adapter_names) != len(weights):
+ raise ValueError(
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
+ )
+
+ # Set None values to default of 1.0
+ # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
+ weights = [w if w is not None else 1.0 for w in weights]
+
+ # Expand weights using model-specific scaling function
+ # e.g. [{...}, 7] -> [{expanded dict...}, 7]
+ scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[
+ self.config._class_name
+ ] # updated to use WanTransformer3DModel
+ weights = scale_expansion_fn(self, weights)
+ set_weights_and_activate_adapters(self, adapter_names, weights)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ temb: torch.Tensor,
+ timestep_proj: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Forward pass of the 3D WAN Transformer.
+
+ This method implements the complete forward pass including:
+ 1. Patch embedding of input
+ 2. Rotary embedding preparation
+ 3. Cross-attention with encoder states
+ 4. Transformer block processing
+ 5. Output normalization and projection
+
+ Args:
+ hidden_states (torch.Tensor): Input tensor to transform
+ encoder_hidden_states (torch.Tensor): Cross-attention encoder states
+ rotary_emb (torch.Tensor): Rotary position embeddings
+ temb (torch.Tensor): Time embedding for diffusion process
+ timestep_proj (torch.Tensor): Projected timestep embeddings
+ encoder_hidden_states_image (Optional[torch.Tensor]): Image encoder states for I2V
+ return_dict (bool): Whether to return a dictionary or tuple
+ attention_kwargs (Optional[Dict[str, Any]]): Additional attention arguments
+
+ Returns:
+ Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ Transformed hidden states, either as tensor or in a dictionary
+ """
+ # Prepare rotary embeddings by splitting along batch dimension
+ rotary_emb = torch.split(rotary_emb, 1, dim=0)
+
+ # Apply patch embedding and reshape for transformer processing
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, H*W, C)
+
+ # Concatenate image and text encoder states if image conditioning is present
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # Standard forward pass
+ for block in self.blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
+
+ # Output normalization, projection & unpatchify
+ if temb.ndim == 3:
+ # Handle 3D time embeddings: batch_size, seq_len, inner_dim (WAN 2.2 T2V)
+ shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift = shift.squeeze(2)
+ scale = scale.squeeze(2)
+ else:
+ # Handle 2D time embeddings: batch_size, inner_dim
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+
+ # Ensure tensors are on the same device as hidden_states
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ # Apply adaptive layer normalization with time conditioning
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+
+ # Final output projection
+ hidden_states = self.proj_out(hidden_states)
+
+ # Store output for return (compiler optimization)
+ output = hidden_states
+
+ # Return in requested format
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/QEfficient/diffusers/pipelines/__init__.py b/QEfficient/diffusers/pipelines/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/configs/flux_config.json b/QEfficient/diffusers/pipelines/configs/flux_config.json
new file mode 100644
index 000000000..73b92265f
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/configs/flux_config.json
@@ -0,0 +1,99 @@
+{
+ "description": "Default configuration for Flux pipeline",
+
+ "modules":
+ {
+ "text_encoder":
+ {
+ "specializations":{
+ "batch_size": 1,
+ "seq_len": 77
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+
+ },
+ "text_encoder_2":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only": true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "transformer":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 4,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts-mos": 1,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "vae_decoder":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "channels": 16
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "aic-enable-depth-first": true,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ }
+ }
+}
diff --git a/QEfficient/diffusers/pipelines/configs/wan_config.json b/QEfficient/diffusers/pipelines/configs/wan_config.json
new file mode 100644
index 000000000..3f5edce07
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/configs/wan_config.json
@@ -0,0 +1,36 @@
+{
+ "description": "Default configuration for Wan pipeline with unified transformer (model_type: 1 for high noise; model_type:2 for low noise)",
+ "modules": {
+ "transformer": {
+ "specializations": [
+ {
+ "batch_size": "1",
+ "num_channels": "16",
+ "steps": "1",
+ "sequence_length": "512",
+ "model_type": 1
+ },
+ {
+ "batch_size": "1",
+ "num_channels": "16",
+ "steps": "1",
+ "sequence_length": "512",
+ "model_type": 2
+ }
+ ],
+ "compilation": {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 16,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts_mos": 1
+ },
+ "execute": {
+ "device_ids": null
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/QEfficient/diffusers/pipelines/flux/__init__.py b/QEfficient/diffusers/pipelines/flux/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/flux/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py
new file mode 100644
index 000000000..511746469
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py
@@ -0,0 +1,854 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+# TODO: Pipeline Architecture Improvements
+# 1. Introduce QEffDiffusionPipeline base class to provide unified export, compile,
+# and inference APIs across all diffusion pipelines, promoting code reusability
+# and consistent interface design.
+# 2. Implement persistent QPC session management strategy to retain/drop compiled model
+# sessions in memory across all pipeline modules.
+
+import os
+import time
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from diffusers import FluxPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+from tqdm import tqdm
+
+from QEfficient.diffusers.pipelines.pipeline_module import (
+ QEffFluxTransformerModel,
+ QEffTextEncoder,
+ QEffVAE,
+)
+from QEfficient.diffusers.pipelines.pipeline_utils import (
+ ONNX_SUBFUNCTION_MODULE,
+ ModulePerf,
+ QEffPipelineOutput,
+ calculate_compressed_latent_dimension,
+ compile_modules_parallel,
+ compile_modules_sequential,
+ config_manager,
+ set_module_device_ids,
+)
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils.logging_utils import logger
+
+
+class QEffFluxPipeline:
+ """
+ QEfficient-optimized Flux pipeline for high-performance text-to-image generation on Qualcomm AI hardware.
+
+ This pipeline provides an optimized implementation of the Flux diffusion model specifically designed
+ for deployment on Qualcomm AI Cloud (QAIC) devices. It wraps the original HuggingFace Flux model
+ components with QEfficient-optimized versions that can be exported to ONNX format and compiled
+ into Qualcomm Program Container (QPC) files for efficient inference.
+
+ The pipeline supports the complete Flux workflow including:
+ - Dual text encoding with CLIP and T5 encoders
+ - Transformer-based denoising with adaptive layer normalization
+ - VAE decoding for final image generation
+ - Performance monitoring and optimization
+
+ Attributes:
+ text_encoder (QEffTextEncoder): Optimized CLIP text encoder for pooled embeddings
+ text_encoder_2 (QEffTextEncoder): Optimized T5 text encoder for sequence embeddings
+ transformer (QEffFluxTransformerModel): Optimized Flux transformer for denoising
+ vae_decode (QEffVAE): Optimized VAE decoder for latent-to-image conversion
+ modules (Dict[str, Any]): Dictionary of all pipeline modules for batch operations
+ model (FluxPipeline): Original HuggingFace Flux model reference
+ tokenizer: CLIP tokenizer for text preprocessing
+ scheduler: Diffusion scheduler for timestep management
+
+ Example:
+ >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> images = pipeline(
+ ... prompt="A beautiful sunset over mountains",
+ ... height=512,
+ ... width=512,
+ ... num_inference_steps=28
+ ... )
+ >>> images.images[0].save("generated_image.png")
+ """
+
+ _hf_auto_class = FluxPipeline
+
+ def __init__(self, model, *args, **kwargs):
+ """
+ Initialize the QEfficient Flux pipeline.
+
+ This pipeline provides an optimized implementation of the Flux text-to-image model
+ for deployment on Qualcomm AI hardware. It wraps the original HuggingFace Flux model
+ components with QEfficient-optimized versions that can be exported to ONNX and compiled
+ for QAIC devices.
+
+ Args:
+ model: Pre-loaded FluxPipeline model
+ **kwargs: Additional arguments including height and width
+ """
+
+ # Wrap model components with QEfficient optimized versions
+ self.model = model
+ self.text_encoder = QEffTextEncoder(model.text_encoder)
+ self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2)
+ self.transformer = QEffFluxTransformerModel(model.transformer)
+ self.vae_decode = QEffVAE(model.vae, "decoder")
+
+ # Store all modules in a dictionary for easy iteration during export/compile
+ self.modules = {
+ "text_encoder": self.text_encoder,
+ "text_encoder_2": self.text_encoder_2,
+ "transformer": self.transformer,
+ "vae_decoder": self.vae_decode,
+ }
+
+ # Copy tokenizers and scheduler from the original model
+ self.tokenizer = model.tokenizer
+ self.text_encoder.tokenizer = model.tokenizer
+ self.text_encoder_2.tokenizer = model.tokenizer_2
+ self.tokenizer_max_length = model.tokenizer_max_length
+ self.scheduler = model.scheduler
+
+ # Override VAE forward method to use decode directly
+ self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode(
+ latent_sample, return_dict
+ )
+
+ # Sync max position embeddings between text encoders
+ self.text_encoder_2.model.config.max_position_embeddings = (
+ self.text_encoder.model.config.max_position_embeddings
+ )
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ **kwargs,
+ ):
+ """
+ Load a pretrained Flux model from HuggingFace Hub or local path and wrap it with QEfficient optimizations.
+
+ This class method provides a convenient way to instantiate a QEffFluxPipeline from a pretrained
+ Flux model. It automatically loads the base FluxPipeline model in float32 precision on CPU
+ and wraps all components with QEfficient-optimized versions for QAIC deployment.
+
+ Args:
+ pretrained_model_name_or_path (str or os.PathLike): Either a HuggingFace model identifier
+ (e.g., "black-forest-labs/FLUX.1-schnell") or a local path to a saved model directory.
+ **kwargs: Additional keyword arguments passed to FluxPipeline.from_pretrained().
+
+ Returns:
+ QEffFluxPipeline: A fully initialized pipeline instance with QEfficient-optimized components
+ ready for export, compilation, and inference on QAIC devices.
+
+ Raises:
+ ValueError: If the model path is invalid or model cannot be loaded
+ OSError: If there are issues accessing the model files
+ RuntimeError: If model initialization fails
+
+ Example:
+ >>> # Load from HuggingFace Hub
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>>
+ >>> # Load from local path
+ >>> pipeline = QEffFluxPipeline.from_pretrained("/path/to/local/flux/model")
+ >>>
+ >>> # Load with custom cache directory
+ >>> pipeline = QEffFluxPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-dev",
+ ... cache_dir="/custom/cache/dir"
+ ... )
+ """
+ # Load the base Flux model in float32 on CPU
+ model = cls._hf_auto_class.from_pretrained(
+ pretrained_model_name_or_path,
+ torch_dtype=torch.float32,
+ device_map="cpu",
+ **kwargs,
+ )
+
+ return cls(
+ model=model,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ **kwargs,
+ )
+
+ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ """
+ Export all pipeline modules to ONNX format for deployment preparation.
+
+ This method systematically exports each pipeline component (CLIP text encoder, T5 text encoder,
+ Flux transformer, and VAE decoder) to ONNX format. Each module is exported with its specific
+ configuration including dynamic axes, input/output specifications, and optimization settings.
+
+ The export process prepares the models for subsequent compilation to QPC format, enabling
+ efficient inference on QAIC hardware. ONNX subfunctions can be used for certain modules
+ to optimize memory usage and performance.
+
+ Args:
+ export_dir (str, optional): Target directory for saving ONNX model files. If None,
+ uses the default export directory structure based on model name and configuration.
+ The directory will be created if it doesn't exist.
+ use_onnx_subfunctions (bool, default=False): Whether to enable ONNX subfunction
+ optimization for supported modules. This can optimize thegraph and
+ improve compilation efficiency for models like the transformer.
+
+ Returns:
+ str: Absolute path to the export directory containing all ONNX model files.
+ Each module will have its own subdirectory with the exported ONNX file.
+
+ Raises:
+ RuntimeError: If ONNX export fails for any module
+ OSError: If there are issues creating the export directory or writing files
+ ValueError: If module configurations are invalid
+
+ Note:
+ - All models are exported in float32 precision for maximum compatibility
+ - Dynamic axes are configured to support variable batch sizes and sequence lengths
+ - The export process may take several minutes depending on model size
+ - Exported ONNX files can be large (several GB for complete pipeline)
+
+ Example:
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> export_path = pipeline.export(
+ ... export_dir="/path/to/export",
+ ... use_onnx_subfunctions=True
+ ... )
+ >>> print(f"Models exported to: {export_path}")
+ """
+ for module_name, module_obj in tqdm(self.modules.items(), desc="Exporting modules", unit="module"):
+ # Get ONNX export configuration for this module
+ example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params()
+
+ export_params = {
+ "inputs": example_inputs,
+ "output_names": output_names,
+ "dynamic_axes": dynamic_axes,
+ "export_dir": export_dir,
+ }
+
+ if use_onnx_subfunctions and module_name in ONNX_SUBFUNCTION_MODULE:
+ export_params["use_onnx_subfunctions"] = True
+
+ module_obj.export(**export_params)
+
+ @staticmethod
+ def get_default_config_path() -> str:
+ """
+ Get the absolute path to the default Flux pipeline configuration file.
+
+ Returns:
+ str: Absolute path to the flux_config.json file containing default pipeline
+ configuration settings for compilation and device allocation.
+ """
+ return "QEfficient/diffusers/pipelines/configs/flux_config.json"
+
+ def compile(
+ self,
+ compile_config: Optional[str] = None,
+ parallel: bool = False,
+ height: int = 512,
+ width: int = 512,
+ use_onnx_subfunctions: bool = False,
+ ) -> None:
+ """
+ Compile ONNX models into optimized QPC format for deployment on Qualcomm AI hardware.
+
+ Args:
+ compile_config (str, optional): Path to a JSON configuration file containing
+ compilation settings, device mappings, and optimization parameters. If None,
+ uses the default configuration from get_default_config_path().
+ parallel (bool, default=False): Compilation mode selection:
+ - True: Compile modules in parallel using ThreadPoolExecutor for faster processing
+ - False: Compile modules sequentially for lower resource usage
+ height (int, default=512): Target image height in pixels.
+ width (int, default=512): Target image width in pixels.
+ use_onnx_subfunctions (bool, default=False): Whether to export models with ONNX
+ subfunctions before compilation.
+
+ Raises:
+ RuntimeError: If compilation fails for any module or if QAIC compiler is not available
+ FileNotFoundError: If ONNX models haven't been exported or config file is missing
+ ValueError: If configuration parameters are invalid
+ OSError: If there are issues with file I/O during compilation
+
+ Example:
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> # Sequential compilation with default config
+ >>> pipeline.compile(height=1024, width=1024)
+ >>>
+ >>> # Parallel compilation with custom config
+ >>> pipeline.compile(
+ ... compile_config="/path/to/custom_config.json",
+ ... parallel=True,
+ ... height=512,
+ ... width=512
+ ... )
+ """
+ # Ensure all modules are exported to ONNX before compilation
+ if any(
+ path is None
+ for path in [
+ self.text_encoder.onnx_path,
+ self.text_encoder_2.onnx_path,
+ self.transformer.onnx_path,
+ self.vae_decode.onnx_path,
+ ]
+ ):
+ self.export(use_onnx_subfunctions=use_onnx_subfunctions)
+
+ # Load compilation configuration
+ config_manager(self, config_source=compile_config)
+
+ # Calculate compressed latent dimension using utility function
+ cl, latent_height, latent_width = calculate_compressed_latent_dimension(
+ height, width, self.model.vae_scale_factor
+ )
+
+ # Prepare dynamic specialization updates based on image dimensions
+ specialization_updates = {
+ "transformer": {"cl": cl},
+ "vae_decoder": {
+ "latent_height": latent_height,
+ "latent_width": latent_width,
+ },
+ }
+
+ # Use generic utility functions for compilation
+ if parallel:
+ compile_modules_parallel(self.modules, self.custom_config, specialization_updates)
+ else:
+ compile_modules_sequential(self.modules, self.custom_config, specialization_updates)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device_ids: Optional[List[int]] = None,
+ ):
+ """
+ Encode text prompts using the T5 text encoder for detailed semantic understanding.
+
+ T5 provides rich sequence embeddings that capture fine-grained text details,
+ complementing CLIP's global representation in Flux's dual encoder setup.
+
+ Args:
+ prompt (str or List[str]): Input prompt(s) to encode
+ num_images_per_prompt (int): Number of images to generate per prompt
+ max_sequence_length (int): Maximum token sequence length (default: 512)
+ device_ids (List[int], optional): QAIC device IDs for inference
+
+ Returns:
+ tuple: (prompt_embeds, inference_time)
+ - prompt_embeds (torch.Tensor): Encoded embeddings [batch*num_images, seq_len, 4096]
+ - inference_time (float): T5 encoder inference time in seconds
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ # Tokenize prompts with padding and truncation
+ text_inputs = self.text_encoder_2.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ # Check for truncation and warn user
+ untruncated_ids = self.text_encoder_2.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.text_encoder_2.tokenizer.batch_decode(
+ untruncated_ids[:, self.text_encoder_2.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ f"The following part of your input was truncated because `max_sequence_length` is set to "
+ f"{self.text_encoder_2.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ # Initialize QAIC inference session if not already created
+ if self.text_encoder_2.qpc_session is None:
+ self.text_encoder_2.qpc_session = QAICInferenceSession(
+ str(self.text_encoder_2.qpc_path), device_ids=device_ids
+ )
+
+ # Allocate output buffers for QAIC inference
+ text_encoder_2_output = {
+ "last_hidden_state": np.random.rand(
+ batch_size, max_sequence_length, self.text_encoder_2.model.config.d_model
+ ).astype(np.int32),
+ }
+ self.text_encoder_2.qpc_session.set_buffers(text_encoder_2_output)
+
+ # Prepare input for QAIC inference
+ aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
+
+ # Run T5 encoder inference and measure time
+ start_t5_time = time.perf_counter()
+ prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"])
+ end_t5_time = time.perf_counter()
+ text_encoder_2_perf = end_t5_time - start_t5_time
+
+ # Duplicate embeddings for multiple images per prompt
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, text_encoder_2_perf
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device_ids: Optional[List[int]] = None,
+ ):
+ """
+ Encode text prompts using the CLIP text encoder for global semantic representation.
+
+ CLIP provides pooled embeddings that capture high-level semantic meaning,
+ working alongside T5's detailed sequence embeddings in Flux's dual encoder setup.
+
+ Args:
+ prompt (str or List[str]): Input prompt(s) to encode
+ num_images_per_prompt (int): Number of images to generate per prompt
+ device_ids (List[int], optional): QAIC device IDs for inference
+
+ Returns:
+ tuple: (pooled_prompt_embeds, inference_time)
+ - pooled_prompt_embeds (torch.Tensor): Pooled embeddings [batch*num_images, 768]
+ - inference_time (float): CLIP encoder inference time in seconds
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ # Tokenize prompts
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+
+ # Check for truncation and warn user
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ f"The following part of your input was truncated because CLIP can only handle sequences up to "
+ f"{self.tokenizer_max_length} tokens: {removed_text}"
+ )
+
+ # Initialize QAIC inference session if not already created
+ if self.text_encoder.qpc_session is None:
+ self.text_encoder.qpc_session = QAICInferenceSession(str(self.text_encoder.qpc_path), device_ids=device_ids)
+
+ # Allocate output buffers for QAIC inference
+ text_encoder_output = {
+ "last_hidden_state": np.random.rand(
+ batch_size, self.tokenizer_max_length, self.text_encoder.model.config.hidden_size
+ ).astype(np.float32),
+ "pooler_output": np.random.rand(batch_size, self.text_encoder.model.config.hidden_size).astype(np.int32),
+ }
+ self.text_encoder.qpc_session.set_buffers(text_encoder_output)
+
+ # Prepare input for QAIC inference
+ aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
+
+ # Run CLIP encoder inference and measure time
+ start_text_encoder_time = time.perf_counter()
+ aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input)
+ end_text_encoder_time = time.perf_counter()
+ text_encoder_perf = end_text_encoder_time - start_text_encoder_time
+ # Extract pooled output (used for conditioning in Flux)
+ prompt_embeds = torch.tensor(aic_embeddings["pooler_output"])
+
+ # Duplicate embeddings for multiple images per prompt
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds, text_encoder_perf
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ ):
+ """
+ Encode text prompts using Flux's dual text encoder architecture.
+
+ Flux employs both CLIP and T5 encoders for comprehensive text understanding:
+ - CLIP provides pooled embeddings for global semantic conditioning
+ - T5 provides detailed sequence embeddings for fine-grained text control
+
+ Args:
+ prompt (str or List[str]): Primary prompt(s) for both encoders
+ prompt_2 (str or List[str], optional): Secondary prompt(s) for T5. If None, uses primary prompt
+ num_images_per_prompt (int): Number of images to generate per prompt
+ prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 embeddings
+ pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings
+ max_sequence_length (int): Maximum sequence length for T5 tokenization
+
+ Returns:
+ tuple: (prompt_embeds, pooled_prompt_embeds, text_ids, encoder_perf_times)
+ - prompt_embeds (torch.Tensor): T5 sequence embeddings [batch*num_images, seq_len, 4096]
+ - pooled_prompt_embeds (torch.Tensor): CLIP pooled embeddings [batch*num_images, 768]
+ - text_ids (torch.Tensor): Position IDs for text tokens [seq_len, 3]
+ - encoder_perf_times (List[float]): Performance times [CLIP_time, T5_time]
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ # Use primary prompt for both encoders if secondary not provided
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # Encode with CLIP (returns pooled embeddings)
+ pooled_prompt_embeds, text_encoder_perf = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device_ids=self.text_encoder.device_ids,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ # Encode with T5 (returns sequence embeddings)
+ prompt_embeds, text_encoder_2_perf = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device_ids=self.text_encoder_2.device_ids,
+ )
+
+ # Create text position IDs (required by Flux transformer)
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids, [text_encoder_perf, text_encoder_2_perf]
+
+ def __call__(
+ self,
+ height: int = 512,
+ width: int = 512,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ custom_config_path: Optional[str] = None,
+ parallel_compile: bool = False,
+ use_onnx_subfunctions: bool = False,
+ ):
+ """
+ Generate images from text prompts using the QEfficient-optimized Flux pipeline on QAIC hardware.
+
+ This is the main entry point for text-to-image generation. It orchestrates the complete Flux
+ diffusion pipeline optimized for Qualcomm AI Cloud devices.
+
+ Args:
+ height (int, optional): Target image height in pixels. Must be divisible by 8. Default: 512.
+ width (int, optional): Target image width in pixels. Must be divisible by 8. Default: 512.
+ prompt (str or List[str]): Primary text prompt(s) describing the desired image(s).
+ Required unless `prompt_embeds` is provided.
+ prompt_2 (str or List[str], optional): Secondary prompt for T5 encoder. If None, uses `prompt`.
+ negative_prompt (str or List[str], optional): Negative prompt(s) describing what to avoid.
+ Only used when `true_cfg_scale > 1.0`.
+ negative_prompt_2 (str or List[str], optional): Secondary negative prompt for T5. If None, uses `negative_prompt`.
+ true_cfg_scale (float, optional): True classifier-free guidance scale. Values > 1.0 enable
+ negative prompting. Default: 1.0 (disabled).
+ num_inference_steps (int, optional): Number of denoising steps. Default: 28.
+ timesteps (List[int], optional): Custom timestep schedule. If provided, overrides `num_inference_steps`.
+ guidance_scale (float, optional): Guidance scale for classifier-free guidance. Default: 3.5.
+ num_images_per_prompt (int, optional): Number of images to generate per prompt. Default: 1.
+ generator (torch.Generator or List[torch.Generator], optional): Random generator for reproducibility.
+ latents (torch.FloatTensor, optional): Pre-generated latent tensors. If None, random latents are generated.
+ prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 text embeddings. Shape: [batch, seq_len, 4096].
+ pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings. Shape: [batch, 768].
+ negative_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative T5 embeddings.
+ negative_pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative CLIP embeddings.
+ output_type (str, optional): Output format. Options: "pil" (default), "np", or "latent".
+ callback_on_step_end (Callable, optional): Callback function executed after each denoising step.
+ callback_on_step_end_tensor_inputs (List[str], optional): Tensor names to pass to callback. Default: ["latents"].
+ max_sequence_length (int, optional): Maximum token sequence length for T5 encoder. Default: 512.
+ custom_config_path (str, optional): Path to custom JSON configuration file for compilation settings.
+ parallel_compile (bool, optional): Whether to compile modules in parallel. Default: False.
+ use_onnx_subfunctions (bool, optional): Whether to export transformer blocks as ONNX subfunctions. Default: False.
+
+ Returns:
+ QEffPipelineOutput: A dataclass containing:
+ - images: Generated image(s) in the format specified by `output_type`
+ - pipeline_module: Performance metrics for each pipeline component (text encoders, transformer, VAE)
+
+ Raises:
+ ValueError: If input validation fails or parameters are incompatible.
+ RuntimeError: If compilation fails or QAIC devices are unavailable.
+ FileNotFoundError: If custom config file is specified but not found.
+
+ Example:
+ >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> result = pipeline(
+ ... prompt="A serene mountain landscape at sunset",
+ ... height=1024,
+ ... width=1024,
+ ... num_inference_steps=28,
+ ... guidance_scale=7.5
+ ... )
+ >>> result.images[0].save("mountain_sunset.png")
+ >>> print(f"Transformer inference time: {sum(result.pipeline_module[2].perf):.2f}s")
+ """
+ device = self.model._execution_device
+
+ if height is None or width is None:
+ logger.warning("Height or width is None. Setting default values of 512 for both dimensions.")
+
+ self.compile(
+ compile_config=custom_config_path,
+ parallel=parallel_compile,
+ height=height,
+ width=width,
+ use_onnx_subfunctions=use_onnx_subfunctions,
+ )
+
+ # Set device IDs for all modules based on configuration
+ set_module_device_ids(self)
+
+ # Validate all inputs
+ self.model.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ # Step 2: Determine batch size from inputs
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Step 3: Encode prompts with both text encoders
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+
+ (prompt_embeds, pooled_prompt_embeds, text_ids, text_encoder_perf) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # Encode negative prompts if using true classifier-free guidance
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # Step 4: Prepare timesteps for denoising
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # Step 5: Prepare initial latents
+ num_channels_latents = self.transformer.model.config.in_channels // 4
+ latents, latent_image_ids = self.model.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # Step 6: Calculate compressed latent dimension for transformer buffer allocation
+ cl, _, _ = calculate_compressed_latent_dimension(height, width, self.model.vae_scale_factor)
+
+ # Initialize transformer inference session
+ if self.transformer.qpc_session is None:
+ self.transformer.qpc_session = QAICInferenceSession(
+ str(self.transformer.qpc_path), device_ids=self.transformer.device_ids
+ )
+
+ # Allocate output buffer for transformer
+ output_buffer = {
+ "output": np.random.rand(batch_size, cl, self.transformer.model.config.in_channels).astype(np.float32),
+ }
+ self.transformer.qpc_session.set_buffers(output_buffer)
+
+ transformer_perf = []
+ self.scheduler.set_begin_index(0)
+
+ # Step 7: Denoising loop
+ with self.model.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Prepare timestep embedding
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ temb = self.transformer.model.time_text_embed(timestep, pooled_prompt_embeds)
+
+ # Compute AdaLN (Adaptive Layer Normalization) embeddings for dual transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(self.transformer.model.transformer_blocks)):
+ block = self.transformer.model.transformer_blocks[block_idx]
+ # Process through norm1 and norm1_context
+ f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1)
+ f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1)
+ adaln_emb.append(torch.cat(list(f1) + list(f2)))
+ adaln_dual_emb = torch.stack(adaln_emb)
+
+ # Compute AdaLN embeddings for single transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(self.transformer.model.single_transformer_blocks)):
+ block = self.transformer.model.single_transformer_blocks[block_idx]
+ f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1)
+ adaln_emb.append(torch.cat(list(f1)))
+ adaln_single_emb = torch.stack(adaln_emb)
+
+ # Compute output AdaLN embedding
+ temp = self.transformer.model.norm_out
+ adaln_out = temp.linear(temp.silu(temb))
+
+ # Normalize timestep to [0, 1] range
+ timestep = timestep / 1000
+
+ # Prepare all inputs for transformer inference
+ inputs_aic = {
+ "hidden_states": latents.detach().numpy(),
+ "encoder_hidden_states": prompt_embeds.detach().numpy(),
+ "pooled_projections": pooled_prompt_embeds.detach().numpy(),
+ "timestep": timestep.detach().numpy(),
+ "img_ids": latent_image_ids.detach().numpy(),
+ "txt_ids": text_ids.detach().numpy(),
+ "adaln_emb": adaln_dual_emb.detach().numpy(),
+ "adaln_single_emb": adaln_single_emb.detach().numpy(),
+ "adaln_out": adaln_out.detach().numpy(),
+ }
+
+ # Run transformer inference and measure time
+ start_transformer_step_time = time.perf_counter()
+ outputs = self.transformer.qpc_session.run(inputs_aic)
+ end_transformer_step_time = time.perf_counter()
+ transformer_perf.append(end_transformer_step_time - start_transformer_step_time)
+
+ noise_pred = torch.from_numpy(outputs["output"])
+
+ # Update latents using scheduler (x_t -> x_t-1)
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Handle dtype mismatch (workaround for MPS backend bug)
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ latents = latents.to(latents_dtype)
+
+ # Execute callback if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # Update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # Step 8: Decode latents to images (unless output_type is "latent")
+ if output_type == "latent":
+ image = latents
+ else:
+ # Unpack and denormalize latents
+ latents = self.model._unpack_latents(latents, height, width, self.model.vae_scale_factor)
+ latents = (latents / self.vae_decode.model.scaling_factor) + self.vae_decode.model.shift_factor
+
+ # Initialize VAE decoder inference session
+ if self.vae_decode.qpc_session is None:
+ self.vae_decode.qpc_session = QAICInferenceSession(
+ str(self.vae_decode.qpc_path), device_ids=self.vae_decode.device_ids
+ )
+
+ # Allocate output buffer for VAE decoder
+ output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.int32)}
+ self.vae_decode.qpc_session.set_buffers(output_buffer)
+
+ # Run VAE decoder inference and measure time
+ inputs = {"latent_sample": latents.numpy()}
+ start_decode_time = time.perf_counter()
+ image = self.vae_decode.qpc_session.run(inputs)
+ end_decode_time = time.perf_counter()
+ vae_decode_perf = end_decode_time - start_decode_time
+
+ # Post-process image
+ image_tensor = torch.from_numpy(image["sample"])
+ image = self.model.image_processor.postprocess(image_tensor, output_type=output_type)
+
+ # Build performance metrics
+ perf_metrics = [
+ ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]),
+ ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]),
+ ModulePerf(module_name="transformer", perf=transformer_perf),
+ ModulePerf(module_name="vae_decoder", perf=vae_decode_perf),
+ ]
+
+ return QEffPipelineOutput(
+ pipeline_module=perf_metrics,
+ images=image,
+ )
diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py
new file mode 100644
index 000000000..19e7701d4
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/pipeline_module.py
@@ -0,0 +1,632 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+from typing import Dict, List, Tuple
+
+import torch
+import torch.nn as nn
+from diffusers.models.transformers.transformer_wan import WanTransformerBlock
+
+from QEfficient.base.modeling_qeff import QEFFBaseModel
+from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
+from QEfficient.diffusers.models.pytorch_transforms import (
+ AttentionTransform,
+ CustomOpsTransform,
+ NormalizationTransform,
+)
+from QEfficient.diffusers.models.transformers.transformer_flux import (
+ QEffFluxSingleTransformerBlock,
+ QEffFluxTransformerBlock,
+)
+from QEfficient.transformers.models.pytorch_transforms import (
+ T5ModelTransform,
+)
+from QEfficient.utils import constants
+
+
+class QEffTextEncoder(QEFFBaseModel):
+ """
+ Wrapper for text encoder models with ONNX export and QAIC compilation capabilities.
+
+ This class handles text encoder models (CLIP, T5) with specific transformations and
+ optimizations for efficient inference on Qualcomm AI hardware. It applies custom
+ PyTorch and ONNX transformations to prepare models for deployment.
+
+ Attributes:
+ model (nn.Module): The wrapped text encoder model (deep copy of original)
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform, T5ModelTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying text encoder model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the text encoder wrapper.
+
+ Args:
+ model (nn.Module): The text encoder model to wrap (CLIP or T5)
+ """
+ super().__init__(model)
+ self.model = model
+
+ def get_onnx_params(self) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the text encoder.
+
+ Creates example inputs, dynamic axes specifications, and output names
+ tailored to the specific text encoder type (CLIP vs T5).
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+
+ # Create example input with max sequence length
+ example_inputs = {
+ "input_ids": torch.zeros((bs, self.model.config.max_position_embeddings), dtype=torch.int64),
+ }
+
+ # Define which dimensions can vary at runtime
+ dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}}
+
+ # T5 only outputs hidden states, CLIP outputs both hidden states and pooled output
+ if self.model.__class__.__name__ == "T5EncoderModel":
+ output_names = ["last_hidden_state"]
+ else:
+ output_names = ["last_hidden_state", "pooler_output"]
+ example_inputs["output_hidden_states"] = False
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = {},
+ ) -> str:
+ """
+ Export the text encoder model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ **export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations)
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffUNet(QEFFBaseModel):
+ """
+ Wrapper for UNet models with ONNX export and QAIC compilation capabilities.
+
+ This class handles UNet models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. UNet is commonly used in
+ diffusion models for image generation tasks.
+
+ Attributes:
+ model (nn.Module): The wrapped UNet model
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying UNet model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the UNet wrapper.
+
+ Args:
+ model (nn.Module): The pipeline model containing the UNet
+ """
+ super().__init__(model.unet)
+ self.model = model.unet
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = {},
+ ) -> str:
+ """
+ Export the UNet model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ **export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffVAE(QEFFBaseModel):
+ """
+ Wrapper for Variational Autoencoder (VAE) models with ONNX export and QAIC compilation.
+
+ This class handles VAE models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. VAE models are used in diffusion
+ pipelines for encoding images to latent space and decoding latents back to images.
+
+ Attributes:
+ model (nn.Module): The wrapped VAE model (deep copy of original)
+ type (str): VAE operation type ("encoder" or "decoder")
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying VAE model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module, type: str) -> None:
+ """
+ Initialize the VAE wrapper.
+
+ Args:
+ model (nn.Module): The pipeline model containing the VAE
+ type (str): VAE operation type ("encoder" or "decoder")
+ """
+ super().__init__(model)
+ self.model = model
+
+ # To have different hashing for encoder/decoder
+ self.model.config["type"] = type
+
+ def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the VAE decoder.
+
+ Args:
+ latent_height (int): Height of latent representation (default: 32)
+ latent_width (int): Width of latent representation (default: 32)
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+
+ # VAE decoder takes latent representation as input
+ example_inputs = {
+ "latent_sample": torch.randn(bs, 16, latent_height, latent_width),
+ "return_dict": False,
+ }
+
+ output_names = ["sample"]
+
+ # All dimensions except channels can be dynamic
+ dynamic_axes = {
+ "latent_sample": {0: "batch_size", 1: "channels", 2: "latent_height", 3: "latent_width"},
+ }
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = {},
+ ) -> str:
+ """
+ Export the VAE model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ **export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffFluxTransformerModel(QEFFBaseModel):
+ """
+ Wrapper for Flux Transformer2D models with ONNX export and QAIC compilation capabilities.
+
+ This class handles Flux Transformer2D models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. Flux uses a transformer-based diffusion
+ architecture instead of traditional UNet, with dual transformer blocks and adaptive layer
+ normalization (AdaLN) for conditioning.
+
+ Attributes:
+ model (nn.Module): The wrapped Flux transformer model
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [AttentionTransform, NormalizationTransform, CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying Flux transformer model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the Flux transformer wrapper.
+
+ Args:
+ model (nn.Module): The Flux transformer model to wrap
+ """
+ super().__init__(model)
+
+ def get_onnx_params(
+ self,
+ batch_size: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
+ seq_length: int = constants.FLUX_ONNX_EXPORT_SEQ_LENGTH,
+ cl: int = constants.FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM,
+ ) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the Flux transformer.
+
+ Creates example inputs for all Flux-specific inputs including hidden states,
+ text embeddings, timestep conditioning, and AdaLN embeddings.
+
+ Args:
+ batch_size (int): Batch size for example inputs (default: FLUX_ONNX_EXPORT_BATCH_SIZE)
+ seq_length (int): Text sequence length (default: FLUX_ONNX_EXPORT_SEQ_LENGTH)
+ cl (int): Compressed latent dimension (default: FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM)
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ example_inputs = {
+ # Latent representation of the image
+ "hidden_states": torch.randn(batch_size, cl, self.model.config.in_channels, dtype=torch.float32),
+ "encoder_hidden_states": torch.randn(
+ batch_size, seq_length, self.model.config.joint_attention_dim, dtype=torch.float32
+ ),
+ "pooled_projections": torch.randn(batch_size, self.model.config.pooled_projection_dim, dtype=torch.float32),
+ "timestep": torch.tensor([1.0], dtype=torch.float32),
+ "img_ids": torch.randn(cl, 3, dtype=torch.float32),
+ "txt_ids": torch.randn(seq_length, 3, dtype=torch.float32),
+ # AdaLN embeddings for dual transformer blocks
+ # Shape: [num_layers, FLUX_ADALN_DUAL_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM]
+ "adaln_emb": torch.randn(
+ self.model.config["num_layers"],
+ constants.FLUX_ADALN_DUAL_BLOCK_CHUNKS,
+ constants.FLUX_ADALN_HIDDEN_DIM,
+ dtype=torch.float32,
+ ),
+ # AdaLN embeddings for single transformer blocks
+ # Shape: [num_single_layers, FLUX_ADALN_SINGLE_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM]
+ "adaln_single_emb": torch.randn(
+ self.model.config["num_single_layers"],
+ constants.FLUX_ADALN_SINGLE_BLOCK_CHUNKS,
+ constants.FLUX_ADALN_HIDDEN_DIM,
+ dtype=torch.float32,
+ ),
+ # Output AdaLN embedding
+ # Shape: [batch_size, FLUX_ADALN_OUTPUT_DIM] for final projection
+ "adaln_out": torch.randn(batch_size, constants.FLUX_ADALN_OUTPUT_DIM, dtype=torch.float32),
+ }
+
+ output_names = ["output"]
+
+ # Define dynamic dimensions for runtime flexibility
+ dynamic_axes = {
+ "hidden_states": {0: "batch_size", 1: "cl"},
+ "encoder_hidden_states": {0: "batch_size", 1: "seq_len"},
+ "pooled_projections": {0: "batch_size"},
+ "timestep": {0: "steps"},
+ "img_ids": {0: "cl"},
+ }
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = {},
+ use_onnx_subfunctions: bool = False,
+ ) -> str:
+ """
+ Export the Flux transformer model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions)
+ use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
+ for better modularity and potential optimization
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+
+ if use_onnx_subfunctions:
+ export_kwargs = {
+ "export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock},
+ "use_onnx_subfunctions": True,
+ }
+
+ # Sort _use_default_values in config to ensure consistent hash generation during export
+ self.model.config["_use_default_values"].sort()
+
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ offload_pt_weights=False, # As weights are needed with AdaLN changes
+ **export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations)
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffWanUnifiedTransformer(QEFFBaseModel):
+ """
+ Wrapper for WAN Unified Transformer with ONNX export and QAIC compilation capabilities.
+
+ This class handles the unified WAN transformer model that combines high and low noise transformers
+ into a single model for efficient deployment. Based on the timestep shape, the model dynamically
+ selects between high and low noise transformers during inference.
+
+ The wrapper applies specific transformations and optimizations for efficient inference on
+ Qualcomm AI hardware, particularly for video diffusion models.
+
+ Attributes:
+ model (nn.Module): The QEffWanUnifiedWrapper model that combines high/low noise transformers
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [AttentionTransform, CustomOpsTransform, NormalizationTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ def __init__(self, unified_transformer):
+ """
+ Initialize the Wan unified transformer.
+
+ Args:
+ model (nn.Module): Wan unified transformer model
+ """
+ super().__init__(unified_transformer)
+ self.model = unified_transformer
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying Wan transformer model
+ """
+ return self.model.config.__dict__
+
+ def get_onnx_params(self):
+ """
+ Generate ONNX export configuration for the Wan transformer.
+
+ Creates example inputs for all Wan-specific inputs including hidden states,
+ text embeddings, timestep conditioning,
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ batch_size = constants.WAN_ONNX_EXPORT_BATCH_SIZE
+ example_inputs = {
+ # hidden_states = [ bs, in_channels, frames, latent_height, latent_width]
+ "hidden_states": torch.randn(
+ batch_size,
+ self.model.config.in_channels,
+ constants.WAN_ONNX_EXPORT_LATENT_FRAMES,
+ constants.WAN_ONNX_EXPORT_LATENT_HEIGHT_180P,
+ constants.WAN_ONNX_EXPORT_LATENT_WIDTH_180P,
+ dtype=torch.float32,
+ ),
+ # encoder_hidden_states = [BS, seq len , text dim]
+ "encoder_hidden_states": torch.randn(
+ batch_size, constants.WAN_ONNX_EXPORT_SEQ_LEN, constants.WAN_TEXT_EMBED_DIM, dtype=torch.float32
+ ),
+ # Rotary position embeddings: [2, context_length, 1, rotary_dim]; 2 is from tuple of cos, sin freqs
+ "rotary_emb": torch.randn(
+ 2, constants.WAN_ONNX_EXPORT_CL_180P, 1, constants.WAN_ONNX_EXPORT_ROTARY_DIM, dtype=torch.float32
+ ),
+ # Timestep embeddings: [batch_size=1, embedding_dim]
+ "temb": torch.randn(batch_size, constants.WAN_TEXT_EMBED_DIM, dtype=torch.float32),
+ # Projected timestep embeddings: [batch_size=1, projection_dim, embedding_dim]
+ "timestep_proj": torch.randn(
+ batch_size,
+ constants.WAN_PROJECTION_DIM,
+ constants.WAN_TEXT_EMBED_DIM,
+ dtype=torch.float32,
+ ),
+ # Timestep parameter: Controls high/low noise transformer selection based on shape
+ "tsp": torch.ones(1, dtype=torch.int64),
+ }
+
+ output_names = ["output"]
+
+ dynamic_axes = {
+ "hidden_states": {
+ 0: "batch_size",
+ 1: "num_channels",
+ 2: "num_frames",
+ 3: "latent_height",
+ 4: "latent_width",
+ },
+ "timestep": {0: "steps"},
+ "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
+ "rotary_emb": {1: "cl"},
+ "tsp": {0: "model_type"},
+ }
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = {},
+ use_onnx_subfunctions: bool = False,
+ ) -> str:
+ """Export the Wan transformer model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions)
+ use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
+ for better modularity and potential optimization
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ if use_onnx_subfunctions:
+ export_kwargs = {"export_modules_as_functions": {WanTransformerBlock}, "use_onnx_subfunctions": True}
+
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ offload_pt_weights=True,
+ **export_kwargs,
+ )
+
+ def compile(self, specializations, **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations)
+ """
+ self._compile(specializations=specializations, **compiler_options)
diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py
new file mode 100644
index 000000000..4bb305447
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/pipeline_utils.py
@@ -0,0 +1,350 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import math
+import os
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+
+from QEfficient.utils._utils import load_json
+from QEfficient.utils.logging_utils import logger
+
+
+def calculate_compressed_latent_dimension(height: int, width: int, vae_scale_factor: int) -> int:
+ """
+ Calculate the compressed latent dimension.
+ Args:
+ height (int): Target image height in pixels
+ width (int): Target image width in pixels
+ vae_scale_factor (int): VAE downsampling factor (typically 8 for Flux)
+
+ Returns:
+ int: Compressed latent dimension (cl) for transformer input buffer allocation
+ """
+ latent_height = height // vae_scale_factor
+ latent_width = width // vae_scale_factor
+ # cl = compressed latent dimension (divided by 4 for Flux's 2x2 packing)
+ cl = (latent_height * latent_width) // 4
+ return cl, latent_height, latent_width
+
+
+def calculate_latent_dimensions_with_frames(
+ height: int,
+ width: int,
+ num_frames: int,
+ vae_scale_factor_spatial: int,
+ vae_scale_factor_temporal: int,
+ patch_height: int,
+ patch_width: int,
+) -> int:
+ """
+ Calculate the latent dimensions for video generation models.
+
+ This method computes the compressed sequence length (cl),
+ Latent height, Latent width , Latent frames based on the
+ target video dimensions, VAE scale factors, and patch sizes.
+
+ Args:
+ height (int): Target video height in pixels
+ width (int): Target video width in pixels
+ num_frames (int): Target video frames in pixels
+ vae_scale_factor_spatial (int): spatial vae_scale_factor from model config
+ vae_scale_factor_temporal (int): temporal vae_scale_factor from model config
+ patch_height (int): patch_height from model config
+ patch_width (int): patch_width from model config
+
+ Returns:
+ tuple: (cl, latent_height, latent_width)
+ - cl (int): Compressed latent dimension for transformer input
+ - latent_height (int): Height in latent space
+ - latent_width (int): Width in latent space
+ - latent_frames (int): frames in latent space
+
+ Mathematical Formula:
+ latent_height = height // vae_scale_factor_spatial
+ latent_width = width // vae_scale_factor_spatial
+ latent_frames = math.ceil(num_frames / vae_scale_factor_temporal)
+ cl = (latent_height // patch_height) * (latent_width // patch_width) * latent_frames
+
+ """
+ # Calculate latent space dimensions after VAE encoding
+ latent_height = height // vae_scale_factor_spatial
+ latent_width = width // vae_scale_factor_spatial
+ latent_frames = math.ceil(num_frames / vae_scale_factor_temporal)
+ cl = (latent_height // patch_height * latent_width // patch_width) * latent_frames
+ return cl, latent_height, latent_width, latent_frames
+
+
+def config_manager(cls, config_source: Optional[str] = None):
+ """
+ JSON-based compilation configuration manager for diffusion pipelines.
+
+ Supports loading configuration from JSON files only. Automatically detects
+ model type and handles model-specific requirements.
+ Initialize the configuration manager.
+
+ Args:
+ config_source: Path to JSON configuration file. If None, uses default config.
+ """
+ if config_source is None:
+ config_source = cls.get_default_config_path()
+
+ if not isinstance(config_source, str):
+ raise ValueError("config_source must be a path to JSON configuration file")
+
+ # Direct use of load_json utility - no wrapper needed
+ if not os.path.exists(config_source):
+ raise FileNotFoundError(f"Configuration file not found: {config_source}")
+
+ cls.custom_config = load_json(config_source)
+
+
+def set_module_device_ids(cls):
+ """
+ Set device IDs for each module based on the custom configuration.
+
+ Iterates through all modules in the pipeline and assigns device IDs
+ from the configuration file to each module's device_ids attribute.
+ """
+ config_modules = cls.custom_config["modules"]
+ for module_name, module_obj in cls.modules.items():
+ module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"]
+
+
+def compile_modules_parallel(
+ modules: Dict[str, Any],
+ config: Dict[str, Any],
+ specialization_updates: Dict[str, Dict[str, Any]] = None,
+) -> None:
+ """
+ Compile multiple pipeline modules in parallel using ThreadPoolExecutor.
+
+ Args:
+ modules: Dictionary of module_name -> module_object pairs to compile
+ config: Configuration dictionary containing module-specific compilation settings
+ specialization_updates: Optional dictionary of module_name -> specialization_updates
+ to apply dynamic values (e.g., image dimensions)
+ """
+
+ def _prepare_and_compile(module_name: str, module_obj: Any) -> None:
+ """Prepare specializations and compile a single module."""
+ specializations = config["modules"][module_name]["specializations"].copy()
+ compile_kwargs = config["modules"][module_name]["compilation"]
+
+ if (
+ specialization_updates and module_name in specialization_updates
+ ): # Apply specialization updates if available
+ if isinstance(specializations, list): # for unified models spec will be [{high_noise}, {low_noise}]
+ for i, spec in enumerate(specializations):
+ spec.update(specialization_updates[module_name][i])
+ else:
+ specializations.update(specialization_updates[module_name])
+ specializations = [specializations]
+ else:
+ specializations = [specializations]
+ # Compile with prepared specializations
+ module_obj.compile(specializations=specializations, **compile_kwargs)
+
+ # Execute compilations in parallel
+ with ThreadPoolExecutor(max_workers=len(modules)) as executor:
+ futures = {executor.submit(_prepare_and_compile, name, obj): name for name, obj in modules.items()}
+
+ with tqdm(total=len(futures), desc="Compiling modules", unit="module") as pbar:
+ for future in as_completed(futures):
+ try:
+ future.result()
+ except Exception as e:
+ logger.error(f"Compilation failed for {futures[future]}: {e}")
+ raise
+ pbar.update(1)
+
+
+def compile_modules_sequential(
+ modules: Dict[str, Any],
+ config: Dict[str, Any],
+ specialization_updates: Dict[str, Dict[str, Any]] = None,
+) -> None:
+ """
+ Compile multiple pipeline modules sequentially.
+
+ This function provides a generic way to compile diffusion pipeline modules
+ sequentially, which is the default behavior for backward compatibility.
+
+ Args:
+ modules: Dictionary of module_name -> module_object pairs to compile
+ config: Configuration dictionary containing module-specific compilation settings
+ specialization_updates: Optional dictionary of module_name -> specialization_updates
+ to apply dynamic values (e.g., image dimensions)
+
+ """
+ for module_name, module_obj in tqdm(modules.items(), desc="Compiling modules", unit="module"):
+ module_config = config["modules"]
+ specializations = module_config[module_name]["specializations"].copy()
+ compile_kwargs = module_config[module_name]["compilation"]
+
+ if (
+ specialization_updates and module_name in specialization_updates
+ ): # Apply specialization updates if available
+ if isinstance(specializations, list): # for unified models spec will be [{high_noise}, {low_noise}]
+ for i, spec in enumerate(specializations):
+ spec.update(specialization_updates[module_name][i])
+ else:
+ specializations.update(specialization_updates[module_name])
+ specializations = [specializations]
+ else:
+ specializations = [specializations]
+ # Compile with prepared specializations
+ module_obj.compile(specializations=specializations, **compile_kwargs)
+
+
+@dataclass(frozen=True)
+class ModulePerf:
+ """
+ Data class to store performance metrics for a pipeline module.
+
+ Attributes:
+ module_name: Name of the pipeline module (e.g., 'text_encoder', 'transformer', 'vae_decoder')
+ perf: Performance metric in seconds. Can be a single float for modules that run once,
+ or a list of floats for modules that run multiple times (e.g., transformer steps)
+ """
+
+ module_name: str
+ perf: int
+
+
+@dataclass(frozen=True)
+class QEffPipelineOutput:
+ """
+ Data class to store the output of a QEfficient diffusion pipeline.
+
+ Attributes:
+ pipeline_module: List of ModulePerf objects containing performance metrics for each module
+ images: Generated images as either a list of PIL Images or numpy array
+ """
+
+ pipeline_module: list[ModulePerf]
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+ def __repr__(self):
+ output_str = "=" * 60 + "\n"
+ output_str += "QEfficient Diffusers Pipeline Inference Report\n"
+ output_str += "=" * 60 + "\n\n"
+
+ # Module-wise inference times
+ output_str += "Module-wise Inference Times:\n"
+ output_str += "-" * 60 + "\n"
+
+ # Calculate E2E time while iterating
+ e2e_time = 0
+ for module_perf in self.pipeline_module:
+ module_name = module_perf.module_name
+ inference_time = module_perf.perf
+
+ # Add to E2E time
+ e2e_time += sum(inference_time) if isinstance(inference_time, list) else inference_time
+
+ # Format module name for display
+ display_name = module_name.replace("_", " ").title()
+
+ # Handle transformer specially as it has a list of times
+ if isinstance(inference_time, list) and len(inference_time) > 0:
+ total_time = sum(inference_time)
+ avg_time = total_time / len(inference_time)
+ output_str += f" {display_name:25s} {total_time:.4f} s\n"
+ output_str += f" - Total steps: {len(inference_time)}\n"
+ output_str += f" - Average per step: {avg_time:.4f} s\n"
+ output_str += f" - Min step time: {min(inference_time):.4f} s\n"
+ output_str += f" - Max step time: {max(inference_time):.4f} s\n"
+ else:
+ # Single inference time value
+ output_str += f" {display_name:25s} {inference_time:.4f} s\n"
+
+ output_str += "-" * 60 + "\n\n"
+
+ # Print E2E time after all modules
+ output_str += f"End-to-End Inference Time: {e2e_time:.4f} s\n\n"
+ output_str += "=" * 60 + "\n"
+
+ return output_str
+
+
+# List of module name that require special handling during export
+# when use_onnx_subfunctions is enabled
+ONNX_SUBFUNCTION_MODULE = ["transformer"]
+
+
+class QEffWanUnifiedWrapper(nn.Module):
+ """
+ A wrapper class that combines WAN high and low noise transformers into a single unified transformer.
+
+ This wrapper dynamically selects between high and low noise transformers based on the timestep shape
+ in the ONNX graph during inference. This approach enables efficient deployment of both transformer
+ variants in a single model.
+
+ Attributes:
+ transformer_high(nn.Module): The high noise transformer component
+ transformer_low(nn.Module): The low noise transformer component
+ config: Configuration shared between both transformers (from high noise transformer)
+ """
+
+ def __init__(self, transformer_high, transformer_low):
+ super().__init__()
+ self.transformer_high = transformer_high
+ self.transformer_low = transformer_low
+ # Both high and low noise transformers share the same configuration
+ self.config = transformer_high.config
+
+ def forward(
+ self,
+ hidden_states,
+ encoder_hidden_states,
+ rotary_emb,
+ temb,
+ timestep_proj,
+ tsp,
+ attention_kwargs=None,
+ return_dict=False,
+ ):
+ # Condition based on timestep shape
+ is_high_noise = tsp.shape[0] == torch.tensor(1)
+
+ high_hs = hidden_states.detach()
+ ehs = encoder_hidden_states.detach()
+ rhs = rotary_emb.detach()
+ ths = temb.detach()
+ projhs = timestep_proj.detach()
+
+ noise_pred_high = self.transformer_high(
+ hidden_states=high_hs,
+ encoder_hidden_states=ehs,
+ rotary_emb=rhs,
+ temb=ths,
+ timestep_proj=projhs,
+ attention_kwargs=attention_kwargs,
+ return_dict=return_dict,
+ )[0]
+
+ noise_pred_low = self.transformer_low(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ rotary_emb=rotary_emb,
+ temb=temb,
+ timestep_proj=timestep_proj,
+ attention_kwargs=attention_kwargs,
+ return_dict=return_dict,
+ )[0]
+
+ # Select based on timestep condition
+ noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low)
+ return noise_pred
diff --git a/QEfficient/diffusers/pipelines/wan/__init__.py b/QEfficient/diffusers/pipelines/wan/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/wan/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py
new file mode 100644
index 000000000..edae438ae
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py
@@ -0,0 +1,758 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+"""
+QEfficient WAN Pipeline Implementation
+
+This module provides an optimized implementation of the WAN pipeline
+for high-performance text-to-video generation on Qualcomm AI hardware.
+The pipeline supports WAN 2.2 architectures with unified transformer.
+
+TODO: 1. Update Vae, umt5 to Qaic; present running on cpu
+"""
+
+import os
+import time
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from diffusers import WanPipeline
+
+from QEfficient.diffusers.pipelines.pipeline_module import QEffWanUnifiedTransformer
+from QEfficient.diffusers.pipelines.pipeline_utils import (
+ ONNX_SUBFUNCTION_MODULE,
+ ModulePerf,
+ QEffPipelineOutput,
+ QEffWanUnifiedWrapper,
+ calculate_latent_dimensions_with_frames,
+ compile_modules_parallel,
+ compile_modules_sequential,
+ config_manager,
+ set_module_device_ids,
+)
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils import constants
+from QEfficient.utils.logging_utils import logger
+
+
+class QEffWanPipeline:
+ """
+ QEfficient-optimized WAN pipeline for high-performance text-to-video generation on Qualcomm AI hardware.
+
+ This pipeline provides an optimized implementation of the WAN diffusion model
+ specifically designed for deployment on Qualcomm AI Cloud (QAIC) devices. It extends the original
+ HuggingFace WAN model with QEfficient-optimized components that can be exported to ONNX format
+ and compiled into Qualcomm Program Container (QPC) files for efficient video generation.
+
+ The pipeline supports the complete WAN workflow including:
+ - UMT5 text encoding for rich semantic understanding
+ - Unified transformer architecture: Combines multiple transformer stages into a single optimized model
+ - VAE decoding for final video output
+ - Performance monitoring and hardware optimization
+
+ Attributes:
+ text_encoder: UMT5 text encoder for semantic text understanding (TODO: QEfficient optimization)
+ unified_wrapper (QEffWanUnifiedWrapper): Wrapper combining transformer stages
+ transformer (QEffWanUnifiedTransformer): Optimized unified transformer for denoising
+ vae_decode: VAE decoder for latent-to-video conversion
+ modules (Dict[str, Any]): Dictionary of pipeline modules for batch operations
+ model (WanPipeline): Original HuggingFace WAN model reference
+ tokenizer: Text tokenizer for preprocessing
+ scheduler: Diffusion scheduler for timestep management
+
+ Example:
+ >>> from QEfficient.diffusers.pipelines.wan import QEffWanPipeline
+ >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model")
+ >>> videos = pipeline(
+ ... prompt="A cat playing in a garden",
+ ... height=480,
+ ... width=832,
+ ... num_frames=81,
+ ... num_inference_steps=4
+ ... )
+ >>> # Save generated video
+ >>> videos.images[0].save("generated_video.mp4")
+ """
+
+ _hf_auto_class = WanPipeline
+
+ def __init__(self, model, **kwargs):
+ """
+ Initialize the QEfficient WAN pipeline.
+
+ This pipeline provides an optimized implementation of the WAN text-to-video model
+ for deployment on Qualcomm AI hardware. It wraps the original HuggingFace WAN model
+ components with QEfficient-optimized versions that can be exported to ONNX and compiled
+ for QAIC devices.
+
+ Args:
+ model: Pre-loaded WanPipeline model with transformer and transformer_2 components
+ **kwargs: Additional keyword arguments including configuration parameters
+ """
+ # Store original model and configuration
+ self.model = model
+ self.kwargs = kwargs
+ self.custom_config = None
+
+ # Text encoder (TODO: Replace with QEfficient UMT5 optimization)
+ self.text_encoder = model.text_encoder
+
+ # Create unified transformer wrapper combining dual-stage models(high, low noise DiTs)
+ self.unified_wrapper = QEffWanUnifiedWrapper(model.transformer, model.transformer_2)
+ self.transformer = QEffWanUnifiedTransformer(self.unified_wrapper)
+
+ # VAE decoder for latent-to-video conversion
+ self.vae_decode = model.vae
+
+ # Store all modules in a dictionary for easy iteration during export/compile
+ # TODO: add text encoder, vae decoder on QAIC
+ self.modules = {"transformer": self.transformer}
+
+ # Copy tokenizers and scheduler from the original model
+ self.tokenizer = model.tokenizer
+ self.text_encoder.tokenizer = model.tokenizer
+ self.scheduler = model.scheduler
+ # Extract patch dimensions from transformer configuration
+ _, self.patch_height, self.patch_width = self.transformer.model.config.patch_size
+
+ @property
+ def do_classifier_free_guidance(self):
+ """
+ Determine if classifier-free guidance should be used.
+
+ Returns:
+ bool: True if CFG should be applied based on current guidance scales
+ """
+ return self._guidance_scale > 1.0 and (self._guidance_scale_2 is None or self._guidance_scale_2 > 1.0)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ **kwargs,
+ ):
+ """
+ Load a pretrained WAN model from HuggingFace Hub or local path and wrap it with QEfficient optimizations.
+
+ This class method provides a convenient way to instantiate a QEffWanPipeline from a pretrained
+ WAN model. It automatically loads the base WanPipeline model in float32 precision on CPU
+ and wraps all components with QEfficient-optimized versions for QAIC deployment.
+
+ Args:
+ pretrained_model_name_or_path (str or os.PathLike): Either a HuggingFace model identifier
+ or a local path to a saved WAN model directory. Should contain transformer, transformer_2,
+ text_encoder, and VAE components.
+ **kwargs: Additional keyword arguments passed to WanPipeline.from_pretrained().
+
+ Returns:
+ QEffWanPipeline: A fully initialized pipeline instance with QEfficient-optimized components
+ ready for export, compilation, and inference on QAIC devices.
+
+ Raises:
+ ValueError: If the model path is invalid or model cannot be loaded
+ OSError: If there are issues accessing the model files
+ RuntimeError: If model initialization fails
+
+ Example:
+ >>> # Load from HuggingFace Hub
+ >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model")
+ >>>
+ >>> # Load from local path
+ >>> pipeline = QEffWanPipeline.from_pretrained("/local/path/to/wan")
+ >>>
+ >>> # Load with custom cache directory
+ >>> pipeline = QEffWanPipeline.from_pretrained(
+ ... "wan-model-id",
+ ... cache_dir="/custom/cache/dir"
+ ... )
+ """
+ # Load the base WAN model in float32 on CPU for optimization
+ model = cls._hf_auto_class.from_pretrained(
+ pretrained_model_name_or_path,
+ torch_dtype=torch.float32,
+ device_map="cpu",
+ **kwargs,
+ )
+ return cls(
+ model=model,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ **kwargs,
+ )
+
+ def export(
+ self,
+ export_dir: Optional[str] = None,
+ use_onnx_subfunctions: bool = False,
+ ) -> str:
+ """
+ Export all pipeline modules to ONNX format for deployment preparation.
+
+ This method systematically exports the unified transformer to ONNX format with
+ video-specific configurations including temporal dimensions, dynamic axes, and
+ optimization settings. The export process prepares the model for subsequent
+ compilation to QPC format for efficient inference on QAIC hardware.
+
+ Args:
+ export_dir (str, optional): Target directory for saving ONNX model files. If None,
+ uses the default export directory structure. The directory will be created
+ if it doesn't exist.
+ use_onnx_subfunctions (bool, default=False): Whether to enable ONNX subfunction
+ optimization for supported modules. This can optimize the graph structure
+ and improve compilation efficiency for complex models like the transformer.
+
+ Returns:
+ str: Absolute path to the export directory containing all ONNX model files.
+
+ Raises:
+ RuntimeError: If ONNX export fails for any module
+ OSError: If there are issues creating the export directory or writing files
+ ValueError: If module configurations are invalid
+
+ Example:
+ >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model")
+ >>> export_path = pipeline.export(
+ ... export_dir="/path/to/export",
+ ... use_onnx_subfunctions=True
+ ... )
+ """
+
+ # Export each module with video-specific parameters
+ for module_name, module_obj in self.modules.items():
+ # Get ONNX export configuration with video dimensions
+ example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params()
+
+ # Prepare export parameters
+ export_params = {
+ "inputs": example_inputs,
+ "output_names": output_names,
+ "dynamic_axes": dynamic_axes,
+ "export_dir": export_dir,
+ }
+
+ # Enable ONNX subfunctions for supported modules if requested
+ if use_onnx_subfunctions and module_name in ONNX_SUBFUNCTION_MODULE:
+ export_params["use_onnx_subfunctions"] = True
+
+ module_obj.export(**export_params)
+
+ @staticmethod
+ def get_default_config_path():
+ """
+ Get the default configuration file path for WAN pipeline.
+
+ Returns:
+ str: Path to the default WAN configuration JSON file.
+ """
+ return os.path.join(os.path.dirname(__file__), "wan_config.json")
+
+ def compile(
+ self,
+ compile_config: Optional[str] = None,
+ parallel: bool = False,
+ height: int = constants.WAN_ONNX_EXPORT_HEIGHT_180P,
+ width: int = constants.WAN_ONNX_EXPORT_WIDTH_180P,
+ num_frames: int = constants.WAN_ONNX_EXPORT_FRAMES,
+ use_onnx_subfunctions: bool = False,
+ ) -> str:
+ """
+ Compiles the ONNX graphs of the different model components for deployment on Qualcomm AI hardware.
+
+ This method takes the ONNX paths of the transformer and compiles them into an optimized format
+ for inference using JSON-based configuration.
+
+ Args:
+ compile_config (str, optional): Path to a JSON configuration file containing
+ compilation settings, device mappings, and optimization parameters. If None,
+ uses the default configuration.
+ parallel (bool, default=False): Compilation mode selection:
+ - True: Compile modules in parallel using ThreadPoolExecutor for faster processing
+ - False: Compile modules sequentially for lower resource usage
+ height (int, default=192): Target image height in pixels.
+ width (int, default=320): Target image width in pixels.
+ num_frames (int, deafult=81) : Target num of frames in pixel space
+ use_onnx_subfunctions (bool, default=False): Whether to export models with ONNX
+ subfunctions before compilation if not already exported.
+
+ Raises:
+ RuntimeError: If compilation fails for any module or if QAIC compiler is not available
+ FileNotFoundError: If ONNX models haven't been exported or config file is missing
+ ValueError: If configuration parameters are invalid
+ OSError: If there are issues with file I/O during compilation
+
+ Example:
+ >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model")
+ >>> # Sequential compilation with default config
+ >>> pipeline.compile(height=480, width=832, num_frames=81)
+ >>>
+ >>> # Parallel compilation with custom config
+ >>> pipeline.compile(
+ ... compile_config="/path/to/custom_config.json",
+ ... parallel=True,
+ ... height=480,
+ ... width=832,
+ ... num_frames=81
+ ... )
+ """
+ # Ensure all modules are exported to ONNX before compilation
+ if any(
+ path is None
+ for path in [
+ self.transformer.onnx_path,
+ ]
+ ):
+ self.export(use_onnx_subfunctions=use_onnx_subfunctions)
+
+ # Load compilation configuration
+ config_manager(self, config_source=compile_config)
+
+ # Configure pipeline dimensions and calculate compressed latent parameters
+ cl, latent_height, latent_width, latent_frames = calculate_latent_dimensions_with_frames(
+ height,
+ width,
+ num_frames,
+ self.model.vae.config.scale_factor_spatial,
+ self.model.vae.config.scale_factor_temporal,
+ self.patch_height,
+ self.patch_width,
+ )
+ # Prepare dynamic specialization updates based on video dimensions
+ specialization_updates = {
+ "transformer": [
+ # high noise
+ {
+ "cl": cl, # Compressed latent dimension
+ "latent_height": latent_height, # Latent space height
+ "latent_width": latent_width, # Latent space width
+ "num_frames": latent_frames, # Latent frames
+ },
+ # low noise
+ {
+ "cl": cl, # Compressed latent dimension
+ "latent_height": latent_height, # Latent space height
+ "latent_width": latent_width, # Latent space width
+ "num_frames": latent_frames, # Latent frames
+ },
+ ]
+ }
+
+ # Use generic utility functions for compilation
+ if parallel:
+ compile_modules_parallel(self.modules, self.custom_config, specialization_updates)
+ else:
+ compile_modules_sequential(self.modules, self.custom_config, specialization_updates)
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 3.0,
+ guidance_scale_2: Optional[float] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None]]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ custom_config_path: Optional[str] = None,
+ use_onnx_subfunctions: bool = False,
+ parallel_compile: bool = True,
+ ):
+ """
+ Generate videos from text prompts using the QEfficient-optimized WAN pipeline on QAIC hardware.
+
+ This is the main entry point for text-to-video generation. It orchestrates the complete WAN
+ diffusion pipeline optimized for Qualcomm AI Cloud devices.
+
+ Args:
+ prompt (str or List[str]): Primary text prompt(s) describing the desired video content.
+ Required unless `prompt_embeds` is provided.
+ negative_prompt (str or List[str], optional): Negative prompt(s) describing what to avoid
+ in the generated video. Used with classifier-free guidance.
+ height (int, optional): Target video height in pixels. Must be divisible by VAE scale factor.
+ Default: 480.
+ width (int, optional): Target video width in pixels. Must be divisible by VAE scale factor.
+ Default: 832.
+ num_frames (int, optional): Number of video frames to generate. Must satisfy temporal
+ divisibility requirements. Default: 81.
+ num_inference_steps (int, optional): Number of denoising steps. More steps generally
+ improve quality but increase generation time. Default: 50.
+ guidance_scale (float, optional): Guidance scale for classifier-free guidance. Default: 3.0.
+ guidance_scale_2 (float, optional): Guidance scale for low-noise stage in WAN 2.2.
+ If None, uses guidance_scale value.
+ num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Default: 1.
+ generator (torch.Generator or List[torch.Generator], optional): Random generator for
+ reproducible generation.
+ latents (torch.Tensor, optional): Pre-generated latent tensors. If None, random latents
+ are generated based on video dimensions.
+ prompt_embeds (torch.Tensor, optional): Pre-computed text embeddings from UMT5 encoder.
+ Shape: [batch, seq_len, hidden_dim].
+ negative_prompt_embeds (torch.Tensor, optional): Pre-computed negative text embeddings.
+ output_type (str, optional): Output format. Options: "np" (default), "pil", or "latent".
+ return_dict (bool, optional): Whether to return a dictionary or tuple. Default: True.
+ attention_kwargs (Dict[str, Any], optional): Additional attention arguments for transformer.
+ callback_on_step_end (Callable, optional): Callback function executed after each denoising step.
+ callback_on_step_end_tensor_inputs (List[str], optional): Tensor names to pass to callback.
+ Default: ["latents"].
+ max_sequence_length (int, optional): Maximum token sequence length for text encoder. Default: 512.
+ custom_config_path (str, optional): Path to custom JSON configuration file for compilation.
+ use_onnx_subfunctions (bool, optional): Whether to export transformer blocks as ONNX subfunctions.
+ Default: False.
+ parallel_compile (bool, optional): Whether to compile modules in parallel. Default: True.
+
+ Returns:
+ QEffPipelineOutput: A dataclass containing:
+ - images: Generated video(s) in the format specified by `output_type`
+ - pipeline_module: Performance metrics for each pipeline component
+
+ Raises:
+ ValueError: If input validation fails or parameters are incompatible
+ RuntimeError: If compilation fails or QAIC devices are unavailable
+ FileNotFoundError: If custom config file is specified but not found
+
+ Example:
+ >>> from QEfficient.diffusers.pipelines.wan import QEffWanPipeline
+ >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model")
+ >>> result = pipeline(
+ ... prompt="A cat playing in a sunny garden",
+ ... height=480,
+ ... width=832,
+ ... num_frames=81,
+ ... num_inference_steps=4,
+ ... guidance_scale=3.0
+ ... )
+ >>> # Save generated video
+ >>> result.images[0].save("cat_garden.mp4")
+ """
+ device = "cpu"
+
+ # Compile models with custom configuration if needed
+ self.compile(
+ compile_config=custom_config_path,
+ parallel=parallel_compile,
+ use_onnx_subfunctions=use_onnx_subfunctions,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ )
+
+ # Set device IDs for all modules based on configuration
+ set_module_device_ids(self)
+
+ # Step 1: Validate all inputs
+ self.model.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ guidance_scale_2,
+ )
+
+ # Ensure num_frames satisfies temporal divisibility requirements
+ if num_frames % self.model.vae.config.scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.model.vae.config.scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = (
+ num_frames // self.model.vae.config.scale_factor_temporal * self.model.vae.config.scale_factor_temporal
+ + 1
+ )
+ num_frames = max(num_frames, 1)
+
+ if self.model.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
+ # Initialize pipeline state
+ self._guidance_scale = guidance_scale
+ self._guidance_scale_2 = guidance_scale_2 if guidance_scale_2 is not None else guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # Step 2: Determine batch size from inputs
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Step 3: Encode input prompts using UMT5 text encoder
+ # TODO: Update UMT5 on QAIC
+ prompt_embeds, negative_prompt_embeds = self.model.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Convert embeddings to transformer dtype for compatibility
+ transformer_dtype = self.transformer.model.transformer_high.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # Step 4: Prepare timesteps for denoising process
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # Step 5: Prepare initial latent variables for video generation
+ num_channels_latents = self.transformer.model.config.in_channels
+
+ latents = self.model.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # Create mask for temporal processing (used in expand_timesteps mode)
+ mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
+
+ # Step 6: Configure dual-stage processing for WAN 2.2
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # Calculate boundary timestep for stage switching in WAN 2.2
+ if self.model.config.boundary_ratio is not None:
+ boundary_timestep = self.model.config.boundary_ratio * self.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
+ # Step 7: Initialize QAIC inference session for transformer
+ if self.transformer.qpc_session is None:
+ self.transformer.qpc_session = QAICInferenceSession(
+ str(self.transformer.qpc_path), device_ids=self.transformer.device_ids
+ )
+
+ # Calculate compressed latent dimension for transformer buffer allocation
+ cl, _, _, _ = calculate_latent_dimensions_with_frames(
+ height,
+ width,
+ num_frames,
+ self.model.vae.config.scale_factor_spatial,
+ self.model.vae.config.scale_factor_temporal,
+ self.patch_height,
+ self.patch_width,
+ )
+ # Allocate output buffer for QAIC inference
+ output_buffer = {
+ "output": np.random.rand(
+ batch_size,
+ cl, # Compressed latent dimension
+ constants.WAN_DIT_OUT_CHANNELS,
+ ).astype(np.int32),
+ }
+ self.transformer.qpc_session.set_buffers(output_buffer)
+ transformer_perf = []
+
+ # Step 8: Denoising loop with dual-stage processing
+ with self.model.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self._interrupt:
+ continue
+
+ self._current_timestep = t
+
+ # Determine which model to use based on boundary timestep
+ if boundary_timestep is None or t >= boundary_timestep:
+ # High-noise stage
+ current_model = self.transformer.model.transformer_high
+ current_guidance_scale = guidance_scale
+ model_type = torch.ones(1, dtype=torch.int64) # High-noise model indicator
+ else:
+ # Low-noise stage
+ current_model = self.transformer.model.transformer_low
+ current_guidance_scale = guidance_scale_2
+ model_type = torch.ones(2, dtype=torch.int64) # Low-noise model indicator
+
+ # Prepare latent input with proper dtype
+ latent_model_input = latents.to(transformer_dtype)
+
+ # Handle timestep expansion for temporal consistency
+ if self.model.config.expand_timesteps:
+ # Expand timesteps spatially for better temporal modeling
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ # Standard timestep broadcasting
+ timestep = t.expand(latents.shape[0])
+
+ # Extract dimensions for patch processing
+ batch_size, num_channels, num_frames, height, width = latents.shape
+ p_t, p_h, p_w = current_model.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ # Generate rotary position embeddings
+ rotary_emb = current_model.rope(latent_model_input)
+ rotary_emb = torch.cat(rotary_emb, dim=0)
+ ts_seq_len = None
+ timestep = timestep.flatten()
+
+ # Generate conditioning embeddings (time + text)
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = (
+ current_model.condition_embedder(
+ timestep, prompt_embeds, encoder_hidden_states_image=None, timestep_seq_len=ts_seq_len
+ )
+ )
+
+ # Generate negative conditioning for classifier-free guidance
+ if self.do_classifier_free_guidance:
+ temb, timestep_proj, encoder_hidden_states_neg, encoder_hidden_states_image = (
+ current_model.condition_embedder(
+ timestep,
+ negative_prompt_embeds,
+ encoder_hidden_states_image=None,
+ timestep_seq_len=ts_seq_len,
+ )
+ )
+
+ # Reshape timestep projection for transformer input
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ # Prepare inputs for QAIC inference
+ inputs_aic = {
+ "hidden_states": latents.detach().numpy(),
+ "encoder_hidden_states": encoder_hidden_states.detach().numpy(),
+ "rotary_emb": rotary_emb.detach().numpy(),
+ "temb": temb.detach().numpy(),
+ "timestep_proj": timestep_proj.detach().numpy(),
+ "tsp": model_type.detach().numpy(), # Transformer stage pointer
+ }
+
+ # Prepare negative inputs for classifier-free guidance
+ if self.do_classifier_free_guidance:
+ inputs_aic2 = {
+ "hidden_states": latents.detach().numpy(),
+ "encoder_hidden_states": encoder_hidden_states_neg.detach().numpy(),
+ "rotary_emb": rotary_emb.detach().numpy(),
+ "temb": temb.detach().numpy(),
+ "timestep_proj": timestep_proj.detach().numpy(),
+ }
+
+ # Run conditional prediction with caching context
+ with current_model.cache_context("cond"):
+ # QAIC inference for conditional prediction
+ start_transformer_step_time = time.perf_counter()
+ outputs = self.transformer.qpc_session.run(inputs_aic)
+ end_transformer_step_time = time.perf_counter()
+ transformer_perf.append(end_transformer_step_time - start_transformer_step_time)
+ print(f"DIT {i} time {end_transformer_step_time - start_transformer_step_time:.2f} seconds")
+
+ # Process transformer output
+ hidden_states = torch.tensor(outputs["output"])
+
+ # Reshape output from patches back to video format
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+
+ # Permute dimensions to reconstruct video tensor
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ noise_pred = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ # Run unconditional prediction for classifier-free guidance
+ if self.do_classifier_free_guidance: # Note: CFG is False for WAN Lightning
+ with current_model.cache_context("uncond"):
+ # QAIC inference for unconditional prediction
+ start_transformer_step_time = time.perf_counter()
+ outputs = self.transformer.qpc_session.run(inputs_aic2)
+ end_transformer_step_time = time.perf_counter()
+ transformer_perf.append(end_transformer_step_time - start_transformer_step_time)
+
+ # Process unconditional output
+ hidden_states = torch.tensor(outputs["output"])
+
+ # Reshape unconditional output
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ noise_uncond = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ # Apply classifier-free guidance
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
+
+ # Update latents using scheduler (x_t -> x_t-1)
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Execute callback if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # Update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ self._current_timestep = None
+
+ # Step 9: Decode latents to video
+ if not output_type == "latent":
+ # Prepare latents for VAE decoding
+ latents = latents.to(self.vae_decode.dtype)
+
+ # Apply VAE normalization (denormalization)
+ latents_mean = (
+ torch.tensor(self.vae_decode.config.latents_mean)
+ .view(1, self.vae_decode.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae_decode.config.latents_std).view(
+ 1, self.vae_decode.config.z_dim, 1, 1, 1
+ ).to(latents.device, latents.dtype)
+ latents = latents / latents_std + latents_mean
+
+ # TODO: Enable VAE on QAIC
+ # VAE Decode latents to video using CPU (temporary)
+ video = self.model.vae.decode(latents, return_dict=False)[0] # CPU fallback
+
+ # Post-process video for output
+ video = self.model.video_processor.postprocess_video(video.detach())
+ else:
+ video = latents
+
+ # Step 10: Collect performance metrics
+ perf_data = {
+ "transformer": transformer_perf, # Unified transformer (QAIC)
+ }
+
+ # Build performance metrics for output
+ perf_metrics = [ModulePerf(module_name=name, perf=perf_data[name]) for name in perf_data.keys()]
+
+ return QEffPipelineOutput(
+ pipeline_module=perf_metrics,
+ images=video,
+ )
diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py
index 7da2300d6..de10c9b88 100755
--- a/QEfficient/generation/text_generation_inference.py
+++ b/QEfficient/generation/text_generation_inference.py
@@ -329,6 +329,7 @@ def cloud_ai_100_exec_kv(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
+ include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
@@ -356,6 +357,8 @@ def cloud_ai_100_exec_kv(
next tokens. For Speculative Decoding Target Language Model,
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
Decoding Draft Language Model and `return_pdfs`=False for regular model.
+ :include_guided_decoding (bool, default=False): If True, enables guided token-level filtering
+ during decoding. Only works when `include_sampler`=True.
sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend.
The dictionary should contain the following keys:
`repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
@@ -394,6 +397,7 @@ def cloud_ai_100_exec_kv(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
+ include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)
@@ -442,6 +446,7 @@ def __init__(
is_tlm: Optional[int] = None,
include_sampler: bool = False,
return_pdfs: bool = False,
+ include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
activate: bool = True,
) -> None:
@@ -451,6 +456,7 @@ def __init__(
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm
self.return_pdfs = return_pdfs
+ self.include_guided_decoding = include_guided_decoding
self.sampling_params = sampling_params
self._qpc_path = qpc_path # Store qpc_path for later use
@@ -461,7 +467,9 @@ def __init__(
# Validate sampler inputs for On-Device Sampling
self.include_sampler = validate_sampler_inputs(
- session_inputs=set(self._session.input_names), include_sampler=include_sampler
+ session_inputs=set(self._session.input_names),
+ include_sampler=include_sampler,
+ include_guided_decoding=include_guided_decoding,
)
# Fetch the variables from the QPC
@@ -628,7 +636,7 @@ def prepare_decode_inputs(self):
decode_inputs["batch_index"] = self.batch_index
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
- for op in Constants.SAMPLER_OPS:
+ for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if self.batch_index is not None:
decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()]
else:
@@ -795,7 +803,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
inputs["num_logits_to_keep"] = np.zeros((1, 1))
if self.include_sampler:
inputs["last_accepted_output_tokens"] = inputs["input_ids"]
- for op in Constants.SAMPLER_OPS:
+ for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
if decode_batch_id is not None:
inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
@@ -811,7 +819,9 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
if self.comp_ctx_lengths_prefill is not None:
- self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
+ self.list_of_comp_ctx_lengths_prefill = [
+ np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_prefill
+ ]
prefill_ccl_id = 0
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
@@ -841,7 +851,9 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
)
def initialize_ccl(self, decode_inputs):
- self.list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
+ self.list_of_comp_ctx_lengths_decode = [
+ np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_decode
+ ]
max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
max_position_id = np.max(decode_inputs["position_ids"])
ccl_id_initial = 0
@@ -1067,6 +1079,7 @@ def __init__(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
+ include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._qaic_model = QEffTextGenerationBase(
@@ -1082,6 +1095,7 @@ def __init__(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
+ include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)
self._full_batch_size = self._qaic_model.full_batch_size
diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py
index b37fdc74a..adacc373e 100644
--- a/QEfficient/generation/vlm_generation.py
+++ b/QEfficient/generation/vlm_generation.py
@@ -36,6 +36,7 @@
write_io_files,
)
from QEfficient.utils import LRUCache
+from QEfficient.utils.constants import Constants
from QEfficient.utils.logging_utils import logger
@@ -93,6 +94,7 @@ def __init__(
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
+ include_guided_decoding: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
@@ -114,6 +116,7 @@ def __init__(
is_tlm: Target language model flag
include_sampler: Enable on-device sampling (new feature)
return_pdfs: Return probability distributions
+ include_guided_decoding: Enable guided decoding in on-device sampling
sampling_params: Sampling parameters for on-device sampling
"""
# Validate required parameters
@@ -137,6 +140,7 @@ def __init__(
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
+ include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
activate=False, # vision components need to be initialized first
)
@@ -309,10 +313,19 @@ def _execute_chunked_prefill(
chunk_image_idx = None
if self.comp_ctx_lengths_prefill is not None:
- self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
+ self.list_of_comp_ctx_lengths_prefill = [
+ np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_prefill
+ ]
prefill_ccl_id = 0
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
+ if self.include_sampler:
+ for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
+ if decode_batch_id is not None:
+ lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
+ else:
+ lang_inputs[op] = self.sampling_params[op]
+
for i in range(num_chunks):
input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len]
position_ids_slice = lang_inputs["position_ids"][
@@ -338,6 +351,11 @@ def _execute_chunked_prefill(
chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"]
+ if self.include_sampler:
+ chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
+ for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
+ chunk_inputs[op] = lang_inputs[op]
+
outputs = self._session.run(chunk_inputs)
if "image_idx_output" in outputs:
@@ -790,6 +808,7 @@ def generate_stream_tokens(
is_tlm=self.is_tlm,
include_sampler=self.include_sampler,
return_pdfs=self.return_pdfs,
+ include_guided_decoding=self.include_guided_decoding,
sampling_params=self.sampling_params,
)
diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py
index e69aebb2b..6c7173072 100644
--- a/QEfficient/peft/auto.py
+++ b/QEfficient/peft/auto.py
@@ -253,7 +253,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
return obj
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model with the active adapter to ONNX format.
@@ -291,10 +291,10 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
example_inputs,
output_names,
dynamic_axes,
- export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights
+ do_constant_folding=False, # To avoid merging adapter weights with base weights
onnx_transform_kwargs={"adapter_name": self.model.active_adapter},
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ **kwargs,
)
def compile(
diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py
index 64fa3f61c..8ff8335f5 100644
--- a/QEfficient/peft/lora/auto.py
+++ b/QEfficient/peft/lora/auto.py
@@ -327,7 +327,7 @@ def _init_adapter_model(self):
# load_weight to model
self._load_adapter_weights_to_model()
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``.
@@ -387,7 +387,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
output_names,
dynamic_axes,
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ **kwargs,
)
def generate(
diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py
index 62cc71a4c..faadaba6b 100644
--- a/QEfficient/transformers/cache_utils.py
+++ b/QEfficient/transformers/cache_utils.py
@@ -46,6 +46,7 @@ def _get_invalid_idx_value(cls):
"""
if torch.onnx.is_in_onnx_export():
if cls.SUBFUNC_ENABLED:
+ # TODO: should not return 0 remove this if condition, it can hurt perf
return 0
else:
return torch.iinfo(torch.int32).max
@@ -681,6 +682,37 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache
+ def write_only(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ k_out, v_out = key_states, value_states
+ else:
+ position_ids = cache_kwargs.get("position_ids")
+ is_sliding_layer = cache_kwargs.get("is_sliding")
+ _, _, ctx_len, _ = self.key_cache[layer_idx].shape
+ if is_sliding_layer:
+ kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1)
+ self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
+ self.value_cache[layer_idx] = CtxScatterFunc.apply(
+ self.value_cache[layer_idx], kv_position_ids, value_states
+ )
+ else:
+ kv_position_ids = position_ids
+
+ self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
+ self.value_cache[layer_idx] = CtxScatterFunc.apply(
+ self.value_cache[layer_idx], kv_position_ids, value_states
+ )
+ k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
+ return k_out, v_out
+
def update(
self,
key_states: torch.Tensor,
@@ -747,3 +779,92 @@ def update(
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out
+
+ def full_cache_update_chunked(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ position_ids = cache_kwargs.get("position_ids")
+ batch_index = cache_kwargs.get("batch_index")
+ invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
+
+ # Scatter
+ if batch_index is not None:
+ if torch.onnx.is_in_onnx_export():
+ scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids)
+ self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
+ self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
+ )
+ self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
+ self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
+ )
+ else:
+ self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
+ self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)
+
+ k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ # Gather
+ ctx_len = cache_kwargs.get("CCL", k_out.shape[2])
+ ctx_indices = torch.arange(ctx_len)[None, None, ...]
+ gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
+ invalid_mask = ctx_indices > gather_limit
+ ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
+ if batch_index is not None:
+ k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len)
+ v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len)
+ else:
+ k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
+ v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)
+ v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
+
+ return k_out, v_out
+
+ def sliding_window_update_chunked(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ position_ids = cache_kwargs.get("position_ids")
+ batch_index = cache_kwargs.get("batch_index")
+ invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
+
+ if batch_index is not None:
+ if torch.onnx.is_in_onnx_export():
+ scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids)
+ self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
+ self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
+ )
+ self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
+ self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
+ )
+ else:
+ self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
+ self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)
+
+ k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
+ sliding_window_len = cache_kwargs.get("sliding_window")
+
+ # Gather
+ ctx_len = position_ids.shape[1] + sliding_window_len
+ ctx_indices = torch.arange(ctx_len)[None, None, ...]
+ first_pos_idx = position_ids[0][0]
+ add_idx = torch.where(first_pos_idx >= sliding_window_len, first_pos_idx - sliding_window_len, 0)
+ ctx_indices += add_idx
+ gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
+ invalid_mask = ctx_indices > gather_limit
+ ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
+ if batch_index is not None:
+ k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len)
+ v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len)
+ else:
+ k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
+ v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)
+ v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
+
+ return k_out, v_out
diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py
index 5337b44f5..47059d8dc 100644
--- a/QEfficient/transformers/modeling_utils.py
+++ b/QEfficient/transformers/modeling_utils.py
@@ -188,6 +188,9 @@
# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc.
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}
+# This is for supporting different modelling classes specially written for prefill-only model
+SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"}
+
# Define a transformers layers to QEff layers dictionary
# While onboarding new models make sure to add the new layer maps to this dictionary.
TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = {
diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py
index c91d2fe32..a6e451bec 100644
--- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py
+++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py
@@ -921,7 +921,7 @@ def get_dummy_inputs(
)
if comp_ctx_lengths is not None:
- lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8)
if continuous_batching:
lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
index 84552aff4..3efe890b8 100644
--- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
+++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
+import math
+import os
from typing import Callable, Optional, Union
import torch
@@ -30,8 +32,8 @@
from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
-from QEfficient.utils import constants
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
+from QEfficient.utils.logging_utils import logger
class QEffGptOssExperts(GptOssExperts):
@@ -42,8 +44,8 @@ def __qeff_init__(self):
self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim))
-class QEffGptOssMLP(GptOssMLP):
- def alt_forward(self, hidden: torch.Tensor):
+class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP):
+ def forward(self, hidden: torch.Tensor):
B, S, H = hidden.shape
T = B * S
hidden = hidden.view(T, H)
@@ -78,7 +80,62 @@ def alt_forward(self, hidden: torch.Tensor):
up = (hidden @ W_u) + b_u # [T, I]
# Apply GptOss activation with clamping
- gate = gate.clamp(min=None, max=self.experts.limit)
+ gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
+ up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
+
+ # GLU activation
+ glu = gate * torch.sigmoid(gate * self.experts.alpha)
+ intermediate = (up + 1) * glu # [T, I]
+
+ # Down projection
+ down_out = (intermediate @ W_d) + b_d # [T, H]
+
+ # Apply routing weights and accumulate
+ expert_out += down_out * routing_weight
+
+ # original shape [B, S, H]
+ return expert_out.view(B, S, H), router_logits
+
+
+class QEffPrefillOnlyGptOssMLP(GptOssMLP):
+ def forward(self, hidden: torch.Tensor):
+ if os.environ.get("NUM_FFN_BLOCKS", None) is not None:
+ return self.blocked_ffn_forward(hidden)
+ B, S, H = hidden.shape
+ T = B * S
+ hidden = hidden.view(T, H)
+
+ # Router computation
+ router_logits = F.linear(hidden, self.router.weight, self.router.bias)
+
+ # Top-k selection
+ top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K]
+ top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype)
+
+ masked_logits = torch.zeros_like(router_logits)
+ masked_logits.scatter_(1, top_i, top_w)
+
+ # Routing weights for each expert [T, E]
+ routing_weights = masked_logits
+
+ # ββββββββββββββββββ allocate the output tensor βββββ
+ expert_out = hidden.new_zeros((T, H)) # accumulation buffer
+
+ # βββββββββββββββββββββββββ Expert computation loop βββββββββββββββββββββββββββββ
+ for e in range(self.experts.num_experts):
+ routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1]
+
+ W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I]
+ b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I]
+ W_d = self.experts.down_proj[e] # [I, H]
+ b_d = self.experts.down_proj_bias[e] # [H]
+
+ # Gate and Up projections
+ gate = (hidden @ W_g) + b_g # [T, I]
+ up = (hidden @ W_u) + b_u # [T, I]
+
+ # Apply GptOss activation with clamping
+ gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
# GLU activation
@@ -88,6 +145,165 @@ def alt_forward(self, hidden: torch.Tensor):
# Down projection
down_out = (intermediate @ W_d) + b_d # [T, H]
+ # Apply routing weights and accumulate
+ expert_out += down_out * routing_weight
+
+ # original shape [B, S, H]
+ return expert_out.view(B, S, H), router_logits
+
+ def blocked_ffn_forward(self, hidden: torch.Tensor):
+ B, S, H = hidden.shape
+ T = B * S
+ hidden = hidden.view(T, H)
+
+ # Router computation
+ router_logits = F.linear(hidden, self.router.weight, self.router.bias)
+
+ # Top-k selection
+ top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K]
+ top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype)
+
+ masked_logits = torch.zeros_like(router_logits)
+ masked_logits.scatter_(1, top_i, top_w)
+
+ # Routing weights for each expert [T, E]
+ routing_weights = masked_logits
+
+ # ββββββββββββββββββ allocate the output tensor βββββ
+ expert_out = hidden.new_zeros((T, H)) # accumulation buffer
+ target_blocks = int(os.environ.get("NUM_FFN_BLOCKS", 1))
+ block_positions = []
+ for j in range(target_blocks):
+ block_positions.append(j * (T // target_blocks))
+ # βββββββββββββββββββββββββ Expert computation loop βββββββββββββββββββββββββββββ
+ for e in range(self.experts.num_experts):
+ routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1]
+
+ W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I]
+ b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I]
+ W_d = self.experts.down_proj[e] # [I, H]
+ b_d = self.experts.down_proj_bias[e] # [H]
+
+ block_count = 0
+ outs = []
+ for block_idx in range(target_blocks):
+ block_count += 1
+ qi = block_positions[block_idx]
+
+ # Calculate block size (last block should be handled with remainder)
+ if block_idx == target_blocks - 1:
+ real_q_len = T - qi
+ else:
+ real_q_len = block_positions[block_idx + 1] - qi
+
+ tgb = hidden[qi : qi + real_q_len, :]
+ # Gate and Up projections
+ # Gate and Up projections
+ gate = (tgb @ W_g) + b_g # [T, I]
+ up = (tgb @ W_u) + b_u # [T, I]
+
+ # Apply GptOss activation with clamping
+ gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
+ up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
+
+ # GLU activation
+ glu = gate * torch.sigmoid(gate * self.experts.alpha)
+ intermediate = (up + 1) * glu # [T, I]
+
+ # Down projection
+ down_out_block = (intermediate @ W_d) + b_d # [T, H]
+
+ outs.append(down_out_block)
+
+ down_out = torch.cat(outs, dim=0)
+
+ # Apply routing weights and accumulate
+ expert_out += down_out * routing_weight
+
+ # original shape [B, S, H]
+ return expert_out.view(B, S, H), router_logits
+
+ def blocked_ffn_forward_block_weights(self, hidden: torch.Tensor):
+ B, S, H = hidden.shape
+ T = B * S
+ hidden = hidden.view(T, H)
+
+ # Router computation
+ router_logits = F.linear(hidden, self.router.weight, self.router.bias)
+
+ # Top-k selection
+ top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K]
+ top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype)
+
+ masked_logits = torch.zeros_like(router_logits)
+ masked_logits.scatter_(1, top_i, top_w)
+
+ # Routing weights for each expert [T, E]
+ routing_weights = masked_logits
+
+ # ββββββββββββββββββ allocate the output tensor βββββ
+ expert_out = hidden.new_zeros((T, H)) # accumulation buffer
+ target_blocks = int(os.environ.get("NUM_BLOCKS", 1))
+ block_positions = []
+ for j in range(target_blocks):
+ block_positions.append(j * (T // target_blocks))
+ # βββββββββββββββββββββββββ Expert computation loop βββββββββββββββββββββββββββββ
+ for e in range(self.experts.num_experts):
+ routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1]
+
+ W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I]
+ b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I]
+ W_d = self.experts.down_proj[e] # [I, H]
+ b_d = self.experts.down_proj_bias[e] # [H]
+
+ block_count = 0
+ outs = []
+ for block_idx in range(target_blocks):
+ block_count += 1
+ qi = block_positions[block_idx]
+
+ # Calculate block size (last block should be handled with remainder)
+ if block_idx == target_blocks - 1:
+ real_q_len = T - qi
+ else:
+ real_q_len = block_positions[block_idx + 1] - qi
+
+ tgb = hidden[qi : qi + real_q_len, :]
+ # Gate and Up projections
+
+ wg_col_shape = W_g.shape[1]
+ wg_num_blocks = math.ceil(wg_col_shape / 128)
+ last_block_size = wg_col_shape % 128 if wg_col_shape % 128 != 0 else 128
+
+ intermediates = []
+ for i in range(wg_num_blocks):
+ if i == wg_num_blocks - 1:
+ cur_gate = (tgb @ W_g[:, -last_block_size:]) + b_g[-last_block_size:]
+ cur_up = (tgb @ W_u[:, -last_block_size:]) + b_u[-last_block_size:]
+ else:
+ cur_gate = (tgb @ W_g[:, i * 128 : (i + 1) * 128]) + b_g[i * 128 : (i + 1) * 128]
+ cur_up = (tgb @ W_u[:, i * 128 : (i + 1) * 128]) + b_u[i * 128 : (i + 1) * 128]
+
+ cur_gate = cur_gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
+ cur_up = cur_up.clamp(min=-self.experts.limit, max=self.experts.limit)
+ cur_glu = cur_gate * torch.sigmoid(cur_gate * self.experts.alpha)
+ cur_intermediate = (cur_up + 1) * cur_glu
+ intermediates.append(cur_intermediate)
+
+ intermediate = torch.cat(intermediates, dim=-1)
+
+ downs = []
+ for i in range(wg_num_blocks):
+ if i == wg_num_blocks - 1:
+ downs.append((intermediate @ W_d[:, -last_block_size:]) + b_d[-last_block_size:])
+ else:
+ downs.append((intermediate @ W_d[:, i * 128 : (i + 1) * 128]) + b_d[i * 128 : (i + 1) * 128])
+
+ down_out_block = torch.cat(downs, dim=1)
+ outs.append(down_out_block)
+
+ down_out = torch.cat(outs, dim=0)
+
# Apply routing weights and accumulate
masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out))
expert_out += masked_down
@@ -95,6 +311,8 @@ def alt_forward(self, hidden: torch.Tensor):
# original shape [B, S, H]
return expert_out.view(B, S, H), router_logits
+
+class QEffGptOssMLP(GptOssMLP):
# ------------------- Gather based, weights as activation approach ---------------
def forward_weights_as_activation(self, hidden_states):
bs, seq_len, _ = hidden_states.shape
@@ -142,7 +360,6 @@ def forward_weights_as_activation(self, hidden_states):
# ------------------- Gather based, weights as activation approach, With Seperate Gate, up Projections ---------------
def forward(self, hidden_states):
- # print("Seperate Split, Up, Gate Projections")
bs, seq_len, _ = hidden_states.shape
hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size)
@@ -172,7 +389,7 @@ def forward(self, hidden_states):
up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1)
# Apply activation with clamping
- gate = gate.clamp(min=None, max=self.experts.limit)
+ gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
# GLU activation
@@ -404,6 +621,283 @@ def eager_attention_forward(
return attn_output, attn_weights
+def eager_attention_forward_blocked(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ BS, NH, CL, DH = query.shape
+ target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1))
+ block_positions = []
+ for j in range(target_blocks):
+ block_positions.append(j * (CL // target_blocks))
+ block_count = 0
+
+ outs = []
+ for block_idx in range(target_blocks):
+ block_count += 1
+ qi = block_positions[block_idx]
+
+ # Calculate block size (last block should be handled with remainder)
+ if block_idx == target_blocks - 1:
+ real_q_len = CL - qi
+ else:
+ real_q_len = block_positions[block_idx + 1] - qi
+
+ q_block = query[:, :, qi : qi + real_q_len, :]
+ scores = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling
+ attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :]
+ curr_attn_weights = torch.where(
+ attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores
+ )
+ sinks = module.sinks.reshape(1, -1, 1, 1).expand(
+ curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1
+ )
+ combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1)
+ combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
+ curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32)
+ curr_attn_weights = curr_attn_weights[..., :-1]
+ out_block = torch.matmul(curr_attn_weights, value_states)
+ outs.append(out_block)
+ output = torch.cat(outs, dim=2)
+
+ output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous()
+ return output, output
+
+
+def opt_eager_attention_forward_blocked(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ BS, NH, CL, DH = query.shape
+ target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1))
+ block_positions = []
+ for j in range(target_blocks):
+ block_positions.append(j * (CL // target_blocks))
+ block_count = 0
+ outs = []
+ for block_idx in range(target_blocks):
+ block_count += 1
+ qi = block_positions[block_idx]
+ # Calculate block size (last block should be handled with remainder)
+
+ if block_idx == target_blocks - 1:
+ real_q_len = CL - qi
+ else:
+ real_q_len = block_positions[block_idx + 1] - qi
+
+ if block_idx == 0:
+ kv_start_idx = 0
+ else:
+ kv_start_idx = qi - 128
+
+ q_block = query[:, :, qi : qi + real_q_len, :]
+ if kwargs.get("sliding_window"):
+ k_block = key_states[:, :, kv_start_idx : qi + real_q_len, :]
+ v_block = value_states[:, :, kv_start_idx : qi + real_q_len, :]
+ attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, kv_start_idx : qi + real_q_len]
+ else:
+ k_block = key_states
+ v_block = value_states
+ attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :]
+
+ scores = torch.matmul(q_block, k_block.transpose(2, 3)) * scaling
+ curr_attn_weights = torch.where(
+ attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores
+ )
+ sinks = module.sinks.reshape(1, -1, 1, 1).expand(
+ curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1
+ )
+ combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1)
+ combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
+ curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32)
+ curr_attn_weights = curr_attn_weights[..., :-1]
+ out_block = torch.matmul(curr_attn_weights, v_block)
+ outs.append(out_block)
+ output = torch.cat(outs, dim=2)
+
+ output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous()
+ return output, output
+
+
+class QEffPrefillOnlyChunkedGptOssAttention(GptOssAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __qeff_init__(self):
+ self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ batch_index: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ sliding_mask=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")):
+ max_seq_len_cached = 32 * 1024
+ cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached)
+ query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "config": self.config,
+ "is_sliding": self.sliding_window is not None,
+ "sliding_window": self.sliding_window,
+ }
+ if self.sliding_window is not None:
+ key_states, value_states = past_key_value.sliding_window_update_chunked(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+ else:
+ key_states, value_states = past_key_value.full_cache_update_chunked(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ if self.sliding_window is not None:
+ attention_mask = sliding_mask
+ # positive_pos_ids = torch.where(position_ids<0, 0, position_ids)
+ ctx_len = position_ids.shape[1] + self.sliding_window
+ ctx_indices = torch.arange(ctx_len)
+ first_pos_idx = position_ids[0][0]
+ add_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx - self.sliding_window, 0)
+ # start_idx = torch.where(first_pos_idx>=self.sliding_window, first_pos_idx-self.sliding_window, 0)
+ # end_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx+position_ids.shape[1], position_ids.shape[1]+self.sliding_window)
+ ctx_indices += add_idx
+ attention_mask = attention_mask[:, :, :, ctx_indices]
+ else:
+ attention_mask = attention_mask
+
+ attention_interface: Callable = eager_attention_forward
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ s_aux=self.sinks, # diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights, past_key_value
+
+
+class QEffPrefillOnlyGptOssAttention(GptOssAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __qeff_init__(self):
+ self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ batch_index: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ sliding_mask=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")):
+ max_seq_len_cached = 32 * 1024
+ cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached)
+ query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "config": self.config,
+ "is_sliding": self.sliding_window is not None,
+ "sliding_window": past_key_value.sliding_window_len,
+ }
+ if self.sliding_window is not None:
+ sliding_window_len = past_key_value.sliding_window_len
+ short_read_idx = torch.arange(past_key_value.key_cache[self.layer_idx].shape[2])
+ read_idx = short_read_idx + torch.where(
+ position_ids.max() > sliding_window_len - 1, position_ids.max() - sliding_window_len + 1, 0
+ )
+ # This is a trick to export with seq_len position_ids.max(), 0, read_idx)
+ k_cache = key_states[:, :, read_idx, :]
+ v_cache = value_states[:, :, read_idx, :]
+ else:
+ k_cache, v_cache = key_states, value_states
+ _, _ = past_key_value.write_only(k_cache, v_cache, self.layer_idx, cache_kwargs)
+
+ if self.sliding_window is not None:
+ attention_mask = sliding_mask
+ else:
+ attention_mask = attention_mask
+
+ if os.environ.get("ENABLE_OPT_SWA", "0") == "1":
+ attention_interface: Callable = opt_eager_attention_forward_blocked
+ else:
+ attention_interface: Callable = eager_attention_forward_blocked
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ s_aux=self.sinks, # diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights, past_key_value
+
+
class QEffGptOssAttention(GptOssAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -429,8 +923,9 @@ def forward(
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024)
+ if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")):
+ max_seq_len_cached = 32 * 1024
+ cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
@@ -511,7 +1006,6 @@ def forward(
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
- # alth, _ = self.mlp.alt_forward(hidden_states)
hidden_states = hidden_states.reshape(residual.shape)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
@@ -525,6 +1019,97 @@ def forward(
return outputs
+class QEffPrefillOnlyGptOssModel(GptOssModel):
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ batch_index: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> MoeModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens
+ causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len)
+ sliding_mask = _create_causal_mask(
+ position_ids=position_ids,
+ target_length=past_key_values.max_cache_len,
+ sliding_window=self.config.sliding_window,
+ )
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ batch_index=batch_index,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ sliding_mask=sliding_mask,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if return_legacy_cache:
+ past_key_values = past_key_values.to_legacy_cache()
+
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ )
+
+
class QEffGptOssModel(GptOssModel):
def forward(
self,
@@ -578,7 +1163,6 @@ def forward(
)
hidden_states = inputs_embeds
- # position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
@@ -708,15 +1292,15 @@ def forward(
router_logits=outputs.router_logits,
)
- def get_pkv_dynamic_axes(
- self,
- ):
+ def get_pkv_dynamic_axes(self, retain_full_kv: Optional[bool] = False, continuous_batching: Optional[bool] = False):
pkv_dynamic_axes = []
for layer_type in self.config.layer_types:
- if layer_type == "sliding_attention":
- pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"})
- elif layer_type == "full_attention":
- pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"})
+ if layer_type == "sliding_attention" and not retain_full_kv:
+ pkv_dynamic_axes.append(
+ {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"}
+ )
+ else:
+ pkv_dynamic_axes.append({0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"})
return pkv_dynamic_axes
def get_specializations(
@@ -724,10 +1308,14 @@ def get_specializations(
batch_size: int,
prefill_seq_len: int,
ctx_len: int,
+ **kwargs,
):
batch_size = batch_size if batch_size else 1
- prefill_seq_len = prefill_seq_len if prefill_seq_len else constants.PROMPT_LEN
- ctx_len = ctx_len if ctx_len else constants.CTX_LEN
+ if kwargs.get("prefill_only") and not kwargs.get("enable_chunking") and ctx_len != prefill_seq_len:
+ ctx_len = prefill_seq_len
+ logger.warning(
+ f"overriding ctx_len={prefill_seq_len}, currently we don't support ctx_len different than prefill_seq_len for prefill_only model"
+ )
specializations = [
{
diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py
index 85c331aa8..b47db7eda 100644
--- a/QEfficient/transformers/models/internvl/modeling_internvl.py
+++ b/QEfficient/transformers/models/internvl/modeling_internvl.py
@@ -321,7 +321,7 @@ def get_dummy_inputs(
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
if comp_ctx_lengths is not None:
- lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8)
if continuous_batching:
lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py
index 7a2f687fe..834ee8880 100644
--- a/QEfficient/transformers/models/llama4/modeling_llama4.py
+++ b/QEfficient/transformers/models/llama4/modeling_llama4.py
@@ -1225,7 +1225,7 @@ def get_dummy_inputs(
lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
if comp_ctx_lengths is not None:
- lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8)
inputs = {}
if kv_offload:
diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py
index d5f5ee920..abdb77ea5 100644
--- a/QEfficient/transformers/models/llava/modeling_llava.py
+++ b/QEfficient/transformers/models/llava/modeling_llava.py
@@ -181,7 +181,7 @@ def get_dummy_inputs(
lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1)
if comp_ctx_lengths is not None:
- lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8)
if continuous_batching:
lang_inputs["batch_index"] = torch.arange(BS).view(BS, 1)
diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py
index 878d04a45..627f7393e 100755
--- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py
+++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py
@@ -241,7 +241,7 @@ def get_dummy_inputs(
lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, constants.GRANITEVISION_CTX_LEN - 1)
if comp_ctx_lengths is not None:
- lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8)
if continuous_batching:
lang_inputs["batch_index"] = torch.arange(BS).view(BS, 1)
diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py
index 89e19c65b..d2149b6bd 100644
--- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py
+++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py
@@ -315,7 +315,7 @@ def get_dummy_inputs(
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
if comp_ctx_lengths is not None:
- lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8)
if continuous_batching:
lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py
index a3cb4273d..74de1c6c1 100644
--- a/QEfficient/transformers/models/mllama/modeling_mllama.py
+++ b/QEfficient/transformers/models/mllama/modeling_mllama.py
@@ -967,7 +967,7 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl
lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1)
if comp_ctx_lengths is not None:
- lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8)
inputs = {}
diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py
index 8edc1f3f0..3e5acf4bb 100644
--- a/QEfficient/transformers/models/modeling_auto.py
+++ b/QEfficient/transformers/models/modeling_auto.py
@@ -5,10 +5,11 @@
#
# ----------------------------------------------------------------------------
+import os
import warnings
from pathlib import Path
from time import perf_counter
-from typing import Dict, List, Optional, Union
+from typing import List, Optional, Union
import numpy as np
import torch
@@ -37,13 +38,20 @@
get_compilation_dims,
)
from QEfficient.generation.vlm_generation import VisionLanguageGeneration
-from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH
+from QEfficient.transformers.modeling_utils import (
+ DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH,
+ SPECIALIZED_PREFILL_ONLY_MODEL_ARCH,
+)
from QEfficient.transformers.models.pytorch_transforms import (
BlockedKVAttentionTransform,
CustomOpsTransform,
KVCacheExternalModuleMapperTransform,
KVCacheTransform,
PoolingTransform,
+ PrefillOnlyChunkedTransform,
+ PrefillOnlyTransform,
+ RevertPrefillKeepAttentionTransform,
+ RevertPrefillOnlyTransform,
SamplerTransform,
SpDTransform,
VlmKVOffloadTransform,
@@ -62,6 +70,7 @@
)
from QEfficient.utils.check_ccl_specializations import process_ccl_specializations
from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs
class QEFFTransformersBase(QEFFBaseModel):
@@ -124,21 +133,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying HuggingFace model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
class MultimodalUtilityMixin:
"""
@@ -316,7 +310,7 @@ def get_model_config(self) -> dict:
"""
return self.model.config.__dict__
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model to ONNX format using ``torch.onnx.export``.
@@ -353,7 +347,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
output_names,
dynamic_axes,
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
)
def compile(
@@ -603,15 +597,7 @@ def __init__(self, model: nn.modules, **kwargs):
self.model = model.get_qeff_vision_encoder()
self.hash_params["qeff_auto_class"] = self.__class__.__name__
- def export(
- self,
- inputs,
- output_names,
- dynamic_axes,
- export_dir=None,
- offload_pt_weights=True,
- use_onnx_subfunctions: bool = False,
- ):
+ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs):
"""
Exports the vision encoder component to ONNX format.
@@ -641,7 +627,7 @@ def export(
dynamic_axes,
export_dir=export_dir,
offload_pt_weights=offload_pt_weights,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
)
def compile(
@@ -701,21 +687,6 @@ def compile(
**compiler_options,
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying vision encoder model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@property
def get_model_config(self) -> dict:
"""
@@ -749,7 +720,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
- def __init__(self, model, qaic_config, **kwargs):
+ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
"""
Initializes the language decoder component for multimodal models.
@@ -763,7 +734,7 @@ def __init__(self, model, qaic_config, **kwargs):
**kwargs :
Additional keyword arguments passed to the base class constructor.
"""
- super().__init__(model, **kwargs)
+ super().__init__(model, qaic_config=qaic_config, **kwargs)
self.model = model.get_qeff_language_decoder()
self.model.qaic_config = qaic_config
self.hash_params["qeff_auto_class"] = self.__class__.__name__
@@ -771,15 +742,7 @@ def __init__(self, model, qaic_config, **kwargs):
if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))
- def export(
- self,
- inputs,
- output_names,
- dynamic_axes,
- export_dir=None,
- offload_pt_weights=True,
- use_onnx_subfunctions: bool = False,
- ):
+ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs):
"""
Exports the language decoder component to ONNX format.
@@ -809,7 +772,7 @@ def export(
dynamic_axes,
export_dir=export_dir,
offload_pt_weights=offload_pt_weights,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
)
def compile(
@@ -869,21 +832,6 @@ def compile(
**compiler_options,
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying language decoder model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@property
def get_model_config(self) -> dict:
"""
@@ -924,16 +872,16 @@ def __init__(
----------
model : nn.Module
The full HuggingFace multimodal model.
+ qaic_config : dict, optional
+ A dictionary for QAIC-specific configurations.
**kwargs :
- Additional keyword arguments. `full_batch_size` is not supported here.
-
- Raises
- ------
- NotImplementedError
- If `full_batch_size` is provided.
+ Additional keyword arguments.
"""
if kwargs.pop("full_batch_size", None):
- raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
+ continuous_batching = True
+ warnings.warn(
+ "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
+ )
self.model = model
self.config = model.config
@@ -945,21 +893,11 @@ def __init__(
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
self.input_shapes, self.output_names = None, None
-
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying multimodal model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
+ # ---Sampling---
+ # Note: SamplerTransform should be applied after all other transforms
+ # are done. The role of the sampler is to just add nodes at the output of the
+ # previous transform function.
+ self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs):
@@ -1070,6 +1008,19 @@ def export(
kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode
)
output_names = self.model.get_output_names(kv_offload=True)
+ if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get(
+ "include_sampler", False
+ ):
+ logits_index = output_names["lang"].index("logits")
+ output_names["lang"][logits_index] = "next_tokens"
+ inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs(
+ example_inputs=inputs["lang"],
+ output_names=output_names["lang"],
+ dynamic_axes=dynamic_axes["lang"],
+ continuous_batching=self.continuous_batching,
+ vocab_size=self.model.language_model.config.vocab_size,
+ qaic_config=self.lang_model.model.qaic_config,
+ )
self.vision_model.export(
inputs["vision"],
@@ -1186,17 +1137,14 @@ def compile(
# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
- if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
- logger.warning(
- "Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
- )
- self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
+ if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
+ logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)
-
# For supporting VLLM and Disaggregated with CCL
- if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
- self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
+ elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)
@@ -1302,6 +1250,7 @@ def generate(
generation_len: Optional[int] = None,
image_height: Optional[int] = None,
image_width: Optional[int] = None,
+ **kwargs,
) -> Union[torch.Tensor, np.ndarray]:
"""
Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards.
@@ -1362,6 +1311,7 @@ def generate(
comp_ctx_lengths_decode=self.comp_ctx_lengths_decode,
image_height=image_height,
image_width=image_width,
+ **kwargs,
)
# Call generate method
@@ -1510,7 +1460,9 @@ def kv_offload_generate(
lang_session.set_buffers(vision_outputs)
if self.comp_ctx_lengths_prefill is not None:
- list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
+ list_of_comp_ctx_lengths_prefill = [
+ np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_prefill
+ ]
prefill_ccl_id = 0
lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
@@ -1559,7 +1511,9 @@ def kv_offload_generate(
# Decode loop
if self.comp_ctx_lengths_decode is not None:
max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
- list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
+ list_of_comp_ctx_lengths_decode = [
+ np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_decode
+ ]
max_position_id = np.max(lang_inputs["position_ids"])
ccl_id_initial = 0
ccl_id = ccl_id_initial
@@ -1644,10 +1598,15 @@ def __init__(
Raises
------
NotImplementedError
- If `full_batch_size` is provided.
+ If `full_batch_size` is provided or `include_sampler` is True.
"""
if kwargs.pop("full_batch_size", None):
+ warnings.warn(
+ "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
+ )
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
+ if qaic_config is not None and qaic_config.pop("include_sampler", False):
+ raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.")
super().__init__(model, **kwargs)
self.model.qaic_config = qaic_config
@@ -1834,17 +1793,14 @@ def compile(
# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
- if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
- logger.warning(
- "Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
- )
- self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
+ if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
+ logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)
-
# For supporting VLLM and Disaggregated with CCL
- if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
- self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
+ elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)
@@ -2048,7 +2004,9 @@ def cloud_ai_100_generate(
inputs["image_idx"] = np.array([[0]])
if self.comp_ctx_lengths_prefill is not None:
- list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
+ list_of_comp_ctx_lengths_prefill = [
+ np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_prefill
+ ]
prefill_ccl_id = 0
inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
@@ -2089,7 +2047,9 @@ def cloud_ai_100_generate(
# Decode loop
if self.comp_ctx_lengths_decode is not None:
- list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
+ list_of_comp_ctx_lengths_decode = [
+ np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_decode
+ ]
max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
max_position_id = np.max(inputs["position_ids"])
ccl_id_initial = 0
@@ -2131,21 +2091,6 @@ def cloud_ai_100_generate(
),
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying multimodal model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@property
def get_model_config(self) -> dict:
"""
@@ -2279,6 +2224,8 @@ def from_pretrained(
If True, uses the dual QPC approach (vision encoder KV offloaded).
If False, uses the single QPC approach (entire model in one QPC).
If None, the default behavior of the internal classes is used (typically dual QPC).
+ qaic_config : dict, optional
+ A dictionary for QAIC-specific configurations.
**kwargs :
Additional arguments passed to HuggingFace's ``from_pretrained``.
@@ -2306,7 +2253,6 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
-
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(
model,
@@ -2359,11 +2305,30 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+ def prefill(
+ self,
+ enable: Optional[bool] = True,
+ enable_chunking: Optional[bool] = False,
+ retain_full_kv: Optional[bool] = False,
+ ):
+ if enable:
+ if enable_chunking:
+ self.model, tf = PrefillOnlyChunkedTransform.apply(self.model)
+ else:
+ self.model, tf = PrefillOnlyTransform.apply(self.model)
+
+ else:
+ if retain_full_kv:
+ self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model)
+ else:
+ self.model, tf = RevertPrefillOnlyTransform.apply(self.model)
+
def __init__(
self,
model: nn.Module,
continuous_batching: bool = False,
qaic_config: Optional[dict] = None,
+ max_seq_len_cached: Optional[int] = None,
**kwargs,
):
"""
@@ -2383,6 +2348,8 @@ def __init__(
- **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens.
For Speculative Decoding Target Language Models, this is always True.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
+ - **include_guided_decoding** (bool): If True, enables guided token-level filtering
+ during decoding. Only works when include_sampler=True.
- **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation.
**kwargs :
Additional keyword arguments passed to the base class constructor.
@@ -2411,6 +2378,7 @@ def __init__(
# Set use_cache=True to get KV values as output during ONNX export
model.config.use_cache = True
+ setattr(model.config, "max_seq_len_cached", max_seq_len_cached)
super().__init__(model, qaic_config=qaic_config, **kwargs)
self.num_layers = model.config.num_hidden_layers
self.continuous_batching = continuous_batching
@@ -2423,6 +2391,7 @@ def __init__(
if qaic_config:
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
+ self.hash_params["max_seq_len_cached"] = max_seq_len_cached
# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
@@ -2437,21 +2406,6 @@ def __init__(
if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying Causal Language Model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
def __repr__(self) -> str:
return self.__class__.__name__ + "\n" + self.model.__repr__()
@@ -2462,6 +2416,7 @@ def from_pretrained(
pretrained_model_name_or_path,
continuous_batching: bool = False,
qaic_config: Optional[dict] = None,
+ max_seq_len_cached: Optional[int] = None,
*args,
**kwargs,
):
@@ -2491,6 +2446,8 @@ def from_pretrained(
and ``return_pdfs=False`` for regular model.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
The values provided in ``top_ks`` tensor must be less than this maximum limit.
+ - **include_guided_decoding** (bool): If True, enables guided token-level filtering
+ during decoding. Only works when include_sampler=True.
*args :
Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`.
@@ -2525,7 +2482,6 @@ def from_pretrained(
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
# This is support models that should be classified to in a different auto class but transformers load them via this class
-
if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
model,
@@ -2540,6 +2496,7 @@ def from_pretrained(
continuous_batching=continuous_batching,
qaic_config=qaic_config,
pretrained_model_name_or_path=pretrained_model_name_or_path,
+ max_seq_len_cached=max_seq_len_cached,
**kwargs,
)
@@ -2555,7 +2512,56 @@ def get_model_config(self) -> dict:
"""
return self.model.config.__dict__
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, **kwargs) -> str:
+ def get_seq_len_and_handle_specialized_prefill_model(
+ self, prefill_seq_len: Optional[int] = None, enable_chunking=False
+ ) -> int:
+ self.hash_params["prefill_only"] = True
+ if enable_chunking:
+ self.hash_params["chunking"] = True
+ return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
+
+ num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None)
+ if num_q_blocks is None:
+ block_size = 128
+ if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128:
+ raise ValueError(
+ f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. "
+ f"Or set `NUM_Q_BLOCKS` ENV variable"
+ f"Received: prefill_seq_len={prefill_seq_len}"
+ )
+
+ num_q_blocks = prefill_seq_len // block_size
+ logger.warning(
+ f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override"
+ )
+ os.environ["NUM_Q_BLOCKS"] = str(num_q_blocks)
+ num_q_blocks = int(num_q_blocks)
+
+ num_ffn_blocks = os.environ.get("NUM_FFN_BLOCKS", None)
+ num_ffn_blocks = int(num_ffn_blocks) if num_ffn_blocks else num_ffn_blocks
+ min_seq_len = max(num_q_blocks, num_ffn_blocks) if num_ffn_blocks else num_q_blocks
+ if (num_ffn_blocks and min_seq_len % num_ffn_blocks != 0) or min_seq_len % num_q_blocks != 0:
+ raise ValueError(
+ f"Got NUM_FFN_BLOCKS={num_ffn_blocks} and NUM_Q_BLOCKS={num_q_blocks}, tried to set seq_len={min_seq_len} for export but,"
+ "seq_len is not divisible by either num_ffn_blocks or num_q_blocks, try chaning the values."
+ )
+
+ self.hash_params["NUM_Q_BLOCKS"] = num_q_blocks
+ self.hash_params["NUM_FFN_BLOCKS"] = num_ffn_blocks
+ self.hash_params["ENABLE_OPT_SWA"] = os.environ.get("ENABLE_OPT_SWA", "0")
+ return (
+ min_seq_len
+ if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
+ else constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
+ )
+
+ def export(
+ self,
+ export_dir: Optional[str] = None,
+ prefill_only: Optional[bool] = False,
+ prefill_seq_len: Optional[int] = None,
+ **kwargs,
+ ) -> str:
"""
Export the model to ONNX format using ``torch.onnx.export``.
@@ -2581,6 +2587,39 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
kv_cache_shape = get_padding_shape_from_config(
self.model.config, fbs if self.continuous_batching else bs, seq_len
)
+ enable_chunking = kwargs.get("enable_chunking", False)
+ if prefill_only:
+ if not enable_chunking and self.continuous_batching:
+ raise NotImplementedError(
+ "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
+ )
+ self.prefill(enable=True, enable_chunking=enable_chunking)
+ self.hash_params.pop("retain_full_kv", None)
+ seq_len = (
+ self.get_seq_len_and_handle_specialized_prefill_model(
+ prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking
+ )
+ if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH
+ else seq_len
+ )
+ kv_cache_shape[2] = (
+ seq_len + (self.model.config.sliding_window if hasattr(self.model.config, "sliding_window") else 0)
+ if enable_chunking
+ else seq_len
+ )
+ else:
+ self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
+ self.hash_params.pop("prefill_only", None)
+ self.hash_params.pop("NUM_Q_BLOCKS", None)
+ self.hash_params.pop("NUM_FFN_BLOCKS", None)
+ self.hash_params.pop("ENABLE_OPT_SWA", None)
+ self.hash_params.pop("chunking", None)
+ if kwargs.get("retain_full_kv", False):
+ kv_cache_shape[2] = seq_len + (
+ self.model.config.sliding_window if hasattr(self.model.config, "sliding_window") else 0
+ )
+ self.hash_params["retain_full_kv"] = True
+
example_inputs = {
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
"position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
@@ -2591,7 +2630,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
"position_ids": {0: "batch_size", 1: "seq_len"},
}
if self.comp_ctx_lengths_prefill is not None:
- example_inputs["comp_ctx_lengths"] = torch.randint(0, 512, (512,), dtype=torch.long)
+ example_inputs["comp_ctx_lengths"] = torch.randint(0, 127, (512,), dtype=torch.int8)
dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"}
if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d
@@ -2629,7 +2668,13 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
else:
# HACK: create common function for this including above if condition code
pkv_dynamic_axes = (
- self.model.get_pkv_dynamic_axes() if hasattr(self.model, "get_pkv_dynamic_axes") else pkv_dynamic_axes
+ self.model.get_pkv_dynamic_axes(
+ retain_full_kv=kwargs.get("retain_full_kv", False)
+ or (prefill_only and kwargs.get("enable_chunking", False)),
+ continuous_batching=self.continuous_batching,
+ )
+ if hasattr(self.model, "get_pkv_dynamic_axes")
+ else pkv_dynamic_axes
)
pkv_dynamic_axes = (
[pkv_dynamic_axes] * self.model.config.num_hidden_layers
@@ -2638,7 +2683,6 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
)
for i in range(self.num_layers):
- pkv_dynamic_axes[i][0] = "full_batch_size" if self.continuous_batching else "batch_size"
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i]
@@ -2654,100 +2698,24 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"}
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
- example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(
+ example_inputs, output_names, dynamic_axes = get_sampling_inputs_and_outputs(
example_inputs=example_inputs,
output_names=output_names,
dynamic_axes=dynamic_axes,
+ continuous_batching=self.continuous_batching,
+ vocab_size=self.model.config.vocab_size,
+ qaic_config=self.model.qaic_config,
)
-
return self._export(
example_inputs,
output_names,
dynamic_axes,
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
offload_pt_weights=kwargs.get("offload_pt_weights", True),
+ prefill_only=prefill_only,
)
- def get_sampling_inputs_and_outputs(
- self,
- example_inputs: Dict[str, torch.Tensor],
- output_names: List[str],
- dynamic_axes: Dict[str, Dict[int, str]],
- ):
- """
- Updates the example inputs, output names, and dynamic axes to include
- parameters relevant for on-device sampling during ONNX export.
-
- Parameters
- ----------
- example_inputs : Dict[str, torch.Tensor]
- Current dictionary of example inputs.
- output_names : List[str]
- Current list of output names.
- dynamic_axes : Dict[str, Dict[int, str]]
- Current dictionary of dynamic axes configurations.
-
- Returns
- -------
- Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
- Updated example inputs, output names, and dynamic axes including
- sampling-related parameters.
- """
- bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
- fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
-
- example_inputs["last_accepted_output_tokens"] = torch.zeros(
- (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
- )
- dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}
-
- example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
- (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
- )
- dynamic_axes["past_repetition_penalty_buffer"] = {
- 0: "full_batch_size" if self.continuous_batching else "batch_size",
- }
- output_names.append("past_repetition_penalty_buffer_RetainedState")
-
- example_inputs["repetition_penalties"] = (
- torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
- )
- dynamic_axes["repetition_penalties"] = {0: "batch_size"}
-
- example_inputs["past_presence_penalty_buffer"] = torch.zeros(
- (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
- )
- dynamic_axes["past_presence_penalty_buffer"] = {
- 0: "full_batch_size" if self.continuous_batching else "batch_size",
- }
- output_names.append("past_presence_penalty_buffer_RetainedState")
-
- example_inputs["presence_penalties"] = (
- torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
- )
- dynamic_axes["presence_penalties"] = {0: "batch_size"}
-
- example_inputs["temperatures"] = (
- torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
- )
- dynamic_axes["temperatures"] = {0: "batch_size"}
-
- max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
- example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
- dynamic_axes["top_ks"] = {0: "batch_size"}
-
- example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
- dynamic_axes["top_ps"] = {0: "batch_size"}
-
- example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
- dynamic_axes["min_ps"] = {0: "batch_size"}
-
- example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
- dynamic_axes["random_numbers"] = {0: "batch_size"}
-
- return example_inputs, output_names, dynamic_axes
-
def build_prefill_specialization(
self,
prefill_seq_len: int = 32,
@@ -2756,6 +2724,7 @@ def build_prefill_specialization(
batch_size: int = 1,
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
+ **kwargs,
):
"""
Builds a dictionary representing a compilation specialization for the prefill phase.
@@ -2778,11 +2747,17 @@ def build_prefill_specialization(
Dict[str, Union[int, str]]
A dictionary defining the prefill specialization.
"""
+ if prefill_seq_len == 1 and self.continuous_batching:
+ exec_batch_size = full_batch_size
+ else:
+ exec_batch_size = 1 if self.continuous_batching else batch_size
+
if hasattr(self.model, "get_specializations"):
spec = self.model.get_specializations(
- batch_size=1 if self.continuous_batching else batch_size,
+ batch_size=exec_batch_size,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
+ **kwargs,
)[0]
else:
spec = {
@@ -2810,6 +2785,7 @@ def build_decode_specialization(
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
+ **kwargs,
):
"""
Builds a dictionary representing a compilation specialization for the decode phase.
@@ -2880,6 +2856,9 @@ def compile(
num_speculative_tokens: Optional[int] = None,
prefill_only: Optional[bool] = None,
use_onnx_subfunctions: bool = False,
+ offload_pt_weights: Optional[bool] = True,
+ enable_chunking: Optional[bool] = False,
+ retain_full_kv: Optional[bool] = None,
**compiler_options,
) -> str:
"""
@@ -2960,19 +2939,30 @@ def compile(
If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models.
"""
+ if prefill_only is None or not prefill_only:
+ if self.continuous_batching and full_batch_size is None:
+ raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")
+ if kv_cache_batch_size and not full_batch_size:
+ raise ValueError(
+ "KV caching requires continuous batching. Please set `full_batch_size` and "
+ "enable `continuous_batching=True` in `from_pretrained`."
+ )
+ else:
+ if self.continuous_batching:
+ if not isinstance(kv_cache_batch_size, int):
+ raise ValueError(
+ "Please pass valid integer for kv_cache_batch_size as continuous_batching is enabled for prefill-only model"
+ )
# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
- if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
- logger.warning(
- "Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
- )
- self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
+ if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
+ logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)
-
# For supporting VLLM and Disaggregated with CCL
- if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
+ elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
if isinstance(comp_ctx_lengths_prefill, str):
import ast
@@ -2987,7 +2977,7 @@ def compile(
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
- self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)
# --- Validation ---
@@ -2997,15 +2987,6 @@ def compile(
if self.is_tlm:
num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len)
- if self.continuous_batching and full_batch_size is None:
- raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")
-
- if kv_cache_batch_size and not full_batch_size:
- raise ValueError(
- "KV caching requires continuous batching. Please set `full_batch_size` and "
- "enable `continuous_batching=True` in `from_pretrained`."
- )
-
if (
self.model.qaic_config is not None
and self.model.qaic_config.get("include_sampler", False)
@@ -3014,15 +2995,23 @@ def compile(
):
raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.")
+ if kv_cache_batch_size and prefill_only is not None and prefill_only:
+ logger.warning(
+ "kv_cache_batch_size will be ignored as prefill_only is set to True unless this is GPTOSS model"
+ )
+
# Infer kv_cache_batch_size if not provided
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
# --- Specializations ---
specializations = []
if prefill_only is None or prefill_only or prefill_seq_len == 1:
+ # TODO: we are handling decode-only case inside prefill call which is utterly mis-leading
if self.comp_ctx_lengths_prefill is not None:
# Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization
for i in range(0, len(self.comp_ctx_lengths_prefill)):
+ if prefill_only or enable_chunking:
+ raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL")
specializations.append(
self.build_prefill_specialization(
prefill_seq_len=prefill_seq_len,
@@ -3042,6 +3031,8 @@ def compile(
batch_size=batch_size,
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
+ prefill_only=prefill_only,
+ enable_chunking=enable_chunking,
)
)
@@ -3069,6 +3060,7 @@ def compile(
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
+ prefill_only=prefill_only,
)
if decode_spec:
specializations.append(decode_spec)
@@ -3081,7 +3073,6 @@ def compile(
for i in range(self.num_layers):
for kv in ["key", "value"]:
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
-
qpc_path = self._compile(
onnx_path=onnx_path,
compile_dir=compile_dir,
@@ -3096,6 +3087,10 @@ def compile(
aic_num_cores=num_cores,
mxint8_kv_cache=mxint8_kv_cache,
use_onnx_subfunctions=use_onnx_subfunctions,
+ prefill_only=prefill_only,
+ offload_pt_weights=offload_pt_weights,
+ enable_chunking=enable_chunking,
+ retain_full_kv=retain_full_kv,
**compiler_options,
)
@@ -3287,7 +3282,7 @@ def get_model_config(self) -> dict:
"""
return self.model.config.__dict__
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model to ONNX format using ``torch.onnx.export``.
@@ -3315,7 +3310,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
output_names,
dynamic_axes,
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
)
def compile(
@@ -3663,7 +3658,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k
def get_model_config(self) -> dict:
return self.model.config.__dict__
- def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Exports the model to ``ONNX`` format using ``torch.onnx.export``.
@@ -3691,7 +3686,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
output_names,
dynamic_axes,
export_dir=export_dir,
- use_onnx_subfunctions=use_onnx_subfunctions,
+ use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
)
def compile(
diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py
index 7bfa58fc0..b686e6aed 100644
--- a/QEfficient/transformers/models/molmo/modeling_molmo.py
+++ b/QEfficient/transformers/models/molmo/modeling_molmo.py
@@ -972,7 +972,7 @@ def get_dummy_inputs(
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
if comp_ctx_lengths is not None:
- lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8)
if continuous_batching:
lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py
index 21a867eb5..b978b6193 100644
--- a/QEfficient/transformers/models/pytorch_transforms.py
+++ b/QEfficient/transformers/models/pytorch_transforms.py
@@ -197,6 +197,10 @@
Starcoder2ForCausalLM,
Starcoder2Model,
)
+from transformers.models.t5.modeling_t5 import (
+ T5Attention,
+ T5LayerNorm,
+)
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
@@ -238,6 +242,7 @@
QEffGemma3Attention,
QEffGemma3CustomRMSNormAIC,
QEffGemma3DecoderLayer,
+ QEffGemma3DecoderWrapper,
QEffGemma3ForCausalLMModel,
QEffGemma3ForConditionalGeneration,
QEffGemma3TextModel,
@@ -261,6 +266,11 @@
QEffGptOssForCausalLM,
QEffGptOssMLP,
QEffGptOssModel,
+ QEffPrefillOnlyChunkedGptOssAttention,
+ QEffPrefillOnlyChunkedGptOssMLP,
+ QEffPrefillOnlyGptOssAttention,
+ QEffPrefillOnlyGptOssMLP,
+ QEffPrefillOnlyGptOssModel,
)
from QEfficient.transformers.models.gptj.modeling_gptj import (
QEffGPTJAttention,
@@ -292,6 +302,7 @@
QEffGrok1MultiHeadAttention,
)
from QEfficient.transformers.models.internvl.modeling_internvl import (
+ QEffInternDecoderWrapper,
QEffInternVisionEmbeddings,
QEffInternVLModel,
)
@@ -303,6 +314,7 @@
QEffLlamaRotaryEmbedding,
)
from QEfficient.transformers.models.llama4.modeling_llama4 import (
+ QEffLlama4DecoderWrapper,
QEffLlama4ForCausalLM,
QEffLlama4ForConditionalGeneration,
QEffLlama4Router,
@@ -315,9 +327,11 @@
QEffLlama4VisionModel,
)
from QEfficient.transformers.models.llava.modeling_llava import (
+ QEFFLlavaDecoderWrapper,
QEffLlavaForConditionalGeneration,
)
from QEfficient.transformers.models.llava_next.modeling_llava_next import (
+ QEffLlavaNextDecoderWrapper,
QEffLlavaNextForConditionalGeneration,
)
from QEfficient.transformers.models.mistral.modeling_mistral import (
@@ -395,6 +409,7 @@
QEffQwen2_5_VLModel,
QEffQwen2_5_VLTextModel,
QEffQwen2_5_VLVisionAttention,
+ QEffQwen_2_5_vl_DecoderWrapper,
QEffQwen_2_5_vl_ForConditionalGeneration,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
@@ -417,6 +432,10 @@
QEffStarcoder2ForCausalLM,
QEffStarcoder2Model,
)
+from QEfficient.transformers.models.t5.modeling_t5 import (
+ QEffT5Attention,
+ QEffT5LayerNorm,
+)
from QEfficient.transformers.models.whisper.modeling_whisper import (
QEffWhisperAttention,
QEffWhisperDecoder,
@@ -634,6 +653,39 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
return model, transformed
+class PrefillOnlyTransform(ModuleMappingTransform):
+ _module_mapping = {
+ QEffGptOssModel: QEffPrefillOnlyGptOssModel,
+ QEffGptOssAttention: QEffPrefillOnlyGptOssAttention,
+ QEffGptOssMLP: QEffPrefillOnlyGptOssMLP,
+ }
+
+
+class PrefillOnlyChunkedTransform(ModuleMappingTransform):
+ _module_mapping = {
+ QEffGptOssModel: QEffPrefillOnlyGptOssModel,
+ QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
+ QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP,
+ }
+
+
+class RevertPrefillKeepAttentionTransform(ModuleMappingTransform):
+ _module_mapping = {
+ QEffGptOssModel: QEffPrefillOnlyGptOssModel,
+ QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
+ QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
+ QEffPrefillOnlyGptOssMLP: QEffGptOssMLP,
+ QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP,
+ }
+
+
+class RevertPrefillOnlyTransform(ModuleMappingTransform):
+ _module_mapping = {
+ **{v: k for k, v in PrefillOnlyTransform._module_mapping.items()},
+ **{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()},
+ }
+
+
class SpDTransform:
"""
Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill.
@@ -707,14 +759,20 @@ class SamplerTransform:
_module_mapping = {
QEffFalconForCausalLM,
QEffGemmaForCausalLM,
+ QEffGemma3DecoderWrapper,
QEffGPT2LMHeadModel,
QEffGPTJForCausalLM,
QEffGraniteForCausalLM,
QEffGraniteMoeForCausalLM,
+ QEffInternDecoderWrapper,
QEffLlamaForCausalLM,
+ QEffLlama4DecoderWrapper,
+ QEFFLlavaDecoderWrapper,
+ QEffLlavaNextDecoderWrapper,
QEffMptForCausalLM,
QEffPhi3ForCausalLM,
QEffQwen2ForCausalLM,
+ QEffQwen_2_5_vl_DecoderWrapper,
}
@classmethod
@@ -808,6 +866,14 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform):
_match_class_replace_method = {}
+class T5ModelTransform(ModuleMappingTransform):
+ # supported architectures
+ _module_mapping = {
+ T5Attention: QEffT5Attention,
+ T5LayerNorm: QEffT5LayerNorm,
+ }
+
+
class PoolingTransform:
"""
Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output.
diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
index 63e046600..21d2e026e 100644
--- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
+++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
@@ -992,7 +992,7 @@ def get_dummy_inputs(
lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
if comp_ctx_lengths is not None:
- lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8)
inputs = {}
if kv_offload:
diff --git a/QEfficient/transformers/models/t5/__init__.py b/QEfficient/transformers/models/t5/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/transformers/models/t5/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/transformers/models/t5/modeling_t5.py b/QEfficient/transformers/models/t5/modeling_t5.py
new file mode 100644
index 000000000..f54201465
--- /dev/null
+++ b/QEfficient/transformers/models/t5/modeling_t5.py
@@ -0,0 +1,145 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+from transformers import EncoderDecoderCache
+from transformers.models.t5.modeling_t5 import (
+ T5Attention,
+ T5LayerNorm,
+)
+
+
+class QEffT5LayerNorm(T5LayerNorm):
+ def forward(self, hidden_states):
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
+ # half-precision inputs is done in fp32
+
+ div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32))
+ variance = div_first.pow(2).sum(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+class QEffT5Attention(T5Attention):
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ key_value_states=None,
+ position_bias=None,
+ past_key_value=None,
+ layer_head_mask=None,
+ query_length=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ # Input is (batch_size, seq_length, dim)
+ # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
+ is_cross_attention = key_value_states is not None
+
+ query_states = self.q(hidden_states)
+ query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
+ if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache):
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = past_key_value.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_value.self_attention_cache
+ else:
+ curr_past_key_value = past_key_value
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_value is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states = self.k(current_states)
+ value_states = self.v(current_states)
+ key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention:
+ past_key_value.is_updated[self.layer_idx] = True
+
+ # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
+
+ if position_bias is None:
+ key_length = key_states.shape[-2]
+ # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
+ real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(
+ real_seq_length, key_length, device=scores.device, cache_position=cache_position
+ )
+ if past_key_value is not None: # This block is where the patch applies
+ position_bias = position_bias[:, :, -1:, :] # Added by patch
+
+ if mask is not None:
+ causal_mask = mask[:, :, :, : key_states.shape[-2]]
+ position_bias = position_bias + causal_mask
+
+ if self.pruned_heads:
+ mask = torch.ones(position_bias.shape[1])
+ mask[list(self.pruned_heads)] = 0
+ position_bias_masked = position_bias[:, mask.bool()]
+ else:
+ position_bias_masked = position_bias
+
+ scores += position_bias_masked
+
+ # (batch_size, n_heads, seq_length, key_length)
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(batch_size, -1, self.inner_dim)
+ attn_output = self.o(attn_output)
+
+ outputs = (attn_output, position_bias)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py
index dfadc00ef..dc2308e99 100644
--- a/QEfficient/transformers/quantizers/__init__.py
+++ b/QEfficient/transformers/quantizers/__init__.py
@@ -5,6 +5,6 @@
#
# -----------------------------------------------------------------------------
-from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers
+from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers
-__all__ = ["replace_transformers_quantizers"]
+__all__ = ["replace_transformers_quantizers", "undo_transformers_quantizers"]
diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py
index 96846e712..5c86b6355 100644
--- a/QEfficient/transformers/sampler/sampler.py
+++ b/QEfficient/transformers/sampler/sampler.py
@@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput):
probs: torch.FloatTensor = None
next_tokens: torch.IntTensor = None
+ vision_embeds: Optional[torch.FloatTensor] = None # For VLMs
+ image_idx: Optional[torch.IntTensor] = None # for VLMs
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_repetition_penalty_buffer: Optional[torch.Tensor] = None
past_presence_penalty_buffer: Optional[torch.Tensor] = None
@@ -47,7 +49,6 @@ def prefill_path(
positions_mask = (position_ids[:, :1] != zero_tensor).view(-1, 1)
mul_value = CtxScatterFuncCB3D.apply(mul_value, batch_index, zero_tensor, positions_mask)
past_repetition_penalty_buffer *= mul_value
- past_presence_penalty_buffer *= mul_value
# Mask out-of-bounds or invalid position_ids or input_ids
input_ids = torch.where(position_ids == -1, torch.iinfo(torch.int32).max, input_ids)
@@ -59,6 +60,9 @@ def prefill_path(
input_ids,
torch.ones(input_ids.shape, dtype=torch.bool),
)
+
+ mul_value = torch.zeros(past_presence_penalty_buffer.shape[0], 1, dtype=torch.bool)
+ past_presence_penalty_buffer *= mul_value
return past_repetition_penalty_buffer, past_presence_penalty_buffer
@@ -103,6 +107,7 @@ def sampler_forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -112,6 +117,8 @@ def sampler_forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: Optional[int] = None,
+ vision_embeds: Optional[torch.FloatTensor] = None,
+ image_idx: Optional[torch.IntTensor] = None,
last_accepted_output_tokens: Optional[torch.Tensor] = None, # (batch_size, spec_length or less)
past_repetition_penalty_buffer: Optional[torch.Tensor] = None,
repetition_penalties: Optional[torch.Tensor] = None,
@@ -122,11 +129,15 @@ def sampler_forward(
top_ps: Optional[torch.Tensor] = None,
min_ps: Optional[torch.Tensor] = None,
random_numbers: Optional[torch.Tensor] = None,
+ token_bitmasks: Optional[torch.Tensor] = None,
) -> Union[Tuple, SamplerOutput]:
r"""
Perform the sampling of next tokens on the QAIC device (instead of the host)
and return the next tokens and/or probability distributions.
+ The vision_embeds and image_idx parameters are optional
+ and are used only for VLMs when supported by the original forward function.
+
Args:
last_accepted_output_tokens (`torch.Tensor`, *optional*):
Output tokens accepted by the Speculative Decoding Draft Language Model.
@@ -169,21 +180,43 @@ def sampler_forward(
random_numbers (`torch.Tensor`, *optional*):
Sampling parameter that represents the random seeds to use for random sampling.
Must be in [-1, 1].
- """
- outputs = self.old_forward(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- batch_index=batch_index,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- cache_position=cache_position,
- )
+ token_bitmasks (`torch.Tensor`, *optional*):
+ Boolean mask used to guide token-level filtering during decoding. Each
+ element of this tensor indicates whether the corresponding token should be
+ kept (1) or masked (0). Shape: (batch_size, vocab_size)
+ """
+ if vision_embeds is not None:
+ forward_kwargs = dict(
+ input_ids=input_ids,
+ vision_embeds=vision_embeds,
+ position_ids=position_ids,
+ image_idx=image_idx,
+ past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
+ )
+ if batch_index is not None:
+ forward_kwargs["batch_index"] = batch_index
+
+ logits, vision_embeds, image_idx, past_key_values = self.old_forward(**forward_kwargs)
+ outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values)
+ if position_ids.dim() == 3: # For models using m-rope
+ position_ids = position_ids[0]
+ else:
+ outputs = self.old_forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
+ batch_index=batch_index,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
logits = outputs.get("logits", None)
assert logits is not None, f"{self.model.__class__.__name__} does not return logits."
@@ -197,6 +230,13 @@ def sampler_forward(
batch_index = torch.arange(batch_size).view(-1, 1)
batch_index_reshaped = batch_index.view(-1)
+
+ # Guided decoding
+ if token_bitmasks is not None and (token_bitmasks != 1).any():
+ assert spec_length == 1, "Currently, guided decoding is not supported with Speculative Decoding"
+ # Mask logits where token_bitmasks is 0 with -inf
+ logits = torch.where(token_bitmasks == 1, logits, torch.finfo(torch.float16).min)
+
# Prefill
past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path(
input_ids=input_ids,
@@ -224,17 +264,6 @@ def sampler_forward(
is_prefill, past_presence_penalty_buffer_prefill, past_presence_penalty_buffer_decode
)
- # Greedy Sampling
- greedy_samples = torch.argmax(logits, dim=1, keepdim=True) # (batch_size * spec_length, 1)
- if (temperatures == 0).all() and not self.qaic_config.get("return_pdfs", False):
- return SamplerOutput(
- probs=None,
- next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits
- past_key_values=outputs.past_key_values,
- past_repetition_penalty_buffer=past_repetition_penalty_buffer,
- past_presence_penalty_buffer=past_presence_penalty_buffer,
- )
-
# Repetition Penalty
if (repetition_penalties != 1.0).any():
past_repetition_penalty_buffer_selected = past_repetition_penalty_buffer[batch_index_reshaped].repeat(
@@ -253,6 +282,19 @@ def sampler_forward(
) # (batch_size * spec_length, vocab_size)
logits -= presence_penalties * past_presence_penalty_buffer_selected
+ # Greedy Sampling
+ greedy_samples = torch.argmax(logits, dim=1, keepdim=True) # (batch_size * spec_length, 1)
+ if (temperatures == 0).all() and not self.qaic_config.get("return_pdfs", False):
+ return SamplerOutput(
+ probs=None,
+ next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits
+ vision_embeds=outputs.get("vision_embeds", None),
+ image_idx=outputs.get("image_idx", None),
+ past_key_values=outputs.get("past_key_values", None),
+ past_repetition_penalty_buffer=past_repetition_penalty_buffer,
+ past_presence_penalty_buffer=past_presence_penalty_buffer,
+ )
+
# TODO: Frequency Penalty
# Temperature Scaling
@@ -300,9 +342,8 @@ def sampler_forward(
) # (batch_size, spec_length, vocab_size)
# Random Sampling
- topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids)
gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick
- y = topk_probs_asc + gumbel_noise
+ y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids)
random_samples_indices = torch.argmax(y, dim=1, keepdim=True)
random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1)
@@ -314,7 +355,9 @@ def sampler_forward(
return SamplerOutput(
probs=probs,
next_tokens=next_tokens, # Return sampled next tokens instead of logits
- past_key_values=outputs.past_key_values,
+ vision_embeds=outputs.get("vision_embeds", None),
+ image_idx=outputs.get("image_idx", None),
+ past_key_values=outputs.get("past_key_values", None),
past_repetition_penalty_buffer=past_repetition_penalty_buffer,
past_presence_penalty_buffer=past_presence_penalty_buffer,
)
diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py
index 49f0ad30b..3d6583f85 100755
--- a/QEfficient/utils/__init__.py
+++ b/QEfficient/utils/__init__.py
@@ -16,7 +16,6 @@
create_model_params,
custom_format_warning,
dump_qconfig,
- export_wrapper,
generate_mdp_partition_config,
get_num_layers_from_config,
get_num_layers_vlm,
diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py
index 131a7fc26..26bae7a34 100644
--- a/QEfficient/utils/_utils.py
+++ b/QEfficient/utils/_utils.py
@@ -12,7 +12,6 @@
import subprocess
import xml.etree.ElementTree as ET
from dataclasses import dataclass
-from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import requests
@@ -27,9 +26,8 @@
PreTrainedTokenizerFast,
)
-from QEfficient.utils.cache import QEFF_HOME
from QEfficient.utils.constants import KWARGS_INCLUSION_LIST, QEFF_MODELS_DIR, Constants, QnnConstants
-from QEfficient.utils.hash_utils import create_export_hash, json_serializable
+from QEfficient.utils.hash_utils import json_serializable
from QEfficient.utils.logging_utils import logger
@@ -532,61 +530,11 @@ def create_model_params(qeff_model, **kwargs) -> Dict:
"""
model_params = copy.deepcopy(kwargs)
model_params = {k: v for k, v in model_params.items() if k in KWARGS_INCLUSION_LIST}
- model_params["config"] = qeff_model.model.config.to_diff_dict()
model_params["peft_config"] = getattr(qeff_model.model, "active_peft_config", None)
model_params["applied_transform_names"] = qeff_model._transform_names()
return model_params
-def export_wrapper(func):
- def wrapper(self, *args, **kwargs):
- export_dir = kwargs.get("export_dir", None)
- parent_dir = self.model_architecture or self.model_name
- export_dir = Path(export_dir or (QEFF_HOME / parent_dir / self.model_name))
-
- # PREPROCESSING OF PARAMETERS
-
- # Get the original signature
- original_sig = inspect.signature(func)
-
- # Remove 'self' from parameters
- params = list(original_sig.parameters.values())[1:] # skip 'self'
- new_sig = inspect.Signature(params)
-
- # Bind args and kwargs to the new signature
- bound_args = new_sig.bind(*args, **kwargs)
- bound_args.apply_defaults()
-
- # Get arguments as a dictionary
- all_args = bound_args.arguments
-
- export_hash, filtered_hash_params = create_export_hash(
- model_params=self.hash_params,
- output_names=all_args.get("output_names"),
- dynamic_axes=all_args.get("dynamic_axes"),
- export_kwargs=all_args.get("export_kwargs", None),
- onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None),
- use_onnx_subfunctions=all_args.get("use_onnx_subfunctions", False),
- )
-
- export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)
- kwargs["export_dir"] = export_dir
- self.export_hash = export_hash
-
- # _EXPORT CALL
- onnx_path = func(self, *args, **kwargs)
-
- # POST-PROCESSING
- # Dump JSON file with hashed parameters
- hashed_params_export_path = export_dir / "hashed_export_params.json"
- create_json(hashed_params_export_path, filtered_hash_params)
- logger.info("Hashed parameters exported successfully.")
-
- return onnx_path
-
- return wrapper
-
-
def execute_command(process: str, command: str, output_file_path: Optional[str] = None):
"""
Executes the give command using subprocess.
diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py
index 0d6a078f6..cc259ee36 100644
--- a/QEfficient/utils/check_ccl_specializations.py
+++ b/QEfficient/utils/check_ccl_specializations.py
@@ -5,40 +5,160 @@
#
# -----------------------------------------------------------------------------
+from typing import List, Tuple
+
+from QEfficient.utils import constants
+from QEfficient.utils.logging_utils import logger
+
+
+# Better performance when context length is multiple of 1024 β map CL to the next multiple of 1024
+def next_multiple_of_1024(n: int) -> int:
+ """Ceil 'n' to the next multiple of 1024."""
+ if n <= 0:
+ return 0
+ return ((n + 1023) // 1024) * 1024
+
+
+def floor_to_1000(n: int) -> int:
+ """Floor 'n' to the nearest lower multiple of 1000."""
+ if n <= 0:
+ return 0
+ return (n // 1000) * 1000
+
+
+def is_power_of_two(n: int) -> bool:
+ """Return True if n is a power of two (n > 0 and n & (n - 1) == 0)."""
+ return n > 0 and (n & (n - 1)) == 0
+
+
+def build_doubling_list(start: int, limit: int, max_elements: int, last_value: int = None) -> List[int]:
+ """
+ Build a STRICT doubling list: {start, start*2, start*4, ...} up to 'limit',
+ collecting at most 'max_elements' values. Returns a list.
+ Ensure the last element equals 'last_value' by appending or replacing the final element.
+ """
+ values: List[int] = []
+ if max_elements <= 0 or start <= 0 or limit <= 0:
+ return values
+
+ element = start
+ while element <= limit and len(values) < max_elements:
+ values.append(element)
+ element *= 2
+
+ if last_value is not None and values[-1] != last_value:
+ if len(values) < max_elements:
+ values.append(last_value)
+ else:
+ values[-1] = last_value
+ return values[:max_elements]
+
+
+def automatic_ccl_generation(
+ ctx_len: int,
+ prefill_seq_len: int,
+) -> Tuple[List[int], List[int], int]:
+ """
+ Automatic Compute-Context-Length Lists Generation
+ Purpose:
+ Compute decode and prefill CCL lists based on an input context length (CL),
+ prefill sequence length, and optional pre-specified lists.
+ """
+ # Handle non-positive CL
+ if ctx_len <= 0:
+ mapped_cl = next_multiple_of_1024(1)
+ seq = [mapped_cl]
+ return seq, seq, mapped_cl
+
+ mapped_cl = next_multiple_of_1024(ctx_len)
+
+ # Early small-ctx_len case for identical lists
+ if mapped_cl <= constants.CCL_START_CTX_LEN:
+ seq = [mapped_cl]
+ return seq, seq, mapped_cl
+
+ # To limit the number of elements in CCL list, the starting point will be calculated based on context length
+ for upper_bound, (decode_start, prefill_start) in constants.CCL_START_MAP.items():
+ if mapped_cl <= upper_bound:
+ break
+
+ if prefill_seq_len > 1:
+ # ---- Decode: strict doubling up to mapped_cl, then enforce last = mapped_cl
+ decode_list = build_doubling_list(
+ start=decode_start, limit=mapped_cl, max_elements=constants.CCL_MAX_ELEMENTS_LISTS, last_value=mapped_cl
+ )
+
+ # ---- Prefill:
+ if is_power_of_two(mapped_cl):
+ # STRICT doubling only, bounded by mapped_cl
+ prefill_list = build_doubling_list(
+ start=prefill_start, limit=mapped_cl, max_elements=constants.CCL_MAX_ELEMENTS_LISTS
+ )
+ else:
+ # Doubles bounded by mapped_cl, but last must equal floor_to_1000(mapped_cl)
+ prefill_last = floor_to_1000(mapped_cl)
+ prefill_list = build_doubling_list(
+ start=prefill_start,
+ limit=mapped_cl,
+ max_elements=constants.CCL_MAX_ELEMENTS_LISTS,
+ last_value=prefill_last,
+ )
+
+ return prefill_list, decode_list, mapped_cl
+
+ elif prefill_seq_len == 1:
+ # When prefill_seq_len=1 such as in MoE models, prefilling and decoding processes can use the same specializations and we can double the length of ccl lists.
+ # Due to limitations in the number of specializations during compilation, we set the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists to 2*constants.CCL_MAX_ELEMENTS_LISTS.
+ max_elems = 2 * constants.CCL_MAX_ELEMENTS_LISTS
+
+ if mapped_cl < constants.CCL_START_CTX_LEN:
+ seq = [mapped_cl]
+ return seq, seq, mapped_cl
+
+ limit = min(mapped_cl, constants.CCL_START_CTX_LEN * (2 ** (max_elems - 1)))
+
+ seq_list = build_doubling_list(
+ start=constants.CCL_START_CTX_LEN, limit=limit, max_elements=max_elems, last_value=mapped_cl
+ )
+
+ return seq_list, seq_list, mapped_cl
+ else:
+ logger.warning("prefill_seq_len cannot be less than 1!")
+
def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
- if ccl_prefill is None or ccl_decode is None:
- return None, None
-
- if ctx_len is None:
- raise TypeError("`ctx_len` is required when loading the model with CCL.")
-
- if prefill_seq_len == 1:
- # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
- ccl_union_all = sorted(set(ccl_prefill + ccl_decode))
- ccl_union_all = [min(x, ctx_len) for x in ccl_union_all]
- return ccl_union_all, ccl_union_all
-
- # Step 1: Cap values to ctx_len
- ccl_prefill = [min(x, ctx_len) for x in ccl_prefill]
- ccl_decode = [min(x, ctx_len) for x in ccl_decode]
-
- # Step 2: Remove duplicates within each list
- ccl_prefill = list(set(ccl_prefill))
- ccl_decode = list(set(ccl_decode))
-
- # Step 3: Ensure no overlap between ccl_prefill and ccl_decode
- updated_prefill = []
- for val in ccl_prefill:
- while val in ccl_decode or val in updated_prefill:
- val -= 1
- if val < 0:
- break # Prevent negative values
- if val >= 0:
- updated_prefill.append(val)
-
- # Step 4: Sort both lists
- updated_prefill.sort()
- ccl_decode.sort()
-
- return updated_prefill, ccl_decode
+ # Automatic CCL generation: If both ccl_prefill and ccl_decode are None
+ if ccl_prefill is None and ccl_decode is None:
+ # Generate optimized context length lists for prefill and decode based on ctx_len
+ # Due to compiler limitations, ccl_prefill and ccl_decode must have distinct values
+ ccl_prefill, ccl_decode, ctx_len = automatic_ccl_generation(ctx_len, prefill_seq_len)
+ else:
+ if prefill_seq_len == 1:
+ if ccl_prefill is not None and ccl_decode is not None:
+ # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
+ ccl_union_all = sorted(set([min(x, ctx_len) for x in ccl_prefill + ccl_decode]))
+ ccl_prefill = ccl_union_all
+ ccl_decode = ccl_union_all
+ else:
+ if ccl_prefill:
+ ccl_prefill = sorted({min(x, ctx_len) for x in (ccl_prefill)})
+ if ccl_decode:
+ ccl_decode = sorted({min(x, ctx_len) for x in (ccl_decode)})
+
+ if ccl_prefill is not None and ccl_decode is not None:
+ tmp_prefill = ccl_prefill
+ ccl_prefill = []
+ for val in tmp_prefill:
+ while val in ccl_decode or val in ccl_prefill:
+ val -= 1
+ if val < 0:
+ break # Prevent negative values
+ if val >= 0:
+ ccl_prefill.append(val)
+ ccl_prefill.sort()
+
+ logger.info("CCL Configuration:")
+ logger.info(f" - Prefill context lengths: {ccl_prefill}")
+ logger.info(f" - Decode context lengths: {ccl_decode}")
+ logger.info(f" - Max context length: {ctx_len}")
+ return ccl_prefill, ccl_decode, ctx_len
diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py
index e0b003422..d0318ac3e 100644
--- a/QEfficient/utils/constants.py
+++ b/QEfficient/utils/constants.py
@@ -88,7 +88,7 @@ def get_models_dir():
SIZE_THRESHOLD_DEFAULT = 1024
-COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw"]
+COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-compile-only"]
DEFAULT_AIC_HW_VERSION = "ai100"
ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100
@@ -144,6 +144,39 @@ def get_models_dir():
# Molmo Constants
MOLMO_IMAGE_HEIGHT = 536
MOLMO_IMAGE_WIDTH = 354
+# Flux Transformer Constants
+FLUX_ONNX_EXPORT_SEQ_LENGTH = 256
+FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM = 4096
+FLUX_ADALN_HIDDEN_DIM = 3072
+FLUX_ADALN_DUAL_BLOCK_CHUNKS = 12 # 6 chunks for norm1 + 6 chunks for norm1_context
+FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3
+FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM
+
+# Wan Transformer Constants
+WAN_TEXT_EMBED_DIM = 5120
+WAN_PROJECTION_DIM = 6
+WAN_ONNX_EXPORT_BATCH_SIZE = 1
+WAN_ONNX_EXPORT_FRAMES = 81
+WAN_ONNX_EXPORT_LATENT_FRAMES = 21
+WAN_ONNX_EXPORT_SEQ_LEN = 512
+WAN_ONNX_EXPORT_ROTARY_DIM = 128
+WAN_DIT_OUT_CHANNELS = 64
+# Wan dims for 180p
+WAN_ONNX_EXPORT_CL_180P = 5040
+WAN_ONNX_EXPORT_LATENT_HEIGHT_180P = 24
+WAN_ONNX_EXPORT_LATENT_WIDTH_180P = 40
+WAN_ONNX_EXPORT_HEIGHT_180P = 192
+WAN_ONNX_EXPORT_WIDTH_180P = 320
+
+# For the purpose of automatic CCL lists generation, to limit the number of elements in CCL list, the starting point will be calculated based on context length
+CCL_START_MAP = {
+ 32768: (4096, 4000),
+ 65536: (8192, 8000),
+ float("inf"): (16384, 16000),
+}
+# Limitation in the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists during automatic lists generation process.
+CCL_MAX_ELEMENTS_LISTS = 5
+CCL_START_CTX_LEN = 4096
class Constants:
diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py
new file mode 100644
index 000000000..638f55921
--- /dev/null
+++ b/QEfficient/utils/export_utils.py
@@ -0,0 +1,219 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import copy
+import inspect
+import re
+import warnings
+from pathlib import Path
+from typing import Dict
+
+from QEfficient.base.onnx_transforms import CustomOpTransform, RenameFunctionOutputsTransform
+from QEfficient.transformers.cache_utils import InvalidIndexProvider
+from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
+from QEfficient.utils.cache import QEFF_HOME
+from QEfficient.utils.hash_utils import create_export_hash
+from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches
+
+
+def export_wrapper(func):
+ """
+ Decorator for export methods that orchestrates the complete export lifecycle.
+
+ Responsibilities:
+ 1. Prepare export directory structure
+ 2. Generate reproducible hash for export configuration
+ 3. Setup ONNX subfunction environment (if enabled)
+ 4. Execute the wrapped export function
+ 5. Cleanup subfunction environment (if enabled)
+ 6. Save export metadata
+
+ Args:
+ func: The export method to wrap (typically _export)
+
+ Returns:
+ Wrapped function with complete export lifecycle management
+ """
+
+ def wrapper(self, *args, **kwargs):
+ # 1. Setup ONNX subfunctions if requested
+ if use_onnx_subfunctions := kwargs.pop("use_onnx_subfunctions", False):
+ args, kwargs = _setup_onnx_subfunctions(self, args, kwargs)
+
+ # 2. Prepare export directory
+ export_dir = _prepare_export_directory(self, kwargs)
+
+ # 3. Generate hash and finalize export directory path
+ export_hash, filtered_hash_params = _generate_export_hash(self, args, kwargs, func)
+ export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)
+ kwargs["export_dir"] = export_dir
+ self.export_hash = export_hash
+
+ # 4. Execute the actual export
+ onnx_path = func(self, *args, **kwargs)
+
+ # 5. Save export metadata
+ _save_export_metadata(export_dir, filtered_hash_params)
+
+ # 6. Always cleanup subfunctions if they were setup
+ if use_onnx_subfunctions:
+ _cleanup_onnx_subfunctions(self)
+
+ return onnx_path
+
+ return wrapper
+
+
+def _prepare_export_directory(qeff_model, kwargs) -> Path:
+ """
+ Prepare and return the base export directory path.
+
+ Args:
+ qeff_model: The QEff model instance
+ kwargs: Keyword arguments containing optional export_dir
+
+ Returns:
+ Path object for the base export directory
+ """
+ export_dir = kwargs.get("export_dir", None)
+ parent_dir = qeff_model.model_architecture or qeff_model.model_name
+ return Path(export_dir or (QEFF_HOME / parent_dir / qeff_model.model_name))
+
+
+def _generate_export_hash(qeff_model, args, kwargs, func):
+ """
+ Generate export hash from model parameters and export arguments.
+
+ The hash ensures reproducibility and prevents conflicts between
+ different export configurations.
+
+ Args:
+ qeff_model: The QEff model instance
+ args: Positional arguments to the export function
+ kwargs: Keyword arguments to the export function
+ func: The export function being wrapped
+
+ Returns:
+ Tuple of (export_hash: str, filtered_hash_params: dict)
+ """
+ # Extract function signature
+ original_sig = inspect.signature(func)
+ params = list(original_sig.parameters.values())[1:] # Skip 'self'
+ new_sig = inspect.Signature(params)
+ # Bind all arguments
+ bound_args = new_sig.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+ all_args = bound_args.arguments
+
+ # Use the model's current configuration for hashing to ensure any post-load modifications are captured
+ # TODO: Replace with get_model_config property of modeling classes and remove the if-else
+ # Determine the config dict to use, preferring .to_diff_dict() if available
+ if hasattr(qeff_model.model, "config") and hasattr(qeff_model.model.config, "to_diff_dict"):
+ config_val = qeff_model.model.config.to_diff_dict()
+ elif hasattr(qeff_model.model, "model") and hasattr(qeff_model.model.model.config, "to_diff_dict"):
+ config_val = qeff_model.model.model.config.to_diff_dict()
+ else:
+ config_val = qeff_model.model.config
+
+ copy_of_hash_params = copy.deepcopy(qeff_model.hash_params)
+ copy_of_hash_params.update(
+ {
+ "config": config_val,
+ }
+ )
+ # Generate hash from relevant parameters
+ export_hash, filtered_hash_params = create_export_hash(
+ model_params=copy_of_hash_params,
+ output_names=all_args.get("output_names"),
+ dynamic_axes=all_args.get("dynamic_axes"),
+ export_kwargs=all_args.get("export_kwargs", None),
+ onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None),
+ )
+
+ return export_hash, filtered_hash_params
+
+
+def _setup_onnx_subfunctions(qeff_model, args, kwargs):
+ """
+ Setup ONNX subfunction export environment.
+
+ This function prepares the model and environment for exporting with
+ ONNX subfunctions enabled. It:
+ - Applies necessary torch patches
+ - Modifies output names for subfunction compatibility
+ - Adds subfunction-specific ONNX transforms
+ - Updates export kwargs with module classes
+
+ Args:
+ qeff_model: The QEff model instance
+ kwargs: Export keyword arguments (modified in-place).
+ """
+ warnings.warn(
+ "The subfunction feature is experimental. Please note that using compile "
+ "consecutively with and without subfunction may produce inconsistent results."
+ )
+
+ # Apply torch patches for subfunction support
+ apply_torch_patches()
+ InvalidIndexProvider.SUBFUNC_ENABLED = True
+ # Transform output names for subfunction compatibility
+ if "output_names" in kwargs:
+ kwargs["output_names"] = [
+ re.sub("_RetainedState", "_InternalRetainedState", name) for name in kwargs["output_names"]
+ ]
+ else:
+ args = list(args)
+ args[1] = [re.sub("_RetainedState", "_InternalRetainedState", name) for name in args[1]]
+ args = tuple(args)
+ # Add subfunction-specific ONNX transforms
+ qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform)
+ qeff_model._onnx_transforms.append(CustomOpTransform)
+
+ # TODO: Handle this in the modelling class QEFFTransformersBase,remove from here. Refer diffusers implementation
+ kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model)
+ return args, kwargs
+
+
+def _cleanup_onnx_subfunctions(qeff_model):
+ """
+ Cleanup ONNX subfunction export environment.
+
+ Restores the model and environment to pre-subfunction state by:
+ - Undoing torch patches
+ - Resetting InvalidIndexProvider flag
+ - Restoring original ONNX transforms list
+
+ Args:
+ qeff_model: The QEff model instance
+
+ Note:
+ This function is called in a finally block to ensure cleanup
+ even if export fails. Errors during cleanup are logged but
+ not re-raised to avoid masking the original exception.
+ """
+ # Undo torch patches
+ undo_torch_patches()
+ InvalidIndexProvider.SUBFUNC_ENABLED = False
+ qeff_model._onnx_transforms.remove(RenameFunctionOutputsTransform)
+ qeff_model._onnx_transforms.remove(CustomOpTransform)
+
+
+def _save_export_metadata(export_dir: Path, filtered_hash_params: Dict):
+ """
+ Save export metadata to JSON file for reproducibility.
+
+ Args:
+ export_dir: Directory where the export was saved
+ filtered_hash_params: Dictionary of parameters used for hashing
+ """
+ # Import here to avoid circular dependency
+ from QEfficient.utils._utils import create_json
+
+ hashed_params_path = export_dir / "hashed_export_params.json"
+ create_json(hashed_params_path, filtered_hash_params)
+ logger.info("Hashed parameters exported successfully.")
diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py
index 948b72e6a..10e6686d0 100644
--- a/QEfficient/utils/hash_utils.py
+++ b/QEfficient/utils/hash_utils.py
@@ -14,7 +14,8 @@
def json_serializable(obj):
if isinstance(obj, set):
- return sorted(obj)
+ # Convert set to a sorted list of strings for consistent hashing
+ return sorted([cls.__name__ if isinstance(cls, type) else str(cls) for cls in obj])
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
@@ -55,8 +56,6 @@ def create_export_hash(**kwargs):
export_params = {}
export_params["output_names"] = kwargs.get("output_names")
export_params["dynamic_axes"] = kwargs.get("dynamic_axes")
- if kwargs.get("use_onnx_subfunctions"):
- export_params["use_onnx_subfunctions"] = True
export_hash_params["export_params"] = export_params
export_kwargs = kwargs.get("export_kwargs")
@@ -68,5 +67,4 @@ def create_export_hash(**kwargs):
export_hash_params.update(onnx_transform_kwargs)
if export_hash_params.get("peft_config") is not None and not isinstance(export_hash_params["peft_config"], dict):
export_hash_params["peft_config"] = export_hash_params["peft_config"].to_dict()
-
return hash_dict_params(export_hash_params), export_hash_params
diff --git a/QEfficient/utils/sampler_utils.py b/QEfficient/utils/sampler_utils.py
index 6fb1b326f..82a0843bc 100644
--- a/QEfficient/utils/sampler_utils.py
+++ b/QEfficient/utils/sampler_utils.py
@@ -5,13 +5,18 @@
#
# -----------------------------------------------------------------------------
-from typing import Optional, Set
+from typing import Dict, List, Optional, Set
+import torch
+
+from QEfficient.utils import constants
from QEfficient.utils.constants import Constants
from QEfficient.utils.logging_utils import logger
-def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[bool] = None) -> bool:
+def validate_sampler_inputs(
+ session_inputs: Set[str], include_sampler: Optional[bool] = None, include_guided_decoding: Optional[bool] = None
+) -> bool:
"""
Validates whether the `QAICInferenceSession` inputs match inputs required for on-device sampling.
@@ -28,7 +33,7 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[
ValueError if partial support is detected or if user intent conflicts with QPC capabilities.
"""
- sampler_inputs = Constants.SAMPLER_INPUTS
+ sampler_inputs = Constants.SAMPLER_INPUTS | ({"token_bitmasks"} if include_guided_decoding else set())
count = len(sampler_inputs & session_inputs)
session_includes_sampler = True
@@ -56,3 +61,92 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[
)
return session_includes_sampler
+
+
+def get_sampling_inputs_and_outputs(
+ example_inputs: Dict[str, torch.Tensor],
+ output_names: List[str],
+ dynamic_axes: Dict[str, Dict[int, str]],
+ continuous_batching: bool,
+ vocab_size: int,
+ qaic_config: Dict,
+):
+ """
+ Updates the example inputs, output names, and dynamic axes to include
+ parameters relevant for on-device sampling during ONNX export.
+
+ Parameters
+ ----------
+ example_inputs : Dict[str, torch.Tensor]
+ Current dictionary of example inputs.
+ output_names : List[str]
+ Current list of output names.
+ dynamic_axes : Dict[str, Dict[int, str]]
+ Current dictionary of dynamic axes configurations.
+ continuous_batching : bool
+ Whether this model will be used for continuous batching in the future.
+ vocab_size: int
+ Vocabulary size for this model.
+ qaic_config : Dict
+ QAIC config dictionary.
+
+ Returns
+ -------
+ Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
+ Updated example inputs, output names, and dynamic axes including
+ sampling-related parameters.
+ """
+ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+ fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
+ seq_len: int = example_inputs["input_ids"].shape[-1]
+
+ example_inputs["last_accepted_output_tokens"] = torch.zeros((bs, seq_len), dtype=torch.int64)
+ dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}
+
+ example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
+ (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool
+ )
+ dynamic_axes["past_repetition_penalty_buffer"] = {
+ 0: "full_batch_size" if continuous_batching else "batch_size",
+ }
+ output_names.append("past_repetition_penalty_buffer_RetainedState")
+
+ example_inputs["repetition_penalties"] = (
+ torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
+ )
+ dynamic_axes["repetition_penalties"] = {0: "batch_size"}
+
+ example_inputs["past_presence_penalty_buffer"] = torch.zeros(
+ (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool
+ )
+ dynamic_axes["past_presence_penalty_buffer"] = {
+ 0: "full_batch_size" if continuous_batching else "batch_size",
+ }
+ output_names.append("past_presence_penalty_buffer_RetainedState")
+
+ example_inputs["presence_penalties"] = (
+ torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
+ )
+ dynamic_axes["presence_penalties"] = {0: "batch_size"}
+
+ example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
+ dynamic_axes["temperatures"] = {0: "batch_size"}
+
+ max_top_k_ids = qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
+ example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
+ dynamic_axes["top_ks"] = {0: "batch_size"}
+
+ example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
+ dynamic_axes["top_ps"] = {0: "batch_size"}
+
+ example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
+ dynamic_axes["min_ps"] = {0: "batch_size"}
+
+ example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
+ dynamic_axes["random_numbers"] = {0: "batch_size"}
+
+ if qaic_config.get("include_guided_decoding", False):
+ example_inputs["token_bitmasks"] = torch.zeros((bs, vocab_size), dtype=torch.bool)
+ dynamic_axes["token_bitmasks"] = {0: "batch_size"}
+
+ return example_inputs, output_names, dynamic_axes
diff --git a/docs/image/girl_laughing.png b/docs/image/girl_laughing.png
new file mode 100644
index 000000000..9e58da61d
Binary files /dev/null and b/docs/image/girl_laughing.png differ
diff --git a/examples/diffusers/flux/README.md b/examples/diffusers/flux/README.md
new file mode 100644
index 000000000..2a3c1605f
--- /dev/null
+++ b/examples/diffusers/flux/README.md
@@ -0,0 +1,243 @@
+# FLUX.1-schnell Image Generation Examples
+
+This directory contains examples demonstrating how to use the QEffFluxPipeline to generate images using the FLUX.1-schnell model from Black Forest Labs.
+
+## Overview
+
+FLUX.1-schnell is a fast, distilled version of the FLUX.1 text-to-image model optimized for speed with minimal quality loss. These examples show how to leverage Qualcomm Cloud AI 100 acceleration for efficient image generation.
+
+## Files
+
+- **`flux_1_schnell.py`** - Basic example showing simple image generation
+- **`flux_1_shnell_custom.py`** - Advanced example with customization options
+- **`flux_config.json`** - Configuration file for pipeline modules
+
+## Quick Start
+
+### Basic Usage
+
+The simplest way to generate images with FLUX.1-schnell:
+
+```python
+from QEfficient import QEffFluxPipeline
+import torch
+
+# Initialize pipeline
+pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+
+# Generate image
+output = pipeline(
+ prompt="A laughing girl",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+
+# Save image
+output.images[0].save("girl_laughing.png")
+```
+
+Run the basic example:
+```bash
+python flux_1_schnell.py
+```
+
+## Advanced Customization
+
+The `flux_1_shnell_custom.py` example demonstrates several advanced features:
+
+### 1. Custom Model Components
+
+You can provide custom text encoders, transformers, and tokenizers:
+
+```python
+pipeline = QEffFluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ text_encoder=custom_text_encoder,
+ transformer=custom_transformer,
+ tokenizer=custom_tokenizer,
+)
+```
+
+### 2. Custom Scheduler
+
+Replace the default scheduler with your own:
+
+```python
+pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config)
+```
+
+### 3. Reduce Model Layers for Faster Inference
+
+Trade quality for speed by reducing transformer blocks:
+
+```python
+original_blocks = pipeline.transformer.model.transformer_blocks
+org_single_blocks = pipeline.transformer.model.single_transformer_blocks
+pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]])
+pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]])
+pipeline.transformer.model.config['num_layers'] = 1
+pipeline.transformer.model.config['num_single_layers'] = 1
+```
+
+### 4. Pre-compile with Custom Configuration
+
+Compile the model separately before generation:
+
+```python
+pipeline.compile(
+ compile_config="examples/diffusers/flux/flux_config.json",
+ height=512,
+ width=512,
+ use_onnx_subfunctions=False
+)
+```
+
+### 5. Runtime Configuration
+
+Use custom configuration during generation:
+
+```python
+output = pipeline(
+ prompt="A girl laughing",
+ custom_config_path="examples/diffusers/flux/flux_config.json",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+```
+
+Run the advanced example:
+```bash
+python flux_1_shnell_custom.py
+```
+
+## Configuration File
+
+The `flux_config.json` file controls compilation and execution settings for each pipeline module:
+
+### Module Structure
+
+The configuration includes four main modules:
+
+1. **text_encoder** (CLIP) - Encodes text prompts (77 token sequence)
+2. **text_encoder_2** (T5) - Secondary text encoder (256 token sequence)
+3. **transformer** - Core diffusion transformer model
+4. **vae_decoder** - Decodes latents to images
+
+### Configuration Parameters
+
+Each module has three sections:
+
+#### Specializations
+- `batch_size`: Batch size for inference
+- `seq_len`: Sequence length for text encoders
+- `steps`: Number of inference steps (transformer only)
+- `channels`: Number of channels (VAE decoder only)
+
+#### Compilation
+- `onnx_path`: Path to pre-exported ONNX model (null for auto-export)
+- `compile_dir`: Directory for compiled artifacts (null for auto-generation)
+- `mdp_ts_num_devices`: Number of devices for model data parallelism
+- `mxfp6_matmul`: Enable MXFP6 quantization for matrix multiplication
+- `convert_to_fp16`: Convert model to FP16 precision
+- `aic_num_cores`: Number of AI cores to use
+- `mos`: Multi-output streaming (transformer only)
+- `mdts-mos`: Multi-device tensor slicing with MOS (transformer only)
+- `aic-enable-depth-first`: Enable depth-first compilation (VAE only)
+
+#### Execute
+- `device_ids`: List of device IDs to use (null for auto-selection)
+
+### Example Configuration Snippet
+
+```json
+{
+ "transformer": {
+ "specializations": {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation": {
+ "mdp_ts_num_devices": 4,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute": {
+ "device_ids": null
+ }
+ }
+}
+```
+
+## Key Parameters
+
+### Generation Parameters
+
+- **`prompt`** (str): Text description of the image to generate
+- **`height`** (int): Output image height in pixels (default: 1024)
+- **`width`** (int): Output image width in pixels (default: 1024)
+- **`guidance_scale`** (float): Classifier-free guidance scale (0.0 for schnell)
+- **`num_inference_steps`** (int): Number of denoising steps (4 recommended for schnell)
+- **`max_sequence_length`** (int): Maximum text sequence length (256 recommended)
+- **`generator`** (torch.Generator): Random seed for reproducibility
+- **`parallel_compile`** (bool): Enable parallel compilation of modules
+- **`use_onnx_subfunctions`** (bool): Enable ONNX modular export (experimental)
+
+### Performance Tuning
+
+- **Faster inference**: Reduce `num_inference_steps` or model layers
+- **Better quality**: Increase `num_inference_steps` or use full model
+- **Memory optimization**: Adjust `mdp_ts_num_devices` in config
+- **Precision trade-offs**: Toggle `mxfp6_matmul` and `convert_to_fp16`
+
+## Output
+
+The pipeline returns an output object containing:
+- `images`: List of generated PIL Image objects
+- Performance metrics (timing information)
+
+Example output:
+```python
+print(output) # Displays performance information
+image = output.images[0] # Access the generated image
+image.save("output.png") # Save to disk
+```
+
+## Hardware Requirements
+
+- Qualcomm Cloud AI 100 accelerator
+- Sufficient memory for model compilation and execution
+- Multiple devices recommended for optimal transformer performance (see `mdp_ts_num_devices`)
+
+## Notes
+
+- FLUX.1-schnell is optimized for 4-step generation with `guidance_scale=0.0`
+- The transformer module benefits most from multi-device parallelism
+- ONNX subfunctions (`use_onnx_subfunctions=True`) is experimental and may improve compile time but is not recommended for production use
+- Custom configurations allow fine-tuning for specific hardware setups
+
+## Troubleshooting
+
+- **Out of memory**: Reduce image dimensions or increase `mdp_ts_num_devices`
+- **Slow compilation**: Enable `parallel_compile=True`
+- **Quality issues**: Ensure using recommended parameters (4 steps, guidance_scale=0.0)
+- **Device errors**: Check `device_ids` in config or set to `null` for auto-selection
+
+## References
+
+- [FLUX.1 Model Card](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
+- [QEfficient Documentation](../../../README.md)
+- [Diffusers Pipeline Guide](../../README.md)
diff --git a/examples/diffusers/flux/flux_1_schnell.py b/examples/diffusers/flux/flux_1_schnell.py
new file mode 100644
index 000000000..46f26bb6b
--- /dev/null
+++ b/examples/diffusers/flux/flux_1_schnell.py
@@ -0,0 +1,45 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+FLUX.1-schnell Image Generation Example
+
+This example demonstrates how to use the QEffFluxPipeline to generate images
+using the FLUX.1-schnell model from Black Forest Labs. FLUX.1-schnell is a
+fast, distilled version of the FLUX.1 text-to-image model optimized for
+speed with minimal quality loss.
+"""
+
+import torch
+
+from QEfficient import QEffFluxPipeline
+
+# Initialize the FLUX.1-schnell pipeline from pretrained weights
+pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+
+# Generate an image from a text prompt
+# use_onnx_subfunctions=True enables ONNX-based optimizations for faster compilation
+output = pipeline(
+ prompt="A laughing girl",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+
+# Extract the generated image from the output
+image = output.images[0]
+
+# Save the generated image to disk
+image.save("girl_laughing.png")
+
+# Print the output object (contains perf info)
+print(output)
diff --git a/examples/diffusers/flux/flux_1_shnell_custom.py b/examples/diffusers/flux/flux_1_shnell_custom.py
new file mode 100644
index 000000000..201ebe659
--- /dev/null
+++ b/examples/diffusers/flux/flux_1_shnell_custom.py
@@ -0,0 +1,113 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+FLUX.1 Schnell Custom Configuration Example
+
+This example demonstrates how to customize the FLUX.1 model with various options:
+1. Custom image dimensions (height/width)
+2. Custom transformer model and text encoder
+3. Custom scheduler configuration
+4. Reduced model layers for faster inference
+5. Custom compilation settings
+6. Custom runtime configuration via JSON config file
+
+Use this example to learn how to fine-tune FLUX.1 for your specific needs.
+"""
+
+import torch
+
+from QEfficient import QEffFluxPipeline
+
+# ============================================================================
+# PIPELINE INITIALIZATION WITH CUSTOM PARAMETERS
+# ============================================================================
+
+# Option 1: Basic initialization with default parameters
+pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+# Option 2: Advanced initialization with custom modules
+# Uncomment and modify to use your own custom components:
+#
+# pipeline = QEffFluxPipeline.from_pretrained(
+# "black-forest-labs/FLUX.1-schnell",
+# text_encoder=custom_text_encoder, # Your custom CLIP text encoder
+# transformer=custom_transformer, # Your custom transformer model
+# tokenizer=custom_tokenizer, # Your custom tokenizer
+# )
+
+# ============================================================================
+# OPTIONAL: CUSTOM SCHEDULER CONFIGURATION
+# ============================================================================
+# Uncomment to use a custom scheduler (e.g., different sampling methods):
+#
+# pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config)
+
+# ============================================================================
+# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE
+# ============================================================================
+# Reduce the number of transformer blocks to speed up image generation.
+#
+# Trade-off: Faster inference but potentially lower image quality
+# Use case: Quick testing, prototyping, or when speed is critical
+#
+# Uncomment the following lines to use only the first transformer block:
+#
+# original_blocks = pipeline.transformer.model.transformer_blocks
+# org_single_blocks = pipeline.transformer.model.single_transformer_blocks
+# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]])
+# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]])
+# pipeline.transformer.model.config['num_layers'] = 1
+# pipeline.transformer.model.config['num_single_layers'] = 1
+
+# ============================================================================
+# OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION
+# ============================================================================
+# Pre-compile the model for optimized performance on target hardware.
+#
+# When to use:
+# - When you want to compile the model separately before generation
+# - When you need to skip image generation and only prepare the model
+#
+# NOTE-1: If compile_config is not specified, the default configuration from
+# QEfficient/diffusers/pipelines/flux/flux_config.json will be used
+#
+# NOTE-2: use_onnx_subfunctions=True enables modular ONNX export optimizations (Experimental so not recommended)
+# This feature improves export performance by breaking down the model into smaller,
+# more manageable ONNX functions, which can lead to improve compile time.
+# Uncomment to compile with a custom configuration:
+# pipeline.compile(
+# compile_config="examples/diffusers/flux/flux_config.json",
+# height=512,
+# width=512,
+# use_onnx_subfunctions=False
+# )
+
+# ============================================================================
+# IMAGE GENERATION WITH CUSTOM RUNTIME CONFIGURATION
+# ============================================================================
+# Generate an image using the configured pipeline.
+#
+# Note: Use of custom_config_path provides flexibility to set device_ids for each
+# module, so you can skip the separate pipeline.compile() step.
+
+output = pipeline(
+ prompt="A laughing girl",
+ custom_config_path="examples/diffusers/flux/flux_config.json",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+
+image = output.images[0]
+# Save the generated image to disk
+image.save("laughing_girl.png")
+print(output)
diff --git a/examples/diffusers/flux/flux_config.json b/examples/diffusers/flux/flux_config.json
new file mode 100644
index 000000000..73b92265f
--- /dev/null
+++ b/examples/diffusers/flux/flux_config.json
@@ -0,0 +1,99 @@
+{
+ "description": "Default configuration for Flux pipeline",
+
+ "modules":
+ {
+ "text_encoder":
+ {
+ "specializations":{
+ "batch_size": 1,
+ "seq_len": 77
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+
+ },
+ "text_encoder_2":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only": true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "transformer":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 4,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts-mos": 1,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "vae_decoder":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "channels": 16
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "aic-enable-depth-first": true,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ }
+ }
+}
diff --git a/examples/diffusers/wan/README.md b/examples/diffusers/wan/README.md
new file mode 100644
index 000000000..b90bf3908
--- /dev/null
+++ b/examples/diffusers/wan/README.md
@@ -0,0 +1,249 @@
+# WAN 2.2 Text-to-Video Generation Examples
+
+This directory contains examples demonstrating how to use the QEffWanPipeline to generate videos using the WAN 2.2 text-to-video model with Lightning LoRA optimization.
+
+## Overview
+
+WAN 2.2 is a text-to-video diffusion model that uses dual-stage processing for high-quality video generation. These examples show how to leverage Qualcomm Cloud AI 100 acceleration for efficient video generation with Lightning LoRA for fast 4-step inference.
+
+## Files
+
+- **`wan_lightning.py`** - Complete example with Lightning LoRA for fast video generation
+- **`wan_config.json`** - Configuration file for transformer module compilation
+
+## Quick Start
+
+### Basic Usage
+
+The simplest way to generate videos with WAN 2.2 Lightning:
+### 1. Load Model
+```python
+from QEfficient import QEffWanPipeline
+import torch
+from diffusers.utils import export_to_video
+
+# Initialize pipeline
+pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers")
+```
+
+### 2. Lightning LoRA Integration
+
+Load high and low noise LoRA adapters for fast 4-step generation:
+
+```python
+from huggingface_hub import hf_hub_download
+from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
+import safetensors.torch
+
+# Download Lightning LoRAs
+high_noise_lora_path = hf_hub_download(
+ repo_id="lightx2v/Wan2.2-Lightning",
+ filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors",
+)
+low_noise_lora_path = hf_hub_download(
+ repo_id="lightx2v/Wan2.2-Lightning",
+ filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors",
+)
+
+# Load and apply LoRAs
+def load_wan_lora(path: str):
+ return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path))
+
+pipeline.transformer.model.transformer_high.load_lora_adapter(
+ load_wan_lora(high_noise_lora_path), adapter_name="high_noise"
+)
+pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0])
+
+pipeline.transformer.model.transformer_low.load_lora_adapter(
+ load_wan_lora(low_noise_lora_path), adapter_name="low_noise"
+)
+pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0])
+```
+
+
+### 3. Compile API
+
+To compile the model for desired resolution:
+
+```python
+# Compile with custom configuration
+pipeline.compile(
+ compile_config="examples/diffusers/wan/wan_config.json",
+ parallel=True,
+ height=480,
+ width=832,
+ num_frames=81,
+ use_onnx_subfunctions=False,
+)
+```
+
+### 4. Generate video
+```python
+output = pipeline(
+ prompt="A cat playing in a sunny garden",
+ num_frames=81,
+ height=480,
+ width=832,
+ guidance_scale=1.0,
+ num_inference_steps=4,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+
+# Export video
+frames = output.images[0]
+export_to_video(frames, "cat_garden.mp4", fps=16)
+```
+
+Run the Lightning example:
+```bash
+python wan_lightning.py
+```
+
+## Advanced Customization
+
+
+### 1. Reduce Model Layers for Faster Inference
+
+
+```python
+# Reduce to 2 layers for faster inference
+pipeline.transformer.model.transformer_high.config.num_layers = 2
+pipeline.transformer.model.transformer_low.config.num_layers = 2
+
+original_blocks = pipeline.transformer.model.transformer_high.blocks
+org_blocks = pipeline.transformer.model.transformer_low.blocks
+
+pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList(
+ [original_blocks[i] for i in range(0, pipeline.transformer.model.transformer_high.config.num_layers)]
+)
+pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList(
+ [org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config.num_layers)]
+)
+```
+
+### 2. To Run with Blocking
+
+Use environment variables to enable attention blocking:
+
+```bash
+# For 180p Generation (192x320) with HKV blocking
+ATTENTION_BLOCKING_MODE=kv head_block_size=16 num_kv_blocks=3 python wan_lightning.py
+
+# For 480p Generation (480x832) with HQKV blocking
+ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=21 num_q_blocks=2 python wan_lightning.py
+
+# for 720P Generation (720x1280) with HQKV blocking
+ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=48 num_q_blocks=5 python wan_lightning.py
+```
+
+### Blocking Modes
+
+Head blocking is common in all modes
+
+- **`kv`**: Block key-value processing (along with Head blocking)
+- **`q`**: Block query processing (along with Head blocking)
+- **`qkv`**: Block query, key, and value (along with Head blocking)
+- **`default`**: Head-only blocking
+
+
+## Configuration File
+
+The `wan_config.json` file controls compilation settings for the transformer module:
+
+### Module Structure
+
+The configuration includes dual specializations for WAN's high and low noise models:
+
+```json
+{
+ "transformer": {
+ "specializations":[
+ {
+ "batch_size":"1",
+ "cl":"5040",
+ "latent_height":"24",
+ "latent_width":"40",
+ "model_type":"1",
+ "num_channels":"16",
+ "num_frames":"21",
+ "sequence_length":"512",
+ "steps":"1"
+ },
+ {
+ "batch_size":"1",
+ "cl":"5040",
+ "latent_height":"24",
+ "latent_width":"40",
+ "model_type":"2",
+ "num_channels":"16",
+ "num_frames":"21",
+ "sequence_length":"512",
+ "steps":"1"
+ }
+ ]
+}
+}
+```
+
+### Configuration Parameters
+
+#### Specializations
+- `batch_size`: Batch size for inference
+- `num_channels`: Number of latent channels (16 for WAN)
+- `num_frames`: Number of latent frames (21 for 81 input frames)
+- `latent_height`/`latent_width`: Latent space dimensions
+- `cl`: Compressed latent dimension for transformer
+- `sequence_length` : Sequence length of text encoder 512
+- `model_type`: 1 for high noise model, 2 for low noise model
+
+#### Compilation
+- `mdp_ts_num_devices`: Number of devices for model parallelism (16 recommended)
+- `mxfp6_matmul`: Enable MXFP6 quantization for matrix multiplication
+- `convert_to_fp16`: Convert model to FP16 precision
+- `aic_num_cores`: Number of AI cores to use (16 recommended)
+- `mos`: Degree of weight splitting done across cores (1 is recommended)
+- `mdts_mos`: Degree of weight splitting done across multi-device tensor slices (1 is recommended)
+
+## Key Parameters
+
+### Generation Parameters
+
+- **`prompt`** (str): Text description of the video to generate
+- **`num_frames`** (int): Number of video frames (default: 81)
+- **`height`** (int): Output video height in pixels (default: 480)
+- **`width`** (int): Output video width in pixels (default: 832)
+- **`guidance_scale`** (float): Guidance scale for high noise stage (1.0 for Lightning)
+- **`guidance_scale_2`** (float): Guidance scale for low noise stage (1.0 for Lightning)
+- **`num_inference_steps`** (int): Number of denoising steps (4 for Lightning)
+- **`generator`** (torch.Generator): Random seed for reproducibility
+- **`parallel_compile`** (bool): Enable parallel compilation of modules
+- **`use_onnx_subfunctions`** (bool): Enable ONNX modular export
+
+
+## Output
+
+The pipeline returns an output object containing:
+- `images`: List of video frames as PIL Image objects
+- Performance metrics (timing information)
+
+Example output:
+```python
+print(output) # Displays performance information
+frames = output.images[0] # Access the generated video frames
+export_to_video(frames, "output.mp4", fps=16) # Export to MP4
+```
+
+## Notes
+
+- WAN 2.2 Lightning is optimized for 4-step generation with `guidance_scale=1.0`
+- The transformer uses dual-stage processing (high/low noise models)
+- Attention blocking is essential for higher resolutions (480p+)
+
+
+## References
+
+- [WAN 2.2 Model Card](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
+- [Lightning LoRA](https://huggingface.co/lightx2v/Wan2.2-Lightning)
+- [QEfficient Documentation](../../../README.md)
diff --git a/examples/diffusers/wan/wan_config.json b/examples/diffusers/wan/wan_config.json
new file mode 100644
index 000000000..7e752ba14
--- /dev/null
+++ b/examples/diffusers/wan/wan_config.json
@@ -0,0 +1,37 @@
+{
+ "description": "Default configuration for Wan pipeline with unified transformer (model_type: 1 for high noise; model_type:2 for low noise)",
+ "model_type": "wan",
+ "modules": {
+ "transformer": {
+ "specializations": [
+ {
+ "batch_size": "1",
+ "num_channels": "16",
+ "steps": "1",
+ "sequence_length": "512",
+ "model_type": 1
+ },
+ {
+ "batch_size": "1",
+ "num_channels": "16",
+ "steps": "1",
+ "sequence_length": "512",
+ "model_type": 2
+ }
+ ],
+ "compilation": {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 16,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts_mos": 1
+ },
+ "execute": {
+ "device_ids": null
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/examples/diffusers/wan/wan_lightning.py b/examples/diffusers/wan/wan_lightning.py
new file mode 100644
index 000000000..691da651f
--- /dev/null
+++ b/examples/diffusers/wan/wan_lightning.py
@@ -0,0 +1,62 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+import safetensors.torch
+import torch
+from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
+from diffusers.utils import export_to_video
+from huggingface_hub import hf_hub_download
+
+from QEfficient import QEffWanPipeline
+
+# Load the pipeline
+pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers")
+
+# Download the LoRAs
+high_noise_lora_path = hf_hub_download(
+ repo_id="lightx2v/Wan2.2-Lightning",
+ filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors",
+)
+low_noise_lora_path = hf_hub_download(
+ repo_id="lightx2v/Wan2.2-Lightning",
+ filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors",
+)
+
+
+# LoRA conversion
+def load_wan_lora(path: str):
+ return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path))
+
+
+# Load into the transformers
+pipeline.transformer.model.transformer_high.load_lora_adapter(
+ load_wan_lora(high_noise_lora_path), adapter_name="high_noise"
+)
+pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0])
+pipeline.transformer.model.transformer_low.load_lora_adapter(
+ load_wan_lora(low_noise_lora_path), adapter_name="low_noise"
+)
+pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0])
+
+
+prompt = "In a warmly lit living room, an elderly man with gray hair sits in a wooden armchair adorned with a blue cushion. He wears a gray cardigan over a white shirt, engrossed in reading a book. As he turns the pages, he subtly adjusts his posture, ensuring his glasses stay in place. He then removes his glasses, holding them in his hand, and turns his head to the right, maintaining his grip on the book. The soft glow of a bedside lamp bathes the scene, creating a calm and serene atmosphere, with gentle shadows enhancing the intimate setting."
+
+output = pipeline(
+ prompt=prompt,
+ num_frames=81,
+ guidance_scale=1.0,
+ guidance_scale_2=1.0,
+ num_inference_steps=4,
+ generator=torch.manual_seed(0),
+ custom_config_path="examples/diffusers/wan/wan_config.json",
+ height=480,
+ width=832,
+ use_onnx_subfunctions=True,
+ parallel_compile=True,
+)
+frames = output.images[0]
+export_to_video(frames, "output_t2v.mp4", fps=16)
+print(output)
diff --git a/examples/diffusers/wan/wan_lightning_custom.py b/examples/diffusers/wan/wan_lightning_custom.py
new file mode 100644
index 000000000..a60d57bb6
--- /dev/null
+++ b/examples/diffusers/wan/wan_lightning_custom.py
@@ -0,0 +1,162 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+Wan2.2-Lightning Custom Configuration Example
+
+This example demonstrates how to customize the Wan2.2-Lightning model with various options:
+1. Custom video dimensions (height/width) and frame count
+2. Custom scheduler configuration
+3. Reduced model layers for faster inference
+4. Custom compilation settings
+5. Custom runtime configuration via JSON config file
+6. LoRA adapter loading and configuration
+
+Use this example to learn how to tune Wan2.2-Lightning for your specific video generation needs.
+"""
+
+import safetensors.torch
+import torch
+from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
+from diffusers.utils import export_to_video
+from huggingface_hub import hf_hub_download
+
+from QEfficient import QEffWanPipeline
+
+# ============================================================================
+# PIPELINE INITIALIZATION WITH CUSTOM PARAMETERS
+# ============================================================================
+
+# Option 1: Basic initialization with default parameters
+pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers")
+
+# ============================================================================
+# LORA ADAPTER LOADING FOR LIGHTNING MODEL
+# ============================================================================
+# Download and load Lightning LoRA adapters for faster inference
+
+# Download the LoRAs from Hugging Face Hub
+high_noise_lora_path = hf_hub_download(
+ repo_id="lightx2v/Wan2.2-Lightning",
+ filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors",
+)
+low_noise_lora_path = hf_hub_download(
+ repo_id="lightx2v/Wan2.2-Lightning",
+ filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors",
+)
+
+
+# LoRA conversion utility function
+def load_wan_lora(path: str):
+ """Convert and load WAN LoRA weights from safetensors format."""
+ return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path))
+
+
+# Load LoRA adapters into the high and low noise transformers
+pipeline.transformer.model.transformer_high.load_lora_adapter(
+ load_wan_lora(high_noise_lora_path), adapter_name="high_noise"
+)
+pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0])
+
+pipeline.transformer.model.transformer_low.load_lora_adapter(
+ load_wan_lora(low_noise_lora_path), adapter_name="low_noise"
+)
+pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0])
+
+# ============================================================================
+# OPTIONAL: CUSTOM SCHEDULER CONFIGURATION
+# ============================================================================
+# Uncomment to use a custom scheduler (e.g., different sampling methods):
+#
+# pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config)
+
+# ============================================================================
+# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE
+# ============================================================================
+# Reduce the number of transformer blocks to speed up video generation.
+#
+# Trade-off: Faster inference but potentially lower video quality
+# Use case: Quick testing, prototyping, or when speed is critical
+#
+# Uncomment the following lines to use only a subset of transformer layers:
+#
+# # Configure for 2-layer model (faster inference)
+# pipeline.transformer.model.transformer_high.config.num_layers = 1
+# pipeline.transformer.model.transformer_low.config.num_layers = 1
+#
+# # Reduce high noise transformer blocks
+# original_blocks = pipeline.transformer.model.transformer_high.blocks
+# pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList(
+# [original_blocks[i] for i in range(0, pipeline.transformer.model.transformer_high.config.num_layers)]
+# )
+#
+# # Reduce low noise transformer blocks
+# org_blocks = pipeline.transformer.model.transformer_low.blocks
+# pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList(
+# [org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config.num_layers)]
+# )
+
+# ============================================================================
+# OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION
+# ============================================================================
+# Pre-compile the model for optimized performance on target hardware.
+#
+# When to use:
+# - When you want to compile the model separately before generation
+# - When you need to skip video generation and only prepare the model
+#
+# NOTE-1: If compile_config is not specified, the default configuration from
+# QEfficient/diffusers/pipelines/wan/wan_config.json will be used
+#
+# NOTE-2: use_onnx_subfunctions=True enables modular ONNX export optimizations
+# This feature improves export performance by breaking down the model into smaller,
+# more manageable ONNX functions, which can lead to improved compile time.
+#
+# Uncomment to compile with a custom configuration:
+# pipeline.compile(
+# compile_config="examples/diffusers/wan/wan_config.json",
+# parallel=True,
+# height=480,
+# width=832,
+# num_frames=81,
+# use_onnx_subfunctions=True
+# )
+
+# ============================================================================
+# VIDEO GENERATION WITH CUSTOM RUNTIME CONFIGURATION
+# ============================================================================
+# Generate a video using the configured pipeline.
+#
+# Note: Use of custom_config_path provides flexibility to set device_ids for each
+# module, so you can skip the separate pipeline.compile() step.
+
+# Custom prompt for video generation
+prompt = "A cat wearing a hat walking through a magical forest with glowing mushrooms and fireflies dancing around, cinematic lighting, high quality"
+
+# Alternative video dimensions for different use cases, corresponding default blocking
+# height=192, width=320 # ATTENTION_BLOCKING_MODE=kv head_block_size=16 num_kv_blocks=3 python3 examples/diffusers/wan/wan_lightning.py
+# height=480, width=832 # ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=21 num_q_blocks=2 python3 examples/diffusers/wan/wan_lightning.py
+# height=720, width=1280 # ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=48 num_q_blocks=5 python3 examples/diffusers/wan/wan_lightning.py
+
+output = pipeline(
+ prompt=prompt,
+ num_frames=81, # Number of video frames to generate
+ guidance_scale=1.0, # Primary guidance scale
+ guidance_scale_2=1.0, # Secondary guidance scale for dual guidance
+ num_inference_steps=4, # Lightning model uses fewer steps
+ generator=torch.manual_seed(42), # For reproducible results
+ custom_config_path="examples/diffusers/wan/wan_config.json",
+ height=480,
+ width=832,
+ use_onnx_subfunctions=True, # Enable ONNX optimizations
+ parallel_compile=False, # Set to True for parallel compilation
+)
+
+# Extract generated frames and export to video
+frames = output.images[0]
+export_to_video(frames, "custom_wan_lightning_output.mp4", fps=16)
+print(output)
diff --git a/examples/disagg_serving/README.md b/examples/disagg_serving/README.md
new file mode 100644
index 000000000..fcf665357
--- /dev/null
+++ b/examples/disagg_serving/README.md
@@ -0,0 +1,31 @@
+# We should be using disaggragate serving for GPTOSS model for best performance
+ - GPT-OSS model has 128/4 for 120b and 32/4 ratio of total_experts/experts_per_tok
+ - We use read all experts only once always strategy in prefill-only model
+ - And we treat weights activtions meaning read only chosen experts for decode-only model
+
+# Prefill-only model
+## Blocking default behviour when `prefill_only=True` in compile API
+ - NUM_Q_BLOCKS= set number of Q blocks in attention
+ - NUM_FFN_BLOCKS= set number of blocks in FFN
+ - ENABLE_OPT_SWA="0" or "1" to enable/disable optimized SWA. when enabled we will be using only valid KVs for given block in Attention reducing MACs
+ - prefix_caching is not supported with this mode
+
+## Chunking pass `enable_chunking=True` and `prefill_only=True` in compile API
+ - Optimized SWA i.e. reading only valid KV as per diagonal attention mask is enabled for this version by default
+ - This model can be used for prefix_caching by passing `kv_cache_batch_size=` in compile API
+
+# Decode-only model
+## Retain Sliding window length of KV for sliding window layers, default behavour when `prefill_seq_len=1` in compile API
+ - This reduces the amount of DDR used by the model
+ - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=` and optinally `kv_cache_batch_size=` if needed
+## Full KV for sliding window layers pass `retain_full_kv=True` along with `prefill_seq_len=1` in compile API
+ - This uses higher DDR as we are retaining ctx_len KV even for sliding window layers but will be reading only sliding window len kv in attention
+ - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=` and optinally `kv_cache_batch_size=` if needed
+ - This is enabled for the usecase of multi-turn chat, where we will be running prefill-> decode and then use cache of prefill as well as decode combined to again run prefill, so we want to retain full KV for sliding window layers
+
+
+NOTE:
+* decode-only model currently fails compilation with `use_onnx_subfunctions=True` so avoid using it
+* 120B model needs NPI, there are two versions of NPI one with and without subfunction both are uploaded here, pass it as `node_precision_info=`
+* It is advised to use `use_onnx_subfunctions=True` with prefill-only model, otherwise the compilation times are too high, with this the model is supposed to export and fail during compile as it needs assert sdk, so user is supposed to run this compilation manually by pasting the command printed in the error
+
diff --git a/examples/disagg_serving/gpt_oss_disagg_mode.py b/examples/disagg_serving/gpt_oss_disagg_mode.py
new file mode 100644
index 000000000..fd0d5b045
--- /dev/null
+++ b/examples/disagg_serving/gpt_oss_disagg_mode.py
@@ -0,0 +1,137 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import time
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer
+
+from QEfficient import QEFFAutoModelForCausalLM
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+
+model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32
+
+prompt = """
+Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures.
+
+As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown.
+
+The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location.
+"""
+all_outputs = []
+# Run prefill
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+PREFILL_SEQ_LEN = 256
+CTX_LEN = 256
+inputs = tokenizer(prompt, return_tensors="np", padding=True)
+position_ids = inputs["attention_mask"].sum(1, keepdims=True)
+padded_len = inputs["input_ids"].shape[1]
+num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
+padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
+
+# Initialize variables specific to request
+# Calculate the max generation length.
+max_gen_len = CTX_LEN - position_ids.max()
+generation_len = max_gen_len
+
+
+qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
+config = qeff_model.model.config
+inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+inputs.pop("token_type_ids", None)
+inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
+past_key_values = []
+for i in range(config.num_hidden_layers):
+ cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN
+ pad_shape = (1, 8, cache_len, 64)
+ past_key = torch.zeros((pad_shape), dtype=torch.float32)
+ past_value = torch.zeros((pad_shape), dtype=torch.float32)
+ pkv = (past_key, past_value)
+ past_key_values.append(pkv)
+inputs["past_key_values"] = past_key_values
+
+
+decode_qpc_path = qeff_model.compile(
+ prefill_seq_len=1,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=True,
+ mxint8_kv_cache=True,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ offload_pt_weights=False,
+)
+prefill_qpc_path = qeff_model.compile(
+ prefill_seq_len=PREFILL_SEQ_LEN,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=True,
+ mxint8_kv_cache=True,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ prefill_only=True,
+ use_onnx_subfunctions=True,
+)
+
+prefill_session = QAICInferenceSession(prefill_qpc_path)
+
+logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
+prefill_session.set_buffers({"logits": logits_out_placeholder})
+inputs.pop("past_key_values")
+inputs = {k: v.detach().numpy() for k, v in inputs.items()}
+st = time.time()
+qpc_out = prefill_session.run(inputs)
+print(f"time for prefill_run={time.time() - st} sec\n")
+
+decode_session = QAICInferenceSession(decode_qpc_path)
+decode_session.set_buffers({"logits": logits_out_placeholder})
+
+decode_inputs = {
+ "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1),
+ "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1,
+}
+print("pos_id for decodee", decode_inputs["position_ids"])
+
+all_outputs.append(decode_inputs["input_ids"][0][0])
+for i in range(config.num_hidden_layers):
+ if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window:
+ k = qpc_out[f"past_key.{i}_RetainedState"]
+ v = qpc_out[f"past_value.{i}_RetainedState"]
+ mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window
+ decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2)
+ decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2)
+ else:
+ decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
+ decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
+
+st = time.time()
+decode_out = decode_session.run(decode_inputs)
+print(f"time for first run of decode with KV as input = {time.time() - st} sec\n")
+decode_session.skip_buffers(
+ [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")]
+)
+pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1
+st = time.time()
+for i in range(generation_len - 2):
+ loop_decode_inputs = {
+ "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
+ "position_ids": pos_id,
+ }
+ all_outputs.append(loop_decode_inputs["input_ids"][0][0])
+ decode_out = decode_session.run(loop_decode_inputs)
+ pos_id += 1
+
+
+print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}")
+print(all_outputs)
+print(tokenizer.decode(all_outputs))
diff --git a/examples/disagg_serving/subfunction_120b_npi.yaml b/examples/disagg_serving/subfunction_120b_npi.yaml
new file mode 100644
index 000000000..762703d58
--- /dev/null
+++ b/examples/disagg_serving/subfunction_120b_npi.yaml
@@ -0,0 +1,27 @@
+FP32NodeInstanceNames:
+ - CustomRMSNorm_58
+ - onnx::Shape_1033777
+ - CustomRMSNorm_349
+ - hidden.127
+ - CustomRMSNorm_27448
+ - onnx::Shape_1066066
+ - CustomRMSNorm_27709
+ - hidden.131
+ - CustomRMSNorm_54808
+ - onnx::Shape_878
+ - CustomRMSNorm_55105
+ - hidden
+ - hidden_states.259
+ - Add_348
+ - Add_347
+ - onnx::Add_1034099
+ - hidden_states.267
+ - Add_27708
+ - onnx::Add_1066358
+ - Add_27707
+ - hidden_states.3
+ - Add_55104
+ - onnx::Add_1209
+ - Add_55103
+ - /model/norm/CustomRMSNorm
+ - /model/norm/CustomRMSNorm_output_0
\ No newline at end of file
diff --git a/examples/disagg_serving/without_subfunc_npi_120b.yaml b/examples/disagg_serving/without_subfunc_npi_120b.yaml
new file mode 100644
index 000000000..ec6cf034f
--- /dev/null
+++ b/examples/disagg_serving/without_subfunc_npi_120b.yaml
@@ -0,0 +1,148 @@
+FP32NodeInstanceNames:
+ - /model/layers.0/Add_1_output_0
+ - /model/layers.0/Add_output_0
+ - /model/layers.0/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.1/Add_1_output_0
+ - /model/layers.1/Add_output_0
+ - /model/layers.1/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.10/Add_1_output_0
+ - /model/layers.10/Add_output_0
+ - /model/layers.10/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.11/Add_1_output_0
+ - /model/layers.11/Add_output_0
+ - /model/layers.11/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.12/Add_1_output_0
+ - /model/layers.12/Add_output_0
+ - /model/layers.12/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.13/Add_1_output_0
+ - /model/layers.13/Add_output_0
+ - /model/layers.13/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.14/Add_1_output_0
+ - /model/layers.14/Add_output_0
+ - /model/layers.14/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.15/Add_1_output_0
+ - /model/layers.15/Add_output_0
+ - /model/layers.15/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.16/Add_1_output_0
+ - /model/layers.16/Add_output_0
+ - /model/layers.16/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.17/Add_1_output_0
+ - /model/layers.17/Add_output_0
+ - /model/layers.17/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.18/Add_1_output_0
+ - /model/layers.18/Add_output_0
+ - /model/layers.18/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.19/Add_1_output_0
+ - /model/layers.19/Add_output_0
+ - /model/layers.19/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.2/Add_1_output_0
+ - /model/layers.2/Add_output_0
+ - /model/layers.2/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.20/Add_1_output_0
+ - /model/layers.20/Add_output_0
+ - /model/layers.20/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.21/Add_1_output_0
+ - /model/layers.21/Add_output_0
+ - /model/layers.21/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.22/Add_1_output_0
+ - /model/layers.22/Add_output_0
+ - /model/layers.22/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.23/Add_1_output_0
+ - /model/layers.23/Add_output_0
+ - /model/layers.23/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.24/Add_1_output_0
+ - /model/layers.24/Add_output_0
+ - /model/layers.24/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.25/Add_1_output_0
+ - /model/layers.25/Add_output_0
+ - /model/layers.25/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.26/Add_1_output_0
+ - /model/layers.26/Add_output_0
+ - /model/layers.26/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.27/Add_1_output_0
+ - /model/layers.27/Add_output_0
+ - /model/layers.27/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.28/Add_1_output_0
+ - /model/layers.28/Add_output_0
+ - /model/layers.28/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.29/Add_1_output_0
+ - /model/layers.29/Add_output_0
+ - /model/layers.29/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.3/Add_1_output_0
+ - /model/layers.3/Add_output_0
+ - /model/layers.3/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.30/Add_1_output_0
+ - /model/layers.30/Add_output_0
+ - /model/layers.30/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.31/Add_1_output_0
+ - /model/layers.31/Add_output_0
+ - /model/layers.31/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.32/Add_1_output_0
+ - /model/layers.32/Add_output_0
+ - /model/layers.32/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.33/Add_1_output_0
+ - /model/layers.33/Add_output_0
+ - /model/layers.33/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.34/Add_1_output_0
+ - /model/layers.34/Add_output_0
+ - /model/layers.34/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.35/Add_1_output_0
+ - /model/layers.35/Add_output_0
+ - /model/norm/Add_output_0
+ - /model/layers.35/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.4/Add_1_output_0
+ - /model/layers.4/Add_output_0
+ - /model/layers.4/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.5/Add_1_output_0
+ - /model/layers.5/Add_output_0
+ - /model/layers.5/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.6/Add_1_output_0
+ - /model/layers.6/Add_output_0
+ - /model/layers.6/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.7/Add_1_output_0
+ - /model/layers.7/Add_output_0
+ - /model/layers.7/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.8/Add_1_output_0
+ - /model/layers.8/Add_output_0
+ - /model/layers.8/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/layers.9/Add_1_output_0
+ - /model/layers.9/Add_output_0
+ - /model/layers.9/input_layernorm/CustomRMSNorm_output_0
+ - /model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0
+ - /model/norm/CustomRMSNorm_output_0
+
\ No newline at end of file
diff --git a/examples/gpt_oss_disagg_mode_with_chunking.py b/examples/gpt_oss_disagg_mode_with_chunking.py
new file mode 100644
index 000000000..363e2806c
--- /dev/null
+++ b/examples/gpt_oss_disagg_mode_with_chunking.py
@@ -0,0 +1,137 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import time
+
+import numpy as np
+import torch
+from transformers import AutoConfig, AutoTokenizer
+
+from QEfficient import QEFFAutoModelForCausalLM
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+
+model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32
+
+prompt = """
+Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures.
+
+As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown.
+
+The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location.
+"""
+# Run prefill
+config = AutoConfig.from_pretrained(model_id)
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+PREFILL_SEQ_LEN = 128
+CTX_LEN = 128 * 3
+
+qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
+
+decode_qpc_path = qeff_model.compile(
+ prefill_seq_len=1,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=True,
+ mxint8_kv_cache=True,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step
+ retain_full_kv=True,
+)
+
+
+# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68
+# prefill_qpc_path = "provide path here"
+prefill_qpc_path = qeff_model.compile(
+ prefill_seq_len=PREFILL_SEQ_LEN,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=True,
+ mxint8_kv_cache=True,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ prefill_only=True,
+ enable_chunking=True,
+ use_onnx_subfunctions=True,
+)
+
+
+inputs = tokenizer(prompt, return_tensors="np", padding=True)
+position_ids = inputs["attention_mask"].sum(1, keepdims=True)
+generation_len = CTX_LEN - position_ids.max()
+padded_len = inputs["input_ids"].shape[1]
+num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
+padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
+inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+inputs.pop("token_type_ids", None)
+inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
+inputs.pop("past_key_values", None)
+inputs = {k: v.detach().numpy() for k, v in inputs.items()}
+
+
+decode_session = QAICInferenceSession(decode_qpc_path)
+prefill_session = QAICInferenceSession(prefill_qpc_path)
+
+all_outputs = []
+for i in range(num_chunks):
+ chunk_inputs = inputs.copy()
+ chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+ chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+ ins = time.time()
+ qpc_out = prefill_session.run(chunk_inputs)
+ print(f"time for this run={time.time() - ins}")
+ for i in range(config.num_hidden_layers):
+ inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
+ inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
+
+all_outputs.append(np.argmax(qpc_out["logits"]))
+decode_inputs = {
+ "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1),
+ "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1,
+}
+for i in range(config.num_hidden_layers):
+ decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
+ decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
+
+st = time.time()
+decode_out = decode_session.run(decode_inputs)
+print(f"time for first run of decode with KV as input = {time.time() - st} sec\n")
+all_outputs.append(np.argmax(decode_out["logits"]))
+pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1
+loop_decode_inputs = {
+ "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
+ "position_ids": pos_id,
+}
+
+for i in range(config.num_hidden_layers):
+ loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
+ loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]
+
+st = time.time()
+for i in range(generation_len - 2):
+ decode_out = decode_session.run(loop_decode_inputs)
+ all_outputs.append(np.argmax(decode_out["logits"]))
+ pos_id += 1
+ for i in range(config.num_hidden_layers):
+ loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
+ loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]
+
+ loop_decode_inputs.update(
+ {
+ "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
+ "position_ids": pos_id,
+ }
+ )
+ft = time.time()
+
+print(f"decode tok/sec={(generation_len - 2) / (ft - st)}")
+print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}")
diff --git a/examples/performance/compute_context_length/README.md b/examples/performance/compute_context_length/README.md
index 9f1d29b9a..2115251e2 100644
--- a/examples/performance/compute_context_length/README.md
+++ b/examples/performance/compute_context_length/README.md
@@ -37,11 +37,22 @@ python basic_inference.py \
--model-name meta-llama/Llama-3.2-1B \
--prompt "Hello, how are you?" \
--ctx-len 1024 \
+ --ccl-enabled \
--comp-ctx-lengths-prefill "256,500" \
--comp-ctx-lengths-decode "512,1024" \
--generation-len 100
```
+# For automatic CCL lists generation, simply not pass CCL lists and only pass ccl-enabled flag
+```bash
+python basic_inference.py \
+ --model-name meta-llama/Llama-3.2-1B \
+ --prompt "Hello, how are you?" \
+ --ctx-len 1024 \
+ --ccl-enabled \
+ --generation-len 100
+```
+
### Vision-Language Models
Run VLM inference with CCL:
@@ -55,11 +66,22 @@ python vlm_inference.py \
--model-name meta-llama/Llama-3.2-11B-Vision-Instruct \
--query "Describe this image" \
--image-url "https://..." \
+ --ccl-enabled \
--comp-ctx-lengths-prefill "4096" \
--comp-ctx-lengths-decode "6144,8192" \
--ctx-len 8192
```
+# For automatic CCL lists generation, simply not pass CCL lists and only pass ccl-enabled flag
+```bash
+python vlm_inference.py \
+ --model-name meta-llama/Llama-3.2-11B-Vision-Instruct \
+ --query "Describe this image" \
+ --image-url "https://..." \
+ --ccl-enabled \
+ --ctx-len 8192
+```
+
## Available Examples
### Text-Only Models
diff --git a/examples/performance/compute_context_length/basic_inference.py b/examples/performance/compute_context_length/basic_inference.py
index 4533c47e8..6e8c045fb 100644
--- a/examples/performance/compute_context_length/basic_inference.py
+++ b/examples/performance/compute_context_length/basic_inference.py
@@ -54,13 +54,13 @@ def main():
parser.add_argument(
"--comp-ctx-lengths-prefill",
type=lambda x: [int(i) for i in x.split(",")],
- default="256,500",
+ default=None,
help="Comma-separated list of context lengths for prefill phase (e.g., '256,500')",
)
parser.add_argument(
"--comp-ctx-lengths-decode",
type=lambda x: [int(i) for i in x.split(",")],
- default="512,1024",
+ default=None,
help="Comma-separated list of context lengths for decode phase (e.g., '512,1024')",
)
parser.add_argument(
@@ -107,11 +107,7 @@ def main():
args = parser.parse_args()
print(f"Loading model: {args.model_name}")
- print("CCL Configuration:")
- print(f" - Prefill context lengths: {args.comp_ctx_lengths_prefill}")
- print(f" - Decode context lengths: {args.comp_ctx_lengths_decode}")
- print(f" - Max context length: {args.ctx_len}")
- print(f" - Continuous batching: {args.continuous_batching}")
+ print(f"Continuous batching: {args.continuous_batching}")
# Load model with CCL configuration
model = QEFFAutoModelForCausalLM.from_pretrained(
diff --git a/examples/performance/compute_context_length/gemma3.py b/examples/performance/compute_context_length/gemma3.py
index d9672b9e3..1dcec5c81 100644
--- a/examples/performance/compute_context_length/gemma3.py
+++ b/examples/performance/compute_context_length/gemma3.py
@@ -21,14 +21,16 @@
processor = AutoProcessor.from_pretrained(model_id)
## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained().
-## Use the optional comp_ctx_lengths argument to provide two lists of context lengths for the prefilling and decoding processes. If comp_ctx_lengths=None, the model will run with its default context length.
+## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length.
## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process.
## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk.
## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process.
## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index.
-## -- It starts from the correct value in the list and increases the context length dynamically when the cache index exceeds the current threshold.
+## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value.
ctx_len = 8192
+ccl_enabled = True
+# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding.
comp_ctx_lengths_prefill = [3072]
comp_ctx_lengths_decode = [4096, ctx_len]
@@ -40,7 +42,7 @@
attn_implementation="eager",
kv_offload=True,
qaic_config={
- "ccl_enabled": True,
+ "ccl_enabled": ccl_enabled,
},
)
diff --git a/examples/performance/compute_context_length/gpt_oss.py b/examples/performance/compute_context_length/gpt_oss.py
index 39a5d48ed..92bef9148 100644
--- a/examples/performance/compute_context_length/gpt_oss.py
+++ b/examples/performance/compute_context_length/gpt_oss.py
@@ -12,16 +12,17 @@
model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32
## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained().
-## Use the optional comp_ctx_lengths argument to provide two lists of context lengths for the prefilling and decoding processes. If comp_ctx_lengths=None, the model will run with its default context length.
+## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length.
## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process.
## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk.
## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process.
## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index.
-## -- It starts from the correct value in the list and increases the context length dynamically when the cache index exceeds the current threshold.
+## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value.
ctx_len = 4096
+ccl_enabled = True
+# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding.
# In moe models like gpt-oss, since prefill_seq_len=1 both comp_ctx_lengths_prefill and comp_ctx_lengths_decode can share similar lists.
-# Set the list of ccl during prefilling and decoding processes
comp_ctx_lengths_prefill = comp_ctx_lengths_decode = [1024, ctx_len]
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(
diff --git a/examples/performance/compute_context_length/granite_vision.py b/examples/performance/compute_context_length/granite_vision.py
index 6dd38395c..ef5dc3a51 100644
--- a/examples/performance/compute_context_length/granite_vision.py
+++ b/examples/performance/compute_context_length/granite_vision.py
@@ -98,6 +98,7 @@ def run_model(
num_devices = 4
ctx_len = 8192
ccl_enabled = True
+ # Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. If both are None, the lists will be generated automatically based on the context length.
comp_ctx_lengths_prefill = [5500]
comp_ctx_lengths_decode = [6144, ctx_len]
diff --git a/examples/performance/compute_context_length/internvl.py b/examples/performance/compute_context_length/internvl.py
index 19bcf4bc1..02e965e0d 100644
--- a/examples/performance/compute_context_length/internvl.py
+++ b/examples/performance/compute_context_length/internvl.py
@@ -263,6 +263,7 @@ def run_intern_on_aic(
ctx_len = 8192
ccl_enabled = True
+ # Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. If both are None, the lists will be generated automatically based on the context length.
comp_ctx_lengths_prefill = [4096]
comp_ctx_lengths_decode = [6144, ctx_len]
diff --git a/examples/performance/compute_context_length/llama4.py b/examples/performance/compute_context_length/llama4.py
index 8cdbd70a1..a867e1bd3 100644
--- a/examples/performance/compute_context_length/llama4.py
+++ b/examples/performance/compute_context_length/llama4.py
@@ -18,14 +18,16 @@
config.vision_config.num_hidden_layers = 2
## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained().
-## Use the optional comp_ctx_lengths argument to provide two lists of context lengths for the prefilling and decoding processes. If comp_ctx_lengths=None, the model will run with its default context length.
+## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length.
## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process.
## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk.
## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process.
## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index.
-## -- It starts from the correct value in the list and increases the context length dynamically when the cache index exceeds the current threshold.
+## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value.
ctx_len = 8192
+ccl_enabled = True
+# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding.
# Set the list of ccl during prefilling process
comp_ctx_lengths_prefill = [3072]
# Set the list of ccl during decoding process
@@ -37,7 +39,7 @@
kv_offload=True,
config=config,
qaic_config={
- "ccl_enabled": True,
+ "ccl_enabled": ccl_enabled,
},
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
diff --git a/examples/performance/compute_context_length/llama4_cb.py b/examples/performance/compute_context_length/llama4_cb.py
index ffbbff67f..f97160693 100644
--- a/examples/performance/compute_context_length/llama4_cb.py
+++ b/examples/performance/compute_context_length/llama4_cb.py
@@ -20,14 +20,16 @@
processor = AutoProcessor.from_pretrained(model_id)
## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained().
-## Use the optional comp_ctx_lengths argument to provide two lists of context lengths for the prefilling and decoding processes. If comp_ctx_lengths=None, the model will run with its default context length.
+## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length.
## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process.
## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk.
## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process.
## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index.
-## -- It starts from the correct value in the list and increases the context length dynamically when the cache index exceeds the current threshold.
+## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value.
ctx_len = 4096
+ccl_enabled = True
+# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding.
# Set the list of ccl during prefilling process
comp_ctx_lengths_prefill = [3072]
# Set the list of ccl during decoding process
@@ -42,7 +44,7 @@
config=config,
continuous_batching=True,
qaic_config={
- "ccl_enabled": True,
+ "ccl_enabled": ccl_enabled,
},
)
@@ -69,7 +71,7 @@
kv_offload=True,
config=config,
qaic_config={
- "ccl_enabled": True,
+ "ccl_enabled": ccl_enabled,
},
)
diff --git a/examples/performance/compute_context_length/llama4_multi_image.py b/examples/performance/compute_context_length/llama4_multi_image.py
index fd513fe45..314aa49b3 100644
--- a/examples/performance/compute_context_length/llama4_multi_image.py
+++ b/examples/performance/compute_context_length/llama4_multi_image.py
@@ -18,14 +18,16 @@
config.vision_config.num_hidden_layers = 2
## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained().
-## Use the optional comp_ctx_lengths argument to provide two lists of context lengths for the prefilling and decoding processes. If comp_ctx_lengths=None, the model will run with its default context length.
+## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length.
## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process.
## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk.
## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process.
## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index.
-## -- It starts from the correct value in the list and increases the context length dynamically when the cache index exceeds the current threshold.
+## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value.
ctx_len = 8192
+ccl_enabled = True
+# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding.
# Set the list of ccl during prefilling process
comp_ctx_lengths_prefill = [5376]
# Set the list of ccl during decoding process
@@ -37,7 +39,7 @@
kv_offload=True,
config=config,
qaic_config={
- "ccl_enabled": True,
+ "ccl_enabled": ccl_enabled,
},
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
diff --git a/examples/performance/compute_context_length/mistral3.py b/examples/performance/compute_context_length/mistral3.py
index 3763fbcde..a773ddfd9 100644
--- a/examples/performance/compute_context_length/mistral3.py
+++ b/examples/performance/compute_context_length/mistral3.py
@@ -101,6 +101,7 @@ def run_model(
num_cores = 16
num_devices = 4
ccl_enabled = True
+ # Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. If both are None, the lists will be generated automatically based on the context length.
comp_ctx_lengths_prefill = [4096]
comp_ctx_lengths_decode = [6144, ctx_len]
diff --git a/examples/performance/compute_context_length/molmo.py b/examples/performance/compute_context_length/molmo.py
index b5f1f50e6..8d773f5fe 100644
--- a/examples/performance/compute_context_length/molmo.py
+++ b/examples/performance/compute_context_length/molmo.py
@@ -19,15 +19,17 @@
# config.num_hidden_layers = 2
## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained().
-## Use the optional comp_ctx_lengths argument to provide two lists of context lengths for the prefilling and decoding processes. If comp_ctx_lengths=None, the model will run with its default context length.
+## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length.
## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process.
## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk.
## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process.
## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index.
-## -- It starts from the correct value in the list and increases the context length dynamically when the cache index exceeds the current threshold.
+## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value.
# load the model
ctx_len = 8192
+ccl_enabled = True
+# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding.
comp_ctx_lengths_prefill = [3072] # None #
comp_ctx_lengths_decode = [4096, 8192] # None #
@@ -37,7 +39,7 @@
trust_remote_code=True,
config=config,
qaic_config={
- "ccl_enabled": True,
+ "ccl_enabled": ccl_enabled,
},
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
diff --git a/examples/performance/compute_context_length/qwen2_5_vl.py b/examples/performance/compute_context_length/qwen2_5_vl.py
index 20960b6a9..5a6818930 100644
--- a/examples/performance/compute_context_length/qwen2_5_vl.py
+++ b/examples/performance/compute_context_length/qwen2_5_vl.py
@@ -23,14 +23,16 @@
config.text_config.num_hidden_layers = 2
## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained().
-## Use the optional comp_ctx_lengths argument to provide two lists of context lengths for the prefilling and decoding processes. If comp_ctx_lengths=None, the model will run with its default context length.
+## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length.
## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process.
## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk.
## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process.
## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index.
-## -- It starts from the correct value in the list and increases the context length dynamically when the cache index exceeds the current threshold.
+## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value.
ctx_len = 8192
+ccl_enabled = True
+# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding.
comp_ctx_lengths_prefill = [4096] # None #
comp_ctx_lengths_decode = [6144, ctx_len] # None #
@@ -40,7 +42,7 @@
kv_offload=True,
config=config,
qaic_config={
- "ccl_enabled": True,
+ "ccl_enabled": ccl_enabled,
},
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
diff --git a/examples/performance/compute_context_length/qwen2_5_vl_cb.py b/examples/performance/compute_context_length/qwen2_5_vl_cb.py
index fc330e14e..c247a1e58 100644
--- a/examples/performance/compute_context_length/qwen2_5_vl_cb.py
+++ b/examples/performance/compute_context_length/qwen2_5_vl_cb.py
@@ -20,14 +20,16 @@
config.text_config.num_hidden_layers = 4
## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained().
-## Use the optional comp_ctx_lengths argument to provide two lists of context lengths for the prefilling and decoding processes. If comp_ctx_lengths=None, the model will run with its default context length.
+## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length.
## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process.
## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk.
## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process.
## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index.
-## -- It starts from the correct value in the list and increases the context length dynamically when the cache index exceeds the current threshold.
+## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value.
ctx_len = 8192
+ccl_enabled = True
+# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding.
comp_ctx_lengths_prefill = [4096]
comp_ctx_lengths_decode = [6144, ctx_len]
@@ -38,7 +40,7 @@
config=config,
continuous_batching=True,
qaic_config={
- "ccl_enabled": True,
+ "ccl_enabled": ccl_enabled,
},
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
diff --git a/examples/performance/compute_context_length/qwen3moe.py b/examples/performance/compute_context_length/qwen3moe.py
index b53a28362..93849fa5a 100644
--- a/examples/performance/compute_context_length/qwen3moe.py
+++ b/examples/performance/compute_context_length/qwen3moe.py
@@ -17,15 +17,17 @@
"""
## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained().
-## Use the optional comp_ctx_lengths argument to provide two lists of context lengths for the prefilling and decoding processes. If comp_ctx_lengths=None, the model will run with its default context length.
+## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length.
## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process.
## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk.
## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process.
## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index.
-## -- It starts from the correct value in the list and increases the context length dynamically when the cache index exceeds the current threshold.
+## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value.
ctx_len = 1024
prefill_seq_len = 1
+ccl_enabled = True
+# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding.
# In moe models when compiling with prefill_seq_len=1 and non-continuous-batching mode, prefill and decode will share the same ccl specializations.
comp_ctx_lengths_prefill = comp_ctx_lengths_decode = [256, 512, ctx_len]
@@ -33,7 +35,7 @@
model_name,
continuous_batching=False,
qaic_config={
- "ccl_enabled": True,
+ "ccl_enabled": ccl_enabled,
},
)
@@ -49,6 +51,5 @@
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
)
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer)
diff --git a/examples/performance/compute_context_length/vlm_inference.py b/examples/performance/compute_context_length/vlm_inference.py
index 876daa3e6..294632fe3 100644
--- a/examples/performance/compute_context_length/vlm_inference.py
+++ b/examples/performance/compute_context_length/vlm_inference.py
@@ -58,10 +58,6 @@ def run_model(
"""
print(f"Loading model: {model_name}")
print(f"KV offload (Dual QPC mode): {kv_offload}")
- print("CCL Configuration:")
- print(f" - Prefill context lengths: {comp_ctx_lengths_prefill}")
- print(f" - Decode context lengths: {comp_ctx_lengths_decode}")
- print(f" - Max context length: {ctx_len}")
## STEP 1: Load the Processor and Model
@@ -186,13 +182,13 @@ def main():
parser.add_argument(
"--comp-ctx-lengths-prefill",
type=lambda x: [int(i) for i in x.split(",")],
- default="4096",
+ default=None,
help="Comma-separated list of context lengths for prefill phase (e.g., '4096')",
)
parser.add_argument(
"--comp-ctx-lengths-decode",
type=lambda x: [int(i) for i in x.split(",")],
- default="6144,8192",
+ default=None,
help="Comma-separated list of context lengths for decode phase (e.g., '6144,8192')",
)
parser.add_argument(
diff --git a/examples/performance/on_device_sampling.py b/examples/performance/on_device_sampling.py
index 6cc72b715..da9c5b43b 100644
--- a/examples/performance/on_device_sampling.py
+++ b/examples/performance/on_device_sampling.py
@@ -21,6 +21,7 @@ def main(args, **kwargs):
include_sampler = None
return_pdfs = None
max_top_k_ids = None
+ include_guided_decoding = None
sampling_params = None
bs = args.full_batch_size if args.full_batch_size is not None else args.batch_size
if args.override_qaic_config is not None:
@@ -28,6 +29,8 @@ def main(args, **kwargs):
if include_sampler is not None:
return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true"
max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512))
+ np.random.seed(int(args.random_number))
+ include_guided_decoding = args.override_qaic_config.get("aic_include_guided_decoding", None) == "true"
sampling_params = {
"repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
"presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1),
@@ -36,7 +39,9 @@ def main(args, **kwargs):
"top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1),
"top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1),
"min_ps": np.array(args.min_p, dtype=np.float32).repeat(bs).reshape(-1, 1),
- "random_numbers": np.array(args.random_number, dtype=np.float32).repeat(bs).reshape(-1, 1),
+ "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=max_top_k_ids), (bs, 1)).astype(
+ np.float32
+ ),
}
qaic_config = {
k: v
@@ -44,13 +49,12 @@ def main(args, **kwargs):
"include_sampler": include_sampler,
"return_pdfs": return_pdfs,
"max_top_k_ids": max_top_k_ids,
+ "include_guided_decoding": include_guided_decoding,
}.items()
if v is not None
}
print("qaic_config:")
pprint(qaic_config)
- print("sampling_params:")
- pprint(sampling_params)
# Load model with On Device Sampler enabled
qeff_model = AutoModelForCausalLM.from_pretrained(
@@ -60,6 +64,19 @@ def main(args, **kwargs):
)
print(f"{args.model_name} optimized for AI 100 \n", qeff_model)
+ if include_guided_decoding:
+ # Ideally this should come from a logits processor like xgrammar, but for the sake of the
+ # example, we generate a random bitmask
+ sampling_params.update(
+ {
+ "token_bitmasks": np.tile(
+ np.random.choice([True, False], size=(qeff_model.model.config.vocab_size,)), (bs, 1)
+ )
+ }
+ )
+ print("sampling_params:")
+ pprint(sampling_params)
+
# Compile the model for inference
generated_qpc_path = qeff_model.compile(
prefill_seq_len=args.prompt_len,
@@ -88,6 +105,7 @@ def main(args, **kwargs):
generation_len=args.generation_len,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
+ include_guided_decoding=include_guided_decoding,
sampling_params=sampling_params,
)
@@ -106,14 +124,14 @@ def main(args, **kwargs):
--num-cores 16 \
--mxint8-kv-cache \
--mxfp6-matmul \
- --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \
+ --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \
--repetition-penalty 1.9 \
--presence-penalty 0.8 \
--temperature 0.67 \
- --top-k 54720 \
+ --top-k 54 \
--top-p 0.89 \
--min-p 0.6 \
- --random-number 0.26
+ --random-number 26
2. For non-continuous batching:
python3.10 examples/on_device_sampling.py \
@@ -126,14 +144,34 @@ def main(args, **kwargs):
--num-cores 16 \
--mxint8-kv-cache \
--mxfp6-matmul \
- --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \
+ --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \
+ --repetition-penalty 1.9 \
+ --presence-penalty 0.8 \
+ --temperature 0.67 \
+ --top-k 54 \
+ --top-p 0.89 \
+ --min-p 0.6 \
+ --random-number 26
+
+ 3. With guided decoding:
+ python3.10 examples/on_device_sampling.py \
+ --model-name 'meta-llama/Llama-3.1-8B' \
+ --prompt-len 128 \
+ --ctx-len 256 \
+ --generation-len 20 \
+ --full-batch-size 2 \
+ --device-group [0,1,2,3] \
+ --num-cores 16 \
+ --mxint8-kv-cache \
+ --mxfp6-matmul \
+ --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:true" \
--repetition-penalty 1.9 \
--presence-penalty 0.8 \
--temperature 0.67 \
- --top-k 54720 \
+ --top-k 54 \
--top-p 0.89 \
--min-p 0.6 \
- --random-number 0.26
+ --random-number 26
"""
parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling")
diff --git a/pyproject.toml b/pyproject.toml
index 8e179ab4a..9da98f71d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,9 +20,10 @@ classifiers = [
requires-python = ">=3.8,<3.11"
dependencies = [
"transformers==4.55.0",
+ "diffusers== 0.35.1",
"huggingface-hub==0.34.0",
"hf_transfer==0.1.9",
- "peft==0.13.2",
+ "peft==0.17.0",
"datasets==2.20.0",
"fsspec==2023.6.0",
"multidict==6.0.4",
@@ -39,6 +40,9 @@ dependencies = [
"fire",
"py7zr",
"torchmetrics==1.7.0",
+ "ftfy==6.3.1",
+ "imageio==2.37.2",
+ "imageio-ffmpeg==0.6.0",
"torch==2.7.0; platform_machine=='aarch64'",
# Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11
"torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'",
@@ -52,7 +56,6 @@ dependencies = [
test = ["pytest","pytest-mock"]
docs = ["Sphinx==7.1.2","sphinx-rtd-theme==2.0.0","myst-parser==3.0.1","sphinx-multiversion"]
quality = ["black", "ruff", "hf_doc_builder@git+https://github.com/huggingface/doc-builder.git"]
-
[build-system]
requires = ["setuptools>=62.0.0"]
build-backend = "setuptools.build_meta"
@@ -74,3 +77,16 @@ target-version = "py310"
addopts = "-W ignore -s -v"
junit_logging = "all"
doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
+markers = [
+ "on_qaic: marks tests as requiring QAIC hardware",
+ "diffusion_models: marks tests for diffusion models",
+ "wan: marks tests for WAN model",
+ "flux: marks tests for Flux model",
+ "regular: marks regular tests",
+ "nightly: marks nightly tests",
+ "multimodal: marks multimodal tests",
+ "qnn: marks QNN tests",
+ "cli: marks CLI tests",
+ "finetune: marks finetune tests",
+ "vllm: marks vLLM tests"
+]
diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile
index 134770638..3420c025b 100644
--- a/scripts/Jenkinsfile
+++ b/scripts/Jenkinsfile
@@ -41,7 +41,7 @@ pipeline {
mkdir -p $PWD/Non_cli_qaic &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Non_cli_qaic &&
- pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log1.xml &&
+ pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log1.xml &&
junitparser merge tests/tests_log1.xml tests/tests_log.xml &&
deactivate"
'''
@@ -58,7 +58,7 @@ pipeline {
mkdir -p $PWD/Non_qaic &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Non_qaic &&
- pytest tests -m '(not cli) and (on_qaic) and (not nightly) and (not multimodal) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log2.xml &&
+ pytest tests -m '(not cli) and (on_qaic) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log2.xml &&
junitparser merge tests/tests_log2.xml tests/tests_log.xml &&
deactivate"
'''
@@ -67,9 +67,9 @@ pipeline {
}
}
}
- stage('QAIC MultiModal Tests') {
+ stage('QAIC MultiModal Tests') {
steps {
- timeout(time: 60, unit: 'MINUTES') {
+ timeout(time: 120, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
cd /efficient-transformers &&
@@ -77,16 +77,34 @@ pipeline {
mkdir -p $PWD/Non_cli_qaic_multimodal &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Non_cli_qaic_multimodal &&
- pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log6.xml &&
+ pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log6.xml &&
junitparser merge tests/tests_log6.xml tests/tests_log.xml &&
deactivate"
'''
}
}
}
+ stage('QAIC Diffusion Models Tests') {
+ steps {
+ timeout(time: 120, unit: 'MINUTES') {
+ sh '''
+ sudo docker exec ${BUILD_TAG} bash -c "
+ cd /efficient-transformers &&
+ . preflight_qeff/bin/activate &&
+ mkdir -p $PWD/Non_cli_qaic_diffusion &&
+ export TOKENIZERS_PARALLELISM=false &&
+ export QEFF_HOME=$PWD/Non_cli_qaic_diffusion &&
+ export HF_HUB_CACHE=/huggingface_hub &&
+ pytest tests -m '(not cli) and (on_qaic) and (diffusion_models) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log_diffusion.xml &&
+ junitparser merge tests/tests_log_diffusion.xml tests/tests_log.xml &&
+ deactivate"
+ '''
+ }
+ }
+ }
stage('Inference Tests') {
steps {
- timeout(time: 60, unit: 'MINUTES') {
+ timeout(time: 120, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
#source /qnn_sdk/bin/envsetup.sh &&
@@ -162,12 +180,13 @@ pipeline {
// }
stage('Finetune CLI Tests') {
steps {
- timeout(time: 5, unit: 'MINUTES') {
+ timeout(time: 20, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
cd /efficient-transformers &&
. preflight_qeff/bin/activate &&
pip install /opt/qti-aic/integrations/torch_qaic/py310/torch_qaic-0.1.0-cp310-cp310-linux_x86_64.whl &&
+ pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0 --index-url https://download.pytorch.org/whl/cpu &&
mkdir -p $PWD/cli_qaic_finetuning &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/cli_qaic_finetuning &&
diff --git a/scripts/memory_profiling/README.md b/scripts/memory_profiling/README.md
new file mode 100644
index 000000000..efb995815
--- /dev/null
+++ b/scripts/memory_profiling/README.md
@@ -0,0 +1,199 @@
+# QEfficient Memory Profiling
+
+A memory profiling solution for QEfficient workflows with manual operation marking.
+
+
+
+## Quick Start
+
+```python
+from profiler import QEffMemoryProfiler
+from QEfficient import QEFFAutoModelForCausalLM
+from transformers import AutoTokenizer
+
+# Initialize profiler with verbose output to see detailed memory tracking information
+profiler = QEffMemoryProfiler(verbose=True)
+# Start monitoring memory usage - this begins tracking memory consumption
+profiler.start_monitoring()
+
+# Mark the start of model loading operation for memory profiling, this will help to create stage wise partitioning the output graph
+profiler.mark_operation("Loading model")
+
+model = QEFFAutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
+tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
+
+# Mark the export operation
+profiler.mark_operation("Export")
+model.export()
+
+# Mark the compilation operation
+profiler.mark_operation("Compile")
+model.compile(prefill_seq_len=128, ctx_len=256, num_cores=16)
+
+# Mark the text generation operation
+profiler.mark_operation("Generation")
+output = model.generate(prompts=["Hello world"], tokenizer=tokenizer, generation_len=100)
+
+# Stop memory monitoring and generate reports
+profiler.stop_monitoring()
+
+# Print a detailed memory usage report to the console showing peak memory and operation-wise breakdown (optional)
+print(profiler.get_memory_report())
+
+# Generate a visual graph of memory usage over time and save it as an image file
+profiler.generate_memory_graph("profile.png")
+```
+
+## Configuration
+
+### Basic Configuration
+
+```python
+profiler = QEffMemoryProfiler(
+ sampling_interval=0.1, # Sample every 100ms
+ output_file="my_profile.png", # Custom output file
+ verbose=True, # Enable detailed logging
+ enable_cpu_monitoring=True, # Monitor CPU usage
+ enable_disk_monitoring=True, # Monitor disk I/O
+)
+```
+
+### Manual Operation Marking
+
+```python
+profiler = QEffMemoryProfiler()
+profiler.start_monitoring()
+
+# Manual operation marking
+profiler.mark_operation("Custom Operation 1")
+# ... your code ...
+
+profiler.mark_operation("Custom Operation 2")
+# ... more code ...
+
+profiler.stop_monitoring()
+```
+
+## API Reference
+
+### QEffMemoryProfiler
+
+#### Constructor Parameters
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `sampling_interval` | `float` | `0.05` | Time between samples (seconds) |
+| `output_file` | `str` | `"qeff_memory_profile.png"` | Output file path |
+| `verbose` | `bool` | `False` | Enable verbose logging |
+| `enable_cpu_monitoring` | `bool` | `True` | Monitor CPU usage |
+| `enable_disk_monitoring` | `bool` | `True` | Monitor disk I/O |
+
+#### Methods
+
+- **`start_monitoring()`**: Start background monitoring
+- **`stop_monitoring()`**: Stop monitoring and mark completion
+- **`mark_operation(name: str)`**: Manually mark operation start
+- **`get_memory_report() -> str`**: Generate comprehensive text report
+- **`generate_memory_graph(filename: str)`**: Create visualization
+- **`stop_and_save(filename: str) -> str`**: Convenience method to stop and save
+
+#### Properties
+
+- **`peak_rss`**: Peak RSS memory usage (MB)
+- **`peak_operation`**: Operation during peak memory
+- **`samples`**: List of collected profiling samples
+- **`operations`**: List of marked operations with timestamps
+
+## Operation Types
+
+The profiler supports marking these common QEfficient operations:
+
+- **Model Loading**: `from_pretrained`, `AutoModel`, `AutoTokenizer`
+- **Export**: `model.export()`, ONNX transforms, PyTorch transforms
+- **Compilation**: `model.compile()`, QNN compilation
+- **Generation**: `model.generate()`, inference execution
+- **Cleanup**: Memory cleanup, garbage collection
+
+## Output
+
+### Console Report
+```
+QEFFICIENT PERFORMANCE MONITORING REPORT
+============================================================
+Peak Memory Usage:
+ β’ RSS (Physical): 18.7 GB at 14:23:45
+ β’ Peak during: Compilation
+
+Memory Statistics:
+ β’ Current RSS: 16.2 GB (Delta: +15.8 GB)
+ β’ Duration: 185.3 seconds
+ β’ Operations: 4
+
+QEfficient Operations Timeline:
+ 1. 0.0s - Model Loading (25.2s) [+8.2 GB]
+ 2. 25.2s - Export (15.4s) [+2.1 GB]
+ 3. 40.6s - Compilation (120.8s) [+6.3 GB] <- Peak
+ 4. 161.4s - Generation (18.7s) [+1.2 GB]
+```
+
+### Visualization
+
+The profiler generates a comprehensive 4-panel visualization:
+
+1. **Memory Timeline**: RSS usage with colored operation phases
+2. **CPU Usage**: CPU utilization with performance zones
+3. **Disk I/O**: Read/write activity per operation phase
+4. **Phase Duration**: Timing analysis with duration labels
+
+#### Sample Output
+
+
+
+*Example memory profiling output showing QEfficient workflow phases including model loading, ONNX transforms, compilation, and generation phases with detailed memory, CPU, and disk I/O metrics.*
+
+## Advanced Usage
+
+
+### Accessing Raw Data
+
+```python
+# Get synchronized data arrays
+data = profiler.get_synchronized_data()
+timestamps = data['timestamps']
+memory_usage = data['rss_memory']
+cpu_usage = data['cpu_usage']
+
+# Access individual samples
+for sample in profiler.samples:
+ print(f"Time: {sample.timestamp}, RSS: {sample.rss_mb} MB")
+```
+
+## Integration Examples
+
+### With Existing QEfficient Scripts
+
+```python
+# Add to existing QEfficient workflow
+profiler = QEffMemoryProfiler(output_file="workflow_profile.png")
+profiler.start_monitoring()
+
+# Existing QEfficient code unchanged
+model = QEFFAutoModelForCausalLM.from_pretrained(model_name)
+# ... rest of workflow ...
+
+# Add at end
+report = profiler.stop_and_save()
+print(report)
+```
+
+
+## Limitations
+
+### Disk I/O Tracking
+
+**Subprocess I/O Limitation**: Disk I/O tracking captures parent process I/O only. Subprocess I/O (e.g., compilation reading ONNX files via `subprocess.run()`) is not captured due to Linux I/O accounting limitations. During compilation phases, expect lower I/O readings than actual file operations performed by subprocesses.
+
+## Compatibility
+
+- **Python**: 3.7+
+- **Dependencies**: `psutil`, `matplotlib`, `numpy`
diff --git a/scripts/memory_profiling/__init__.py b/scripts/memory_profiling/__init__.py
new file mode 100644
index 000000000..dc1377d0b
--- /dev/null
+++ b/scripts/memory_profiling/__init__.py
@@ -0,0 +1,53 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+QEfficient Memory Profiling
+
+A production-ready memory profiling solution specifically designed for QEfficient workflows.
+Provides manual operation marking, comprehensive metrics collection, and professional visualization.
+
+Usage Example:
+
+```python
+from scripts.memory_profiling import QEffMemoryProfiler
+
+profiler = QEffMemoryProfiler(verbose=True)
+profiler.start_monitoring()
+# ... your QEfficient code ...
+profiler.stop_monitoring()
+print(profiler.get_memory_report())
+profiler.generate_memory_graph()
+```
+"""
+
+__version__ = "2.0.0"
+__author__ = "Qualcomm Technologies, Inc."
+
+# Core profiler components
+from .profiler import (
+ MetricsCollector,
+ ProfilerConfig,
+ ProfileSample,
+ QEffMemoryProfiler,
+)
+
+# Visualization component (imported on-demand)
+try:
+ from .visualizer import QEffMemoryVisualizer
+except ImportError:
+ # Handle case where matplotlib is not available
+ QEffMemoryVisualizer = None
+
+__all__ = [
+ "QEffMemoryProfiler",
+ "ProfilerConfig",
+ "ProfileSample",
+ "MetricsCollector",
+ "QEffMemoryVisualizer",
+ "__version__",
+]
diff --git a/scripts/memory_profiling/memory_profile_llama3.2.png b/scripts/memory_profiling/memory_profile_llama3.2.png
new file mode 100644
index 000000000..e91c1d04a
Binary files /dev/null and b/scripts/memory_profiling/memory_profile_llama3.2.png differ
diff --git a/scripts/memory_profiling/profiler.py b/scripts/memory_profiling/profiler.py
new file mode 100644
index 000000000..cfd53e4d7
--- /dev/null
+++ b/scripts/memory_profiling/profiler.py
@@ -0,0 +1,729 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+QEfficient Memory Profiler - Production-Ready Memory Monitoring
+
+This module provides comprehensive memory profiling capabilities specifically
+designed for QEfficient workflows.
+"""
+
+import os
+import threading
+import time
+from dataclasses import dataclass
+from datetime import datetime
+from typing import Dict, List, Optional, Tuple
+
+import psutil
+
+from QEfficient.utils.logging_utils import logger
+
+
+@dataclass
+class ProfilerConfig:
+ """Configuration for memory profiler."""
+
+ sampling_interval: float = 0.2
+ output_file: Optional[str] = None
+ verbose: bool = False
+ enable_cpu_monitoring: bool = True
+ enable_disk_monitoring: bool = True
+ track_child_processes: bool = True
+ child_scan_interval: float = 1.0
+
+
+@dataclass
+class ProfileSample:
+ """Single profiling sample containing all metrics."""
+
+ timestamp: datetime
+ rss_mb: float
+ vms_mb: float
+ cpu_percent: float = 0.0
+ disk_read_mb: float = 0.0
+ disk_write_mb: float = 0.0
+ disk_read_rate: float = 0.0
+ disk_write_rate: float = 0.0
+
+
+class MetricsCollector:
+ """Handles collection of system metrics with child process support."""
+
+ def __init__(self, config: ProfilerConfig):
+ self.config = config
+ self.process = psutil.Process(os.getpid())
+ self._last_disk_counters = None
+ self._last_disk_time = None
+ self._cpu_initialized = False
+ self._last_cpu_ema = 0.0
+ self._cpu_ema_alpha = 0.3
+
+ # Child process tracking
+ self._track_children = config.track_child_processes
+ self._child_processes: Dict[int, psutil.Process] = {}
+ self._last_child_scan = 0.0
+ self._child_scan_interval = config.child_scan_interval
+ self._child_cpu_cache: Dict[int, float] = {}
+
+ if self._track_children and self.config.verbose:
+ logger.info("Child process tracking enabled")
+
+ def initialize_cpu_monitoring(self) -> None:
+ """Initialize CPU monitoring."""
+ try:
+ self.process.cpu_percent() # First call to establish baseline
+ self._cpu_initialized = True
+
+ # Initialize child process CPU monitoring
+ if self._track_children:
+ self._update_child_processes()
+ for child_proc in self._child_processes.values():
+ try:
+ child_proc.cpu_percent() # Initialize baseline for children
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ continue
+
+ if self.config.verbose:
+ logger.info("CPU measurement initialized")
+ except Exception as e:
+ if self.config.verbose:
+ logger.warning(f"CPU initialization warning: {e}")
+ self._cpu_initialized = False
+
+ def _update_child_processes(self) -> None:
+ """Discover and track child processes (compilation subprocesses)."""
+ current_time = time.time()
+ # Only scan for children if we don't have any, or every 5 seconds
+ scan_interval = 5.0 if self._child_processes else self._child_scan_interval
+ if current_time - self._last_child_scan < scan_interval:
+ return
+
+ try:
+ # Get current children (recursive to catch subprocess chains)
+ children = self.process.children(recursive=True)
+
+ # Add new children
+ new_children_count = 0
+ for child in children:
+ if child.pid not in self._child_processes:
+ try:
+ # Verify child is still running and accessible
+ if child.is_running():
+ self._child_processes[child.pid] = child
+ self._child_cpu_cache[child.pid] = 0.0
+
+ # Initialize CPU monitoring for new child
+ try:
+ child.cpu_percent() # First call to establish baseline
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ pass # Child may have terminated quickly
+
+ new_children_count += 1
+
+ if self.config.verbose:
+ try:
+ cmd_name = child.name()
+ logger.info(f"Tracking new subprocess: PID {child.pid} ({cmd_name})")
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ logger.info(f"Tracking new subprocess: PID {child.pid}")
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ continue
+
+ # Remove terminated children
+ terminated_pids = []
+ for pid, proc in self._child_processes.items():
+ try:
+ if not proc.is_running():
+ terminated_pids.append(pid)
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ terminated_pids.append(pid)
+
+ for pid in terminated_pids:
+ if pid in self._child_processes:
+ del self._child_processes[pid]
+ if pid in self._child_cpu_cache:
+ del self._child_cpu_cache[pid]
+ if self.config.verbose:
+ logger.info(f"Removed terminated subprocess: PID {pid}")
+
+ if new_children_count > 0 and self.config.verbose:
+ logger.info(f"Now tracking {len(self._child_processes)} child processes")
+
+ except Exception as e:
+ if self.config.verbose:
+ logger.warning(f"Child process scan error: {e}")
+
+ self._last_child_scan = current_time
+
+ def get_memory_usage(self) -> Tuple[float, float]:
+ """Get current memory usage in MB (parent + children)."""
+ try:
+ # Parent process memory
+ mem_info = self.process.memory_info()
+ total_rss = mem_info.rss / 1024 / 1024
+ total_vms = mem_info.vms / 1024 / 1024
+
+ # Add child process memory (if tracking enabled)
+ if self._track_children:
+ child_rss = 0.0
+ child_vms = 0.0
+ active_children = 0
+ stale_children = []
+
+ # Iterate through current child processes
+ for pid, child_proc in self._child_processes.items():
+ try:
+ child_mem = child_proc.memory_info()
+ child_rss += child_mem.rss / 1024 / 1024
+ child_vms += child_mem.vms / 1024 / 1024
+ active_children += 1
+ except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
+ # Mark child as stale for cleanup
+ stale_children.append(pid)
+ continue
+
+ # Clean up stale children (don't do this during iteration)
+ for pid in stale_children:
+ if pid in self._child_processes:
+ del self._child_processes[pid]
+ if pid in self._child_cpu_cache:
+ del self._child_cpu_cache[pid]
+
+ total_rss += child_rss
+ total_vms += child_vms
+
+ if self.config.verbose and active_children > 0:
+ logger.debug(
+ f"Memory: Parent {mem_info.rss / 1024 / 1024:.1f}MB + "
+ f"Children {child_rss:.1f}MB = Total {total_rss:.1f}MB RSS"
+ )
+
+ return total_rss, total_vms
+ except Exception as e:
+ if self.config.verbose:
+ logger.warning(f"Memory collection error: {e}")
+ return 0.0, 0.0
+
+ def get_cpu_usage(self) -> float:
+ """Get CPU usage with child processes included and smoothing."""
+ if not self.config.enable_cpu_monitoring:
+ return 0.0
+
+ try:
+ import multiprocessing
+
+ num_cores = multiprocessing.cpu_count()
+
+ parent_cpu_raw = 0.0
+ child_cpu_raw_total = 0.0
+
+ # Parent CPU (raw percentage, can be >100% on multi-core)
+ if self._cpu_initialized:
+ parent_cpu_raw = self.process.cpu_percent()
+ if parent_cpu_raw < 0:
+ parent_cpu_raw = 0.0
+
+ # Child CPU (if tracking enabled)
+ if self._track_children:
+ active_children = 0
+
+ for pid, child_proc in list(self._child_processes.items()):
+ try:
+ child_cpu_raw = child_proc.cpu_percent()
+ if child_cpu_raw >= 0:
+ # Cache raw CPU value
+ self._child_cpu_cache[pid] = child_cpu_raw
+ child_cpu_raw_total += child_cpu_raw
+ active_children += 1
+ except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
+ # Use cached value if available, otherwise skip
+ if pid in self._child_cpu_cache:
+ child_cpu_raw_total += self._child_cpu_cache[pid]
+ continue
+
+ if self.config.verbose and active_children > 0:
+ # Convert to system-wide percentage for logging
+ parent_system_pct = parent_cpu_raw / num_cores
+ child_system_pct = child_cpu_raw_total / num_cores
+ logger.debug(
+ f"CPU: Parent {parent_system_pct:.1f}% + "
+ f"Children {child_system_pct:.1f}% (from {active_children} processes) "
+ f"= {parent_system_pct + child_system_pct:.1f}% system-wide"
+ )
+
+ # Calculate system-wide CPU percentage
+ # psutil.Process.cpu_percent() returns per-process CPU time percentage
+ # To get system-wide percentage: divide by number of cores
+ total_process_cpu = parent_cpu_raw + child_cpu_raw_total
+ system_wide_cpu = total_process_cpu / num_cores
+
+ # Cap at 100% (shouldn't exceed this in normal cases)
+ system_wide_cpu = min(system_wide_cpu, 100.0)
+
+ # Apply exponential moving average smoothing
+ if system_wide_cpu > 0 or self._last_cpu_ema > 0:
+ smoothed_cpu = self._cpu_ema_alpha * system_wide_cpu + (1 - self._cpu_ema_alpha) * self._last_cpu_ema
+ self._last_cpu_ema = smoothed_cpu
+ return smoothed_cpu
+
+ return 0.0
+ except Exception as e:
+ if self.config.verbose:
+ logger.warning(f"CPU collection error: {e}")
+ return self._last_cpu_ema
+
+ def get_disk_io_stats(self) -> Tuple[float, float, float, float]:
+ """Get disk I/O statistics with rate calculation (parent + children)."""
+ if not self.config.enable_disk_monitoring:
+ return 0.0, 0.0, 0.0, 0.0
+
+ try:
+ current_time = time.time()
+
+ # Parent process I/O
+ parent_io = self.process.io_counters()
+
+ # Determine which counters to use
+ use_chars = hasattr(parent_io, "read_chars") and hasattr(parent_io, "write_chars")
+
+ if use_chars:
+ total_read_bytes = parent_io.read_chars
+ total_write_bytes = parent_io.write_chars
+ else:
+ total_read_bytes = parent_io.read_bytes
+ total_write_bytes = parent_io.write_bytes
+
+ # Add child process I/O (if tracking enabled)
+ if self._track_children:
+ child_read_total = 0
+ child_write_total = 0
+ active_io_children = 0
+
+ for pid, child_proc in list(self._child_processes.items()):
+ try:
+ child_io = child_proc.io_counters()
+ if use_chars and hasattr(child_io, "read_chars") and hasattr(child_io, "write_chars"):
+ child_read_total += child_io.read_chars
+ child_write_total += child_io.write_chars
+ else:
+ child_read_total += child_io.read_bytes
+ child_write_total += child_io.write_bytes
+ active_io_children += 1
+ except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
+ # Child process terminated or inaccessible
+ continue
+
+ total_read_bytes += child_read_total
+ total_write_bytes += child_write_total
+
+ if self.config.verbose and active_io_children > 0:
+ parent_read_mb = (
+ parent_io.read_chars / 1024 / 1024 if use_chars else parent_io.read_bytes / 1024 / 1024
+ )
+ parent_write_mb = (
+ parent_io.write_chars / 1024 / 1024 if use_chars else parent_io.write_bytes / 1024 / 1024
+ )
+ child_read_mb = child_read_total / 1024 / 1024
+ child_write_mb = child_write_total / 1024 / 1024
+ logger.debug(
+ f"Disk I/O: Parent R:{parent_read_mb:.1f}MB W:{parent_write_mb:.1f}MB + "
+ f"Children R:{child_read_mb:.1f}MB W:{child_write_mb:.1f}MB "
+ f"(from {active_io_children} processes)"
+ )
+
+ # Convert to MB
+ read_mb = total_read_bytes / 1024 / 1024
+ write_mb = total_write_bytes / 1024 / 1024
+
+ # Calculate rates
+ read_rate = 0.0
+ write_rate = 0.0
+
+ if self._last_disk_counters is not None and self._last_disk_time is not None:
+ time_delta = current_time - self._last_disk_time
+ if time_delta > 0:
+ # Calculate delta from last measurement
+ if use_chars:
+ last_read = self._last_disk_counters.get("read_chars", 0)
+ last_write = self._last_disk_counters.get("write_chars", 0)
+ else:
+ last_read = self._last_disk_counters.get("read_bytes", 0)
+ last_write = self._last_disk_counters.get("write_bytes", 0)
+
+ read_delta = (total_read_bytes - last_read) / 1024 / 1024 # MB
+ write_delta = (total_write_bytes - last_write) / 1024 / 1024 # MB
+
+ read_rate = read_delta / time_delta # MB/s
+ write_rate = write_delta / time_delta # MB/s
+
+ # Update counters (store as dict to handle both counter types)
+ if use_chars:
+ self._last_disk_counters = {"read_chars": total_read_bytes, "write_chars": total_write_bytes}
+ else:
+ self._last_disk_counters = {"read_bytes": total_read_bytes, "write_bytes": total_write_bytes}
+ self._last_disk_time = current_time
+
+ return read_mb, write_mb, read_rate, write_rate
+
+ except Exception as e:
+ if self.config.verbose:
+ logger.warning(f"Disk I/O collection error: {e}")
+ return 0.0, 0.0, 0.0, 0.0
+
+ def collect_sample(self) -> ProfileSample:
+ """Collect a complete profiling sample."""
+ timestamp = datetime.now()
+ rss_mb, vms_mb = self.get_memory_usage()
+ cpu_percent = self.get_cpu_usage()
+ read_bytes, write_bytes, read_rate, write_rate = self.get_disk_io_stats()
+
+ return ProfileSample(
+ timestamp=timestamp,
+ rss_mb=rss_mb,
+ vms_mb=vms_mb,
+ cpu_percent=cpu_percent,
+ disk_read_mb=read_bytes,
+ disk_write_mb=write_bytes,
+ disk_read_rate=read_rate,
+ disk_write_rate=write_rate,
+ )
+
+
+class QEffMemoryProfiler:
+ """
+ Production-ready memory profiler for QEfficient workflows.
+
+ Features:
+ - Manual operation marking for QEfficient workflows
+ - Production-quality visualization with detailed segment analysis
+ - Precise memory attribution and performance metrics
+ - Professional-grade reporting suitable for debugging and optimization
+ """
+
+ # Segment colors for visualization
+ SEGMENT_COLORS = {
+ "Initialization": "#E8E8E8",
+ "Model Loading": "#FF6B6B",
+ "Export": "#FFEAA7",
+ "Model Export": "#FFEAA7",
+ "Compilation": "#98D8C8",
+ "Model Compilation": "#98D8C8",
+ "Generation": "#F7DC6F",
+ "Text Generation": "#F7DC6F",
+ "Cleanup": "#AED6F1",
+ "Completion": "#D5DBDB",
+ }
+
+ def __init__(
+ self, sampling_interval: float = 0.05, output_file: Optional[str] = None, verbose: bool = False, **kwargs
+ ):
+ """
+ Initialize the QEfficient Memory Profiler.
+
+ Args:
+ sampling_interval: Time between memory samples in seconds
+ output_file: Output file for memory profile graph
+ verbose: Enable verbose output for monitoring operations
+ """
+ # Create configuration
+ self.config = ProfilerConfig(
+ sampling_interval=sampling_interval,
+ output_file=output_file or "qeff_memory_profile.png",
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Initialize components
+ self.metrics_collector = MetricsCollector(self.config)
+
+ # Monitoring state
+ self.monitoring = False
+ self.monitor_thread = None
+
+ # self.samples = deque(maxlen=5000) # Auto-evicts old samples
+ self.samples: List[ProfileSample] = [] # This could slow down for very long runs
+ self.operations: List[Tuple[datetime, str]] = []
+
+ # Peak tracking
+ self.peak_rss = 0.0
+ self.peak_vms = 0.0
+ self.peak_rss_time: Optional[datetime] = None
+ self.peak_vms_time: Optional[datetime] = None
+ self.peak_operation: Optional[str] = None
+
+ # Operation tracking
+ self.current_operation = "Initialization"
+ self.operation_start_time = datetime.now()
+ self.operation_durations: Dict[str, float] = {}
+ self.operation_memory_deltas: Dict[str, float] = {}
+
+ # Legacy property accessors for backward compatibility
+ @property
+ def timestamps(self) -> List[datetime]:
+ """Get timestamps from samples."""
+ return [sample.timestamp for sample in self.samples]
+
+ @property
+ def rss_memory(self) -> List[float]:
+ """Get RSS memory values from samples."""
+ return [sample.rss_mb for sample in self.samples]
+
+ @property
+ def vms_memory(self) -> List[float]:
+ """Get VMS memory values from samples."""
+ return [sample.vms_mb for sample in self.samples]
+
+ @property
+ def cpu_usage(self) -> List[float]:
+ """Get CPU usage values from samples."""
+ return [sample.cpu_percent for sample in self.samples]
+
+ @property
+ def disk_read_bytes(self) -> List[float]:
+ """Get disk read bytes from samples."""
+ return [sample.disk_read_mb for sample in self.samples]
+
+ @property
+ def disk_write_bytes(self) -> List[float]:
+ """Get disk write bytes from samples."""
+ return [sample.disk_write_mb for sample in self.samples]
+
+ @property
+ def disk_read_rate(self) -> List[float]:
+ """Get disk read rates from samples."""
+ return [sample.disk_read_rate for sample in self.samples]
+
+ @property
+ def disk_write_rate(self) -> List[float]:
+ """Get disk write rates from samples."""
+ return [sample.disk_write_rate for sample in self.samples]
+
+ @property
+ def sampling_interval(self) -> float:
+ """Get sampling interval."""
+ return self.config.sampling_interval
+
+ @property
+ def output_file(self) -> str:
+ """Get output file path."""
+ return self.config.output_file
+
+ @property
+ def verbose(self) -> bool:
+ """Get verbose flag."""
+ return self.config.verbose
+
+ def start_monitoring(self) -> None:
+ """Start continuous memory monitoring in background thread."""
+ if self.monitoring:
+ return
+
+ # Initialize CPU measurement
+ self.metrics_collector.initialize_cpu_monitoring()
+
+ self.monitoring = True
+ self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
+ self.monitor_thread.start()
+
+ if self.config.verbose:
+ logger.info(f"QEff Memory monitoring started (sampling every {self.config.sampling_interval}s)")
+
+ def stop_monitoring(self) -> None:
+ """Stop memory monitoring and generate reports."""
+ if not self.monitoring:
+ return
+
+ self.monitoring = False
+ if self.monitor_thread:
+ self.monitor_thread.join(timeout=1.0)
+
+ # Mark completion
+ self.mark_operation("Completion")
+
+ if self.config.verbose:
+ logger.info("QEff Memory monitoring stopped")
+
+ def _monitor_loop(self) -> None:
+ """Background monitoring loop."""
+ while self.monitoring:
+ try:
+ # Update child processes periodically (throttled internally)
+ if self.metrics_collector._track_children:
+ self.metrics_collector._update_child_processes()
+
+ # Collect sample
+ sample = self.metrics_collector.collect_sample()
+ self.samples.append(sample)
+
+ # Update peaks
+ self._update_peaks(sample)
+
+ time.sleep(self.config.sampling_interval)
+
+ except Exception as e:
+ if self.config.verbose:
+ logger.warning(f"Monitoring error: {e}")
+ break
+
+ def _update_peaks(self, sample: ProfileSample) -> None:
+ """Update peak memory tracking."""
+ if sample.rss_mb > self.peak_rss:
+ self.peak_rss = sample.rss_mb
+ self.peak_rss_time = sample.timestamp
+ self.peak_operation = self.current_operation
+
+ if sample.vms_mb > self.peak_vms:
+ self.peak_vms = sample.vms_mb
+ self.peak_vms_time = sample.timestamp
+
+ def mark_operation(self, operation_name: str) -> None:
+ """Mark the start of a new operation."""
+ current_time = datetime.now()
+ current_rss = self.samples[-1].rss_mb if self.samples else 0.0
+
+ # Record previous operation duration and memory delta
+ if self.current_operation != "Initialization" and self.samples:
+ duration = (current_time - self.operation_start_time).total_seconds()
+ self.operation_durations[self.current_operation] = duration
+
+ # Calculate memory delta from start of operation
+ start_idx = max(0, len(self.samples) - max(1, int(duration / self.config.sampling_interval)))
+ start_rss = self.samples[start_idx].rss_mb if start_idx < len(self.samples) else current_rss
+ memory_delta = current_rss - start_rss
+ self.operation_memory_deltas[self.current_operation] = memory_delta
+
+ # Start new operation
+ self.current_operation = operation_name
+ self.operation_start_time = current_time
+ self.operations.append((current_time, operation_name))
+
+ if self.config.verbose:
+ logger.info(f"{operation_name} | Memory: {current_rss:.1f} MB RSS")
+
+ def get_synchronized_data(self) -> Dict[str, List[float]]:
+ """Get synchronized data arrays."""
+ if not self.samples:
+ return {}
+
+ start_time = self.samples[0].timestamp
+ return {
+ "timestamps": [(s.timestamp - start_time).total_seconds() for s in self.samples],
+ "rss_memory": [s.rss_mb for s in self.samples],
+ "vms_memory": [s.vms_mb for s in self.samples],
+ "cpu_usage": [s.cpu_percent for s in self.samples],
+ "disk_read_bytes": [s.disk_read_mb for s in self.samples],
+ "disk_write_bytes": [s.disk_write_mb for s in self.samples],
+ "disk_read_rate": [s.disk_read_rate for s in self.samples],
+ "disk_write_rate": [s.disk_write_rate for s in self.samples],
+ }
+
+ def mark_segment(self, segment_name: str) -> None:
+ """Convenience method for manual segment marking (API mode)."""
+ self.mark_operation(segment_name)
+
+ def stop_and_save(self, filename: Optional[str] = None) -> str:
+ """Stop monitoring and save results (API mode convenience)."""
+ self.stop_monitoring()
+ self.generate_memory_graph(filename)
+ return self.get_memory_report()
+
+ def get_memory_report(self) -> str:
+ """Generate comprehensive memory usage report."""
+ if not self.samples:
+ return "No memory data collected"
+
+ current_sample = self.samples[-1]
+ initial_sample = self.samples[0]
+
+ # Calculate statistics
+ rss_values = [s.rss_mb for s in self.samples]
+ avg_rss = sum(rss_values) / len(rss_values)
+ max_rss = max(rss_values)
+ min_rss = min(rss_values)
+
+ # Auto-scale units
+ rss_scale, rss_unit = (1024, "GB") if max_rss > 2048 else (1, "MB")
+
+ # Calculate disk I/O statistics
+ disk_io_stats = ""
+ if self.samples and len(self.samples) > 1:
+ total_read = current_sample.disk_read_mb - initial_sample.disk_read_mb
+ total_write = current_sample.disk_write_mb - initial_sample.disk_write_mb
+ max_read_rate = max(s.disk_read_rate for s in self.samples)
+ max_write_rate = max(s.disk_write_rate for s in self.samples)
+ avg_read_rate = sum(s.disk_read_rate for s in self.samples) / len(self.samples)
+ avg_write_rate = sum(s.disk_write_rate for s in self.samples) / len(self.samples)
+
+ disk_io_stats = f"""
+Disk I/O Statistics:
+ β’ Total Read: {total_read:.2f} MB
+ β’ Total Write: {total_write:.2f} MB
+ β’ Peak Read Rate: {max_read_rate:.2f} MB/s
+ β’ Peak Write Rate:{max_write_rate:.2f} MB/s
+ β’ Avg Read Rate: {avg_read_rate:.2f} MB/s
+ β’ Avg Write Rate: {avg_write_rate:.2f} MB/s"""
+
+ report = f"""
+QEFFICIENT PERFORMANCE MONITORING REPORT
+{"=" * 60}
+Peak Memory Usage:
+ β’ RSS (Physical): {self.peak_rss / rss_scale:.2f} {rss_unit} at {self.peak_rss_time.strftime("%H:%M:%S") if self.peak_rss_time else "N/A"}
+ β’ VMS (Virtual): {self.peak_vms / rss_scale:.2f} {rss_unit} at {self.peak_vms_time.strftime("%H:%M:%S") if self.peak_vms_time else "N/A"}
+ β’ Peak during: {self.peak_operation}
+
+Memory Statistics:
+ β’ Current RSS: {current_sample.rss_mb / rss_scale:.2f} {rss_unit} (Delta: {(current_sample.rss_mb - initial_sample.rss_mb) / rss_scale:+.2f} {rss_unit})
+ β’ Current VMS: {current_sample.vms_mb / rss_scale:.2f} {rss_unit} (Delta: {(current_sample.vms_mb - initial_sample.vms_mb) / rss_scale:+.2f} {rss_unit})
+ β’ Average RSS: {avg_rss / rss_scale:.2f} {rss_unit}
+ β’ Min/Max RSS: {min_rss / rss_scale:.2f} / {max_rss / rss_scale:.2f} {rss_unit}
+ β’ Memory Range: {(max_rss - min_rss) / rss_scale:.2f} {rss_unit}{disk_io_stats}
+
+Monitoring Info:
+ β’ Duration: {(current_sample.timestamp - initial_sample.timestamp).total_seconds():.1f} seconds
+ β’ Data Points: {len(self.samples)}
+ β’ Operations: {len(self.operations)}
+ β’ Sampling Rate: {self.config.sampling_interval}s
+
+QEfficient Operations Timeline:"""
+
+ # Add operation timeline
+ if self.operations:
+ start_time = self.samples[0].timestamp
+ for i, (op_time, op_name) in enumerate(self.operations):
+ relative_time = (op_time - start_time).total_seconds()
+ duration = self.operation_durations.get(op_name, 0)
+ memory_delta = self.operation_memory_deltas.get(op_name, 0)
+
+ duration_str = f"({duration:.1f}s)" if duration > 0 else ""
+ memory_str = f"[{memory_delta / rss_scale:+.1f} {rss_unit}]" if abs(memory_delta) > 10 else ""
+
+ report += f"\n {i + 1:2d}. {relative_time:6.1f}s - {op_name} {duration_str} {memory_str}"
+
+ return report
+
+ def generate_memory_graph(self, filename: Optional[str] = None) -> None:
+ """Generate professional memory usage graph with QEfficient operation segments."""
+ if not self.samples:
+ logger.warning("No data to plot")
+ return
+
+ output_file = filename or self.config.output_file
+
+ # Import visualization module
+ from visualizer import QEffMemoryVisualizer
+
+ visualizer = QEffMemoryVisualizer(self)
+ visualizer.generate_professional_graph(output_file)
+
+ if self.config.verbose:
+ logger.info(f"QEfficient memory profile saved as: {output_file}")
+
+ # Legacy methods for backward compatibility
+ def get_memory_usage(self) -> Tuple[float, float]:
+ """Get current memory usage in MB (legacy method)."""
+ return self.metrics_collector.get_memory_usage()
diff --git a/scripts/memory_profiling/visualizer.py b/scripts/memory_profiling/visualizer.py
new file mode 100644
index 000000000..c16c0c0ef
--- /dev/null
+++ b/scripts/memory_profiling/visualizer.py
@@ -0,0 +1,604 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+QEfficient Memory Visualizer - Production Quality Enhanced Visualization
+
+This module provides production-quality visualization with detailed segment analysis,
+clear operation boundaries, and comprehensive memory metrics.
+"""
+
+from datetime import datetime
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+
+import matplotlib.patches as patches
+import matplotlib.pyplot as plt
+import numpy as np
+
+if TYPE_CHECKING:
+ from .profiler import QEffMemoryProfiler
+
+from QEfficient.utils.logging_utils import logger
+
+
+class QEffMemoryVisualizer:
+ """Production-quality memory visualization with enhanced segment analysis."""
+
+ def __init__(self, profiler: "QEffMemoryProfiler"):
+ """Initialize visualizer with profiler data."""
+ self.profiler = profiler
+ self._setup_matplotlib_style()
+
+ def _setup_matplotlib_style(self) -> None:
+ """Configure matplotlib for professional styling."""
+ plt.style.use("default")
+ plt.rcParams.update(
+ {
+ "font.size": 10,
+ "font.family": ["DejaVu Sans", "sans-serif"],
+ "axes.linewidth": 1.2,
+ "figure.facecolor": "white",
+ "axes.facecolor": "white",
+ "grid.alpha": 0.3,
+ "lines.linewidth": 2.0,
+ "axes.spines.top": False,
+ "axes.spines.right": False,
+ "axes.edgecolor": "#333333",
+ "text.color": "#333333",
+ "axes.labelcolor": "#333333",
+ "xtick.color": "#333333",
+ "ytick.color": "#333333",
+ }
+ )
+
+ def generate_professional_graph(self, filename: str) -> None:
+ """Generate enhanced multi-panel memory profile with synchronized visualization."""
+ if not self.profiler.samples:
+ logger.warning("No data to plot")
+ return
+
+ # Get synchronized data
+ sync_data = self.profiler.get_synchronized_data()
+
+ # Create figure with professional layout - Fixed spacing to prevent title overlap
+ fig = plt.figure(figsize=(20, 12), facecolor="white")
+ gs = fig.add_gridspec(
+ 3,
+ 2,
+ height_ratios=[2.5, 1.8, 1.2],
+ width_ratios=[1, 1],
+ hspace=0.35,
+ wspace=0.2,
+ left=0.05,
+ right=0.98,
+ top=0.90,
+ bottom=0.08,
+ )
+
+ # Create subplots
+ ax_memory = fig.add_subplot(gs[0, :]) # Memory usage (full width)
+ ax_cpu = fig.add_subplot(gs[1, :]) # CPU usage (full width)
+ ax_disk = fig.add_subplot(gs[2, 0]) # Disk I/O (left)
+ ax_timing = fig.add_subplot(gs[2, 1]) # Phase Duration (right)
+
+ # Prepare data
+ relative_times = sync_data["timestamps"]
+ max_rss = max(sync_data["rss_memory"]) if sync_data["rss_memory"] else 0
+ use_gb = max_rss > 2048
+ scale = 1024 if use_gb else 1
+ unit = "GB" if use_gb else "MB"
+ rss_scaled = [x / scale for x in sync_data["rss_memory"]]
+
+ # Normalize CPU usage to prevent > 100% values (multi-core issue)
+ normalized_cpu = [min(cpu, 100.0) for cpu in sync_data["cpu_usage"]]
+
+ # Setup plots
+ self._setup_memory_plot(ax_memory, relative_times, rss_scaled, scale, unit)
+ self._setup_cpu_plot(ax_cpu, relative_times, normalized_cpu)
+ self._setup_disk_io_plot(ax_disk, sync_data)
+ self._setup_timing_plot(ax_timing)
+
+ # Add main title with proper spacing
+ fig.suptitle(
+ "QEfficient Enhanced Memory & Performance Analysis - Synchronized View",
+ fontsize=18,
+ fontweight="bold",
+ color="#2E86AB",
+ y=0.95,
+ )
+
+ # Save with high quality
+ plt.savefig(
+ filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none", format="png", pad_inches=0.2
+ )
+ plt.close()
+
+ logger.info(f"Enhanced synchronized memory profile saved: {filename}")
+
+ def _setup_memory_plot(
+ self, ax, relative_times: List[float], rss_scaled: List[float], scale: float, unit: str
+ ) -> None:
+ """Setup the main memory usage plot with enhanced visualization."""
+ if not relative_times or not rss_scaled:
+ ax.text(
+ 0.5,
+ 0.5,
+ "No memory data available",
+ ha="center",
+ va="center",
+ transform=ax.transAxes,
+ fontsize=12,
+ color="#666666",
+ )
+ return
+
+ start_time = self.profiler.samples[0].timestamp
+
+ # Draw segment backgrounds
+ self._draw_segment_backgrounds(ax, relative_times, rss_scaled, start_time)
+
+ # Main memory line
+ ax.plot(
+ relative_times, rss_scaled, color="#2E86AB", linewidth=3.5, label="Memory Usage (RSS)", alpha=0.9, zorder=5
+ )
+ ax.fill_between(relative_times, rss_scaled, alpha=0.15, color="#2E86AB", zorder=1)
+
+ # Add segment boundaries and annotations
+ self._draw_segment_boundaries(ax, start_time, max(rss_scaled))
+ self._mark_peak_memory(ax, start_time, scale, unit)
+
+ # Format axes
+ ax.set_xlabel("Time (seconds)", fontsize=13, fontweight="bold")
+ ax.set_ylabel(f"Memory Usage ({unit})", fontsize=13, fontweight="bold")
+ ax.set_xlim(0, max(relative_times) * 1.02)
+ ax.set_ylim(0, max(rss_scaled) * 1.15)
+ ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.8, color="#CCCCCC")
+ ax.set_axisbelow(True)
+
+ # Enhanced title
+ total_duration = relative_times[-1] if relative_times else 0
+ peak_memory = max(rss_scaled) if rss_scaled else 0
+ ax.set_title(
+ f"Memory Usage Over Time | Peak: {peak_memory:.1f} {unit} | Duration: {total_duration:.1f}s",
+ fontsize=14,
+ fontweight="bold",
+ color="#2E86AB",
+ pad=15,
+ )
+
+ # Add legend
+ self._add_segment_legend(ax)
+
+ def _setup_cpu_plot(self, ax, relative_times: List[float], cpu_usage: List[float]) -> None:
+ """Setup CPU plot with perfect synchronization to memory plot."""
+ if not relative_times or not cpu_usage or len(cpu_usage) != len(relative_times):
+ ax.text(
+ 0.5,
+ 0.5,
+ "CPU data not available or not synchronized",
+ ha="center",
+ va="center",
+ transform=ax.transAxes,
+ fontsize=12,
+ color="#666666",
+ )
+ ax.set_title("CPU Usage Over Time", fontsize=14, fontweight="bold")
+ if relative_times:
+ ax.set_xlim(0, max(relative_times) * 1.02)
+ return
+
+ start_time = self.profiler.samples[0].timestamp
+
+ # Draw segment backgrounds for consistency
+ self._draw_segment_backgrounds(ax, relative_times, cpu_usage, start_time, max_val=100)
+
+ # Main CPU line
+ ax.plot(relative_times, cpu_usage, color="#FF6B35", linewidth=3, label="CPU Usage", alpha=0.9, zorder=5)
+ ax.fill_between(relative_times, cpu_usage, alpha=0.2, color="#FF6B35", zorder=1)
+
+ # Add segment boundaries
+ self._draw_segment_boundaries(ax, start_time, max(cpu_usage) if cpu_usage else 100)
+
+ # Add average line
+ avg_cpu = sum(cpu_usage) / len(cpu_usage)
+ ax.axhline(
+ y=avg_cpu,
+ color="#E74C3C",
+ linestyle="-",
+ alpha=0.8,
+ linewidth=2.5,
+ label=f"Average: {avg_cpu:.1f}%",
+ zorder=4,
+ )
+
+ # Add performance zones
+ ax.axhspan(0, 25, alpha=0.08, color="#4CAF50", zorder=0)
+ ax.axhspan(25, 50, alpha=0.08, color="#FFC107", zorder=0)
+ ax.axhspan(50, 75, alpha=0.08, color="#FF9800", zorder=0)
+ ax.axhspan(75, 100, alpha=0.08, color="#F44336", zorder=0)
+
+ # Format axes
+ ax.set_ylabel("CPU Usage (%)", fontsize=13, fontweight="bold")
+ ax.set_xlabel("Time (seconds)", fontsize=12, fontweight="bold")
+ ax.set_xlim(0, max(relative_times) * 1.02)
+ ax.set_ylim(0, max(cpu_usage) * 1.1 if cpu_usage else 100)
+ ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.8, color="#CCCCCC")
+ ax.set_axisbelow(True)
+
+ # Enhanced title
+ max_cpu = max(cpu_usage)
+ ax.set_title(
+ f"CPU Usage Over Time | Peak: {max_cpu:.1f}% | Average: {avg_cpu:.1f}%",
+ fontsize=14,
+ fontweight="bold",
+ color="#FF6B35",
+ pad=15,
+ )
+
+ # Compact legend
+ ax.legend(loc="upper right", fontsize=10, framealpha=0.9)
+
+ def _setup_disk_io_plot(self, ax, sync_data: Dict[str, List[float]]) -> None:
+ """Setup enhanced disk I/O plot showing phase-based analysis."""
+ if not self.profiler.operations or len(self.profiler.operations) < 2:
+ ax.text(
+ 0.5,
+ 0.5,
+ "No operation phases available",
+ ha="center",
+ va="center",
+ transform=ax.transAxes,
+ fontsize=12,
+ color="#666666",
+ )
+ ax.set_title("Disk I/O per Phase", fontsize=14, fontweight="bold")
+ return
+
+ # Calculate I/O per phase
+ operations, read_totals, write_totals = self._calculate_io_per_phase(sync_data)
+
+ if not operations:
+ ax.text(
+ 0.5,
+ 0.5,
+ "No significant disk I/O detected",
+ ha="center",
+ va="center",
+ transform=ax.transAxes,
+ fontsize=12,
+ color="#666666",
+ )
+ ax.set_title("Disk I/O per Phase", fontsize=14, fontweight="bold")
+ return
+
+ # Create enhanced bar chart
+ x_pos = np.arange(len(operations))
+ bar_width = 0.35
+
+ bars_read = ax.bar(
+ x_pos - bar_width / 2,
+ read_totals,
+ bar_width,
+ label="Read (MB)",
+ color="#2196F3",
+ alpha=0.8,
+ edgecolor="white",
+ linewidth=1.5,
+ )
+ bars_write = ax.bar(
+ x_pos + bar_width / 2,
+ write_totals,
+ bar_width,
+ label="Write (MB)",
+ color="#FF5722",
+ alpha=0.8,
+ edgecolor="white",
+ linewidth=1.5,
+ )
+
+ # Add value labels
+ self._add_bar_labels(ax, bars_read, bars_write, read_totals, write_totals)
+
+ # Format axes
+ ax.set_ylabel("Total I/O (MB)", fontsize=12, fontweight="bold")
+ ax.set_xlabel("Operation Phase", fontsize=11, fontweight="bold")
+ ax.set_xticks(x_pos)
+ ax.set_xticklabels(operations, rotation=45, ha="right", fontsize=10)
+
+ max_val = max(max(read_totals) if read_totals else [0], max(write_totals) if write_totals else [0])
+ ax.set_ylim(0, max_val * 1.25 if max_val > 0 else 1)
+ ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.5, color="#CCCCCC", axis="y")
+ ax.set_title("Disk I/O per Operation Phase", fontsize=14, fontweight="bold", pad=15)
+ ax.legend(loc="upper right", fontsize=10, framealpha=0.9)
+
+ # Summary statistics
+ total_read = sum(read_totals)
+ total_write = sum(write_totals)
+ ax.text(
+ 0.02,
+ 0.98,
+ f"Total I/O: {total_read:.1f} MB read, {total_write:.1f} MB write",
+ transform=ax.transAxes,
+ fontsize=10,
+ va="top",
+ ha="left",
+ bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor="gray", linewidth=1),
+ )
+
+ def _setup_timing_plot(self, ax) -> None:
+ """Setup enhanced timing analysis plot."""
+ operations, durations, colors = self._get_timing_data()
+
+ if not operations:
+ ax.text(
+ 0.5,
+ 0.5,
+ "No timing data available",
+ ha="center",
+ va="center",
+ transform=ax.transAxes,
+ fontsize=12,
+ color="#666666",
+ )
+ ax.set_title("Phase Duration Analysis", fontsize=14, fontweight="bold")
+ return
+
+ # Enhanced horizontal bar chart
+ y_pos = np.arange(len(operations))
+ bars = ax.barh(y_pos, durations, color=colors, alpha=0.8, edgecolor="white", linewidth=1.5, height=0.6)
+
+ # Add duration labels
+ self._add_duration_labels(ax, bars, durations)
+
+ # Format axes
+ ax.set_yticks(y_pos)
+ ax.set_yticklabels(operations, fontsize=11)
+ ax.set_xlabel("Duration (seconds)", fontsize=12, fontweight="bold")
+ ax.set_title("Phase Duration Analysis", fontsize=14, fontweight="bold", pad=15)
+ ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.5, color="#CCCCCC", axis="x")
+ ax.set_xlim(0, max(durations) * 1.2)
+
+ # Add total duration summary
+ total_duration = sum(durations)
+ ax.text(
+ 0.98,
+ 0.02,
+ f"Total: {total_duration:.1f}s",
+ transform=ax.transAxes,
+ fontsize=10,
+ va="bottom",
+ ha="right",
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.9, edgecolor="gray", linewidth=1),
+ )
+
+ def _draw_segment_backgrounds(
+ self,
+ ax,
+ relative_times: List[float],
+ data_values: List[float],
+ start_time: datetime,
+ max_val: Optional[float] = None,
+ ) -> None:
+ """Draw colored background segments for each operation."""
+ if len(self.profiler.operations) < 2:
+ return
+
+ max_value = max_val or (max(data_values) * 1.1 if data_values else 100)
+
+ for i in range(len(self.profiler.operations) - 1):
+ op_start_time = (self.profiler.operations[i][0] - start_time).total_seconds()
+ op_end_time = (self.profiler.operations[i + 1][0] - start_time).total_seconds()
+ op_name = self.profiler.operations[i][1]
+
+ color = self.profiler.SEGMENT_COLORS.get(op_name, "#F0F0F0")
+
+ rect = patches.Rectangle(
+ (op_start_time, 0),
+ op_end_time - op_start_time,
+ max_value,
+ linewidth=0,
+ facecolor=color,
+ alpha=0.15,
+ zorder=0,
+ )
+ ax.add_patch(rect)
+
+ def _draw_segment_boundaries(self, ax, start_time: datetime, max_value: float) -> None:
+ """Draw vertical lines at segment boundaries."""
+ for i, (op_time, op_name) in enumerate(self.profiler.operations):
+ if i == 0:
+ continue
+
+ boundary_time = (op_time - start_time).total_seconds()
+ ax.axvline(x=boundary_time, color="#666666", linestyle="--", alpha=0.6, linewidth=2, zorder=3)
+
+ def _mark_peak_memory(self, ax, start_time: datetime, scale: float, unit: str) -> None:
+ """Mark peak memory with enhanced annotation."""
+ if not self.profiler.peak_rss_time:
+ return
+
+ peak_time_rel = (self.profiler.peak_rss_time - start_time).total_seconds()
+ peak_rss_scaled = self.profiler.peak_rss / scale
+
+ # Enhanced peak marker
+ ax.plot(
+ peak_time_rel,
+ peak_rss_scaled,
+ "o",
+ color="#E74C3C",
+ markersize=14,
+ markeredgecolor="white",
+ markeredgewidth=3,
+ zorder=10,
+ label="Peak Memory",
+ )
+
+ # Enhanced annotation
+ peak_text = f"Peak: {peak_rss_scaled:.1f} {unit}\nPhase: {self.profiler.peak_operation}"
+ ax.annotate(
+ peak_text,
+ xy=(peak_time_rel, peak_rss_scaled),
+ xytext=(25, 25),
+ textcoords="offset points",
+ bbox=dict(boxstyle="round,pad=0.6", facecolor="#E74C3C", alpha=0.95, edgecolor="white", linewidth=2),
+ arrowprops=dict(arrowstyle="->", color="#E74C3C", lw=2.5),
+ fontsize=11,
+ fontweight="bold",
+ color="white",
+ ha="left",
+ va="bottom",
+ zorder=15,
+ )
+
+ def _add_segment_legend(self, ax) -> None:
+ """Add enhanced segment legend with better styling."""
+ if not self.profiler.operations:
+ return
+
+ unique_ops = []
+ seen_ops = set()
+ for _, op_name in self.profiler.operations:
+ if op_name not in seen_ops and op_name not in ["Initialization", "Completion"]:
+ unique_ops.append(op_name)
+ seen_ops.add(op_name)
+
+ if not unique_ops:
+ return
+
+ legend_elements = []
+ for op_name in unique_ops:
+ color = self.profiler.SEGMENT_COLORS.get(op_name, "#666666")
+ duration = self.profiler.operation_durations.get(op_name, 0)
+
+ label = f"{op_name} ({duration:.1f}s)" if duration > 0 else op_name
+ legend_elements.append(patches.Patch(color=color, alpha=0.8, label=label))
+
+ legend = ax.legend(
+ handles=legend_elements,
+ loc="upper left",
+ bbox_to_anchor=(1.01, 1.0),
+ fontsize=11,
+ title="QEfficient Phases",
+ title_fontsize=12,
+ framealpha=0.95,
+ edgecolor="#2E86AB",
+ fancybox=True,
+ )
+ legend.get_frame().set_facecolor("#F8F9FA")
+
+ def _calculate_io_per_phase(self, sync_data: Dict[str, List[float]]) -> Tuple[List[str], List[float], List[float]]:
+ """Calculate I/O totals per operation phase."""
+ operations = []
+ read_totals = []
+ write_totals = []
+
+ valid_operations = [
+ (op_time, op_name)
+ for op_time, op_name in self.profiler.operations
+ if op_name not in ["Initialization", "Completion"]
+ ]
+
+ if not valid_operations:
+ return operations, read_totals, write_totals
+
+ relative_times = sync_data["timestamps"]
+ start_time = self.profiler.samples[0].timestamp
+
+ for i, (op_time, op_name) in enumerate(valid_operations):
+ op_start_time = (op_time - start_time).total_seconds()
+
+ if i + 1 < len(valid_operations):
+ op_end_time = (valid_operations[i + 1][0] - start_time).total_seconds()
+ else:
+ op_end_time = max(relative_times) if relative_times else op_start_time + 1
+
+ # Find data indices
+ start_idx = next((j for j, t in enumerate(relative_times) if t >= op_start_time), 0)
+ end_idx = next((j for j, t in enumerate(relative_times) if t >= op_end_time), len(relative_times) - 1)
+
+ if start_idx < len(sync_data["disk_read_bytes"]) and end_idx < len(sync_data["disk_read_bytes"]):
+ read_total = sync_data["disk_read_bytes"][end_idx] - sync_data["disk_read_bytes"][start_idx]
+ write_total = sync_data["disk_write_bytes"][end_idx] - sync_data["disk_write_bytes"][start_idx]
+
+ if read_total > 0.01 or write_total > 0.01:
+ operations.append(op_name)
+ read_totals.append(max(0, read_total))
+ write_totals.append(max(0, write_total))
+
+ return operations, read_totals, write_totals
+
+ def _get_timing_data(self) -> Tuple[List[str], List[float], List[str]]:
+ """Get timing data for operations."""
+ operations = []
+ durations = []
+ colors = []
+
+ for op_time, op_name in self.profiler.operations:
+ if op_name in ["Initialization", "Completion"]:
+ continue
+ duration = self.profiler.operation_durations.get(op_name, 0)
+ if duration > 0:
+ operations.append(op_name)
+ durations.append(duration)
+ colors.append(self.profiler.SEGMENT_COLORS.get(op_name, "#666666"))
+
+ return operations, durations, colors
+
+ def _add_bar_labels(self, ax, bars_read, bars_write, read_totals: List[float], write_totals: List[float]) -> None:
+ """Add value labels on bars."""
+ max_val = max(max(read_totals) if read_totals else [0], max(write_totals) if write_totals else [0])
+
+ for i, (read_bar, write_bar, read_val, write_val) in enumerate(
+ zip(bars_read, bars_write, read_totals, write_totals)
+ ):
+ if read_val > 0.01:
+ ax.text(
+ read_bar.get_x() + read_bar.get_width() / 2,
+ read_bar.get_height() + max_val * 0.02,
+ f"{read_val:.1f}",
+ ha="center",
+ va="bottom",
+ fontsize=9,
+ fontweight="bold",
+ color="#2196F3",
+ )
+
+ if write_val > 0.01:
+ ax.text(
+ write_bar.get_x() + write_bar.get_width() / 2,
+ write_bar.get_height() + max_val * 0.02,
+ f"{write_val:.1f}",
+ ha="center",
+ va="bottom",
+ fontsize=9,
+ fontweight="bold",
+ color="#FF5722",
+ )
+
+ def _add_duration_labels(self, ax, bars, durations: List[float]) -> None:
+ """Add duration labels on timing bars."""
+ max_duration = max(durations)
+
+ for i, (bar, duration) in enumerate(zip(bars, durations)):
+ width = bar.get_width()
+ minutes = int(duration // 60)
+ seconds = duration % 60
+
+ if minutes > 0:
+ duration_text = f"{minutes}m {seconds:.1f}s"
+ else:
+ duration_text = f"{seconds:.1f}s"
+
+ ax.text(
+ width + max_duration * 0.02,
+ bar.get_y() + bar.get_height() / 2,
+ duration_text,
+ ha="left",
+ va="center",
+ fontsize=10,
+ fontweight="bold",
+ )
diff --git a/scripts/replicate_kv_head/replicate_kv_heads.py b/scripts/replicate_kv_head/replicate_kv_heads.py
index e2e78105a..01cadaa5b 100644
--- a/scripts/replicate_kv_head/replicate_kv_heads.py
+++ b/scripts/replicate_kv_head/replicate_kv_heads.py
@@ -11,7 +11,7 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
-from QEfficient import QEFFAutoModelForCausalLM, export
+from QEfficient import QEFFAutoModelForCausalLM
from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
@@ -160,11 +160,8 @@ def replicate_kv_heads(
# Export the modified model
q_model = QEFFAutoModelForCausalLM(model, continuous_batching=(True if full_batch_size else False))
- export(
- model_name,
- q_model,
- tokenizer=tokenizer,
- onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads",
+ q_model.export(
+ export_dir=f"{model_base_name}-{new_kv_heads}kvheads",
full_batch_size=(full_batch_size if full_batch_size else None),
)
diff --git a/tests/base/test_export_memory_offload.py b/tests/base/test_export_memory_offload.py
index d1b7a4653..f63b18f1a 100644
--- a/tests/base/test_export_memory_offload.py
+++ b/tests/base/test_export_memory_offload.py
@@ -27,7 +27,7 @@
@pytest.fixture
def tmp_cache(tmp_path, monkeypatch):
- monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path)
+ monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path)
yield tmp_path
diff --git a/tests/diffusers/diffusers_utils.py b/tests/diffusers/diffusers_utils.py
new file mode 100644
index 000000000..4e407c5aa
--- /dev/null
+++ b/tests/diffusers/diffusers_utils.py
@@ -0,0 +1,174 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+Common utilities for diffusion pipeline testing.
+Provides essential functions for MAD validation, image validation
+hash verification, and other testing utilities.
+"""
+
+import os
+from typing import Any, Dict, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+
+class DiffusersTestUtils:
+ """Essential utilities for diffusion pipeline testing"""
+
+ @staticmethod
+ def validate_image_generation(
+ image: Image.Image, expected_size: Tuple[int, int], min_variance: float = 1.0
+ ) -> Dict[str, Any]:
+ """
+ Validate generated image properties.
+ Args:
+ image: Generated PIL Image
+ expected_size: Expected (width, height) tuple
+ min_variance: Minimum pixel variance to ensure image is not blank
+
+ Returns:
+ Dict containing validation results
+ Raises:
+ AssertionError: If image validation fails
+ """
+ # Basic image validation
+ assert isinstance(image, Image.Image), f"Expected PIL Image, got {type(image)}"
+ assert image.size == expected_size, f"Expected size {expected_size}, got {image.size}"
+ assert image.mode in ["RGB", "RGBA"], f"Unexpected image mode: {image.mode}"
+
+ # Variance check (ensure image is not blank)
+ img_array = np.array(image)
+ image_variance = float(img_array.std())
+ assert image_variance > min_variance, f"Generated image appears blank (variance: {image_variance:.2f})"
+
+ return {
+ "size": image.size,
+ "mode": image.mode,
+ "variance": image_variance,
+ "mean_pixel_value": float(img_array.mean()),
+ "min_pixel": int(img_array.min()),
+ "max_pixel": int(img_array.max()),
+ "valid": True,
+ }
+
+ @staticmethod
+ def check_file_exists(file_path: str, file_type: str = "file") -> bool:
+ """
+ Check if file exists and log result.
+ Args:
+ file_path: Path to check
+ file_type: Description of file type for logging
+ Returns:
+ bool: True if file exists
+ """
+ exists = os.path.exists(file_path)
+ print(f"file exist: {exists}; {file_type}: {file_path}")
+ return exists
+
+ @staticmethod
+ def print_test_header(title: str, config: Dict[str, Any]) -> None:
+ """
+ Print formatted test header with configuration details.
+
+ Args:
+ title: Test title
+ config: Test configuration dictionary
+ """
+ print(f"\n{'=' * 80}")
+ print(f"{title}")
+ print(f"{'=' * 80}")
+
+ if "model_setup" in config:
+ setup = config["model_setup"]
+ for k, v in setup.items():
+ print(f"{k} : {v}")
+
+ if "functional_testing" in config:
+ func = config["functional_testing"]
+ print(f"Test Prompt: {func.get('test_prompt', 'N/A')}")
+ print(f"Inference Steps: {func.get('num_inference_steps', 'N/A')}")
+ print(f"Guidance Scale: {func.get('guidance_scale', 'N/A')}")
+
+ print(f"{'=' * 80}")
+
+
+class MADValidator:
+ """Specialized class for MAD validation - always enabled, always reports, always fails on exceed"""
+
+ def __init__(self, tolerances: Dict[str, float] = None):
+ """
+ Initialize MAD validator.
+ MAD validation is always enabled, always reports values, and always fails if tolerance is exceeded.
+
+ Args:
+ tolerances: Dictionary of module_name -> tolerance mappings
+ """
+ self.tolerances = tolerances
+ self.results = {}
+
+ def calculate_mad(
+ self, tensor1: Union[torch.Tensor, np.ndarray], tensor2: Union[torch.Tensor, np.ndarray]
+ ) -> float:
+ """
+ Calculate Max Absolute Deviation between two tensors.
+
+ Args:
+ tensor1: First tensor (PyTorch or NumPy)
+ tensor2: Second tensor (PyTorch or NumPy)
+
+ Returns:
+ float: Maximum absolute difference between tensors
+ """
+ if isinstance(tensor1, torch.Tensor):
+ tensor1 = tensor1.detach().numpy()
+ if isinstance(tensor2, torch.Tensor):
+ tensor2 = tensor2.detach().numpy()
+
+ return float(np.max(np.abs(tensor1 - tensor2)))
+
+ def validate_module_mad(
+ self,
+ pytorch_output: Union[torch.Tensor, np.ndarray],
+ qaic_output: Union[torch.Tensor, np.ndarray],
+ module_name: str,
+ step_info: str = "",
+ ) -> bool:
+ """
+ Validate MAD for a specific module.
+ Always validates, always reports, always fails if tolerance exceeded.
+
+ Args:
+ pytorch_output: PyTorch reference output
+ qaic_output: QAIC inference output
+ module_name: Name of the module
+ step_info: Additional step information for logging
+
+ Returns:
+ bool: True if validation passed
+
+ Raises:
+ AssertionError: If MAD exceeds tolerance
+ """
+ mad_value = self.calculate_mad(pytorch_output, qaic_output)
+
+ # Always report MAD value
+ step_str = f" {step_info}" if step_info else ""
+ print(f"{module_name.upper()} MAD{step_str}: {mad_value:.8f}")
+
+ # Always validate - fail if tolerance exceeded
+ tolerance = self.tolerances.get(module_name, 1e-2)
+ if mad_value > tolerance:
+ raise AssertionError(f"{module_name} MAD {mad_value:.6f} exceeds tolerance {tolerance:.6f}")
+
+ # Store result
+ if module_name not in self.results:
+ self.results[module_name] = []
+ self.results[module_name].append({"mad": mad_value, "step_info": step_info, "tolerance": tolerance})
+ return True
diff --git a/tests/diffusers/flux_test_config.json b/tests/diffusers/flux_test_config.json
new file mode 100644
index 000000000..9f13daca0
--- /dev/null
+++ b/tests/diffusers/flux_test_config.json
@@ -0,0 +1,123 @@
+{
+ "model_setup": {
+ "height": 256,
+ "width": 256,
+ "num_transformer_layers": 2,
+ "num_single_layers": 2,
+ "use_onnx_subfunctions": false
+ },
+ "mad_validation": {
+ "tolerances": {
+ "clip_text_encoder": 0.1,
+ "t5_text_encoder": 5.5,
+ "transformer": 2.0,
+ "vae_decoder": 1.0
+ }
+ },
+ "pipeline_params": {
+ "test_prompt": "A cat holding a sign that says hello world",
+ "num_inference_steps": 2,
+ "guidance_scale": 0.0,
+ "max_sequence_length": 256,
+ "validate_gen_img": true,
+ "min_image_variance": 1.0,
+ "custom_config_path": null
+ },
+ "validation_checks": {
+ "image_generation": true,
+ "onnx_export": true,
+ "compilation": true
+ },
+ "modules":
+ {
+ "text_encoder":
+ {
+ "specializations":{
+ "batch_size": 1,
+ "seq_len": 77
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+
+ },
+ "text_encoder_2":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "transformer":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts-mos": 1,
+ "aic-enable-depth-first": true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "vae_decoder":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "channels": 16
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ }
+ }
+
+}
diff --git a/tests/diffusers/test_flux.py b/tests/diffusers/test_flux.py
new file mode 100644
index 000000000..721850257
--- /dev/null
+++ b/tests/diffusers/test_flux.py
@@ -0,0 +1,448 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import os
+import time
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import pytest
+import torch
+from diffusers import FluxPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+
+from QEfficient import QEffFluxPipeline
+from QEfficient.diffusers.pipelines.pipeline_utils import (
+ ModulePerf,
+ QEffPipelineOutput,
+ set_module_device_ids,
+)
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils._utils import load_json
+from tests.diffusers.diffusers_utils import DiffusersTestUtils, MADValidator
+
+# Test Configuration for 256x256 resolution with 2 layers # update mad tolerance
+CONFIG_PATH = "tests/diffusers/flux_test_config.json"
+INITIAL_TEST_CONFIG = load_json(CONFIG_PATH)
+
+
+def flux_pipeline_call_with_mad_validation(
+ pipeline,
+ pytorch_pipeline,
+ height: int = 256,
+ width: int = 256,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ custom_config_path: Optional[str] = None,
+ parallel_compile: bool = False,
+ mad_tolerances: Dict[str, float] = None,
+):
+ """
+ Pipeline call function that replicates the exact flow of pipeline_flux.py.__call__()
+ while adding comprehensive MAD validation at each step.
+
+ This function follows the EXACT same structure as QEffFluxPipeline.__call__()
+ but adds MAD validation hooks throughout the process.
+ """
+ # Initialize MAD validator
+ mad_validator = MADValidator(tolerances=mad_tolerances)
+
+ device = "cpu"
+
+ # Step 1: Load configuration, compile models
+ pipeline.compile(compile_config=custom_config_path, parallel=parallel_compile, height=height, width=width)
+
+ # Set device IDs for all modules based on configuration
+ set_module_device_ids(pipeline)
+
+ # Validate all inputs
+ pipeline.model.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # Set pipeline attributes
+ pipeline._guidance_scale = guidance_scale
+ pipeline._interrupt = False
+ batch_size = INITIAL_TEST_CONFIG["modules"]["transformer"]["specializations"]["batch_size"]
+
+ # Step 3: Encode prompts with both text encoders
+ # Use pipeline's encode_prompt method
+ (t5_qaic_prompt_embeds, clip_qaic_pooled_prompt_embeds, text_ids, text_encoder_perf) = pipeline.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ (t5_torch_prompt_embeds, clip_torch_pooled_prompt_embeds, text_ids) = pytorch_pipeline.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ # Deactivate text encoder qpc sessions
+ pipeline.text_encoder.qpc_session.deactivate()
+ pipeline.text_encoder_2.qpc_session.deactivate()
+
+ # MAD Validation for Text Encoders
+ print(" Performing MAD validation for text encoders...")
+ mad_validator.validate_module_mad(
+ clip_qaic_pooled_prompt_embeds, clip_torch_pooled_prompt_embeds, module_name="clip_text_encoder"
+ )
+ mad_validator.validate_module_mad(t5_torch_prompt_embeds, t5_qaic_prompt_embeds, "t5_text_encoder")
+
+ # Step 4: Prepare timesteps for denoising
+ timesteps, num_inference_steps = retrieve_timesteps(pipeline.scheduler, num_inference_steps, device, timesteps)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0)
+ pipeline._num_timesteps = len(timesteps)
+
+ # Step 5: Prepare initial latents
+ num_channels_latents = pipeline.transformer.model.config.in_channels // 4
+ latents, latent_image_ids = pipeline.model.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ t5_qaic_prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # Step 6: Initialize transformer inference session
+ if pipeline.transformer.qpc_session is None:
+ pipeline.transformer.qpc_session = QAICInferenceSession(
+ str(pipeline.transformer.qpc_path), device_ids=pipeline.transformer.device_ids
+ )
+
+ # Calculate compressed latent dimension (cl) for transformer buffer allocation
+ from QEfficient.diffusers.pipelines.pipeline_utils import calculate_compressed_latent_dimension
+
+ cl, _, _ = calculate_compressed_latent_dimension(height, width, pipeline.model.vae_scale_factor)
+
+ # Allocate output buffer for transformer
+ output_buffer = {
+ "output": np.random.rand(batch_size, cl, pipeline.transformer.model.config.in_channels).astype(np.float32),
+ }
+ pipeline.transformer.qpc_session.set_buffers(output_buffer)
+
+ transformer_perf = []
+ pipeline.scheduler.set_begin_index(0)
+
+ # Step 7: Denoising loop
+ with pipeline.model.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if pipeline._interrupt:
+ continue
+
+ # Prepare timestep embedding
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ temb = pipeline.transformer.model.time_text_embed(timestep, clip_qaic_pooled_prompt_embeds)
+
+ # Compute AdaLN embeddings for dual transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(pipeline.transformer.model.transformer_blocks)):
+ block = pipeline.transformer.model.transformer_blocks[block_idx]
+ f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1)
+ f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1)
+ adaln_emb.append(torch.cat(list(f1) + list(f2)))
+ adaln_dual_emb = torch.stack(adaln_emb)
+
+ # Compute AdaLN embeddings for single transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(pipeline.transformer.model.single_transformer_blocks)):
+ block = pipeline.transformer.model.single_transformer_blocks[block_idx]
+ f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1)
+ adaln_emb.append(torch.cat(list(f1)))
+ adaln_single_emb = torch.stack(adaln_emb)
+
+ # Compute output AdaLN embedding
+ temp = pipeline.transformer.model.norm_out
+ adaln_out = temp.linear(temp.silu(temb))
+
+ # Normalize timestep to [0, 1] range
+ timestep = timestep / 1000
+
+ # Prepare all inputs for transformer inference
+ inputs_aic = {
+ "hidden_states": latents.detach().numpy(),
+ "encoder_hidden_states": t5_qaic_prompt_embeds.detach().numpy(),
+ "pooled_projections": clip_qaic_pooled_prompt_embeds.detach().numpy(),
+ "timestep": timestep.detach().numpy(),
+ "img_ids": latent_image_ids.detach().numpy(),
+ "txt_ids": text_ids.detach().numpy(),
+ "adaln_emb": adaln_dual_emb.detach().numpy(),
+ "adaln_single_emb": adaln_single_emb.detach().numpy(),
+ "adaln_out": adaln_out.detach().numpy(),
+ }
+
+ # MAD Validation for Transformer - PyTorch reference inference
+ noise_pred_torch = pytorch_pipeline.transformer(
+ hidden_states=latents,
+ encoder_hidden_states=t5_torch_prompt_embeds,
+ pooled_projections=clip_torch_pooled_prompt_embeds,
+ timestep=torch.tensor(timestep),
+ img_ids=latent_image_ids,
+ txt_ids=text_ids,
+ return_dict=False,
+ )[0]
+
+ # Run transformer inference and measure time
+ start_transformer_step_time = time.time()
+ outputs = pipeline.transformer.qpc_session.run(inputs_aic)
+ end_transformer_step_time = time.time()
+ transformer_perf.append(end_transformer_step_time - start_transformer_step_time)
+
+ noise_pred = torch.from_numpy(outputs["output"])
+
+ # Transformer MAD validation
+ mad_validator.validate_module_mad(
+ noise_pred_torch.detach().cpu().numpy(),
+ outputs["output"],
+ "transformer",
+ f"step {i} (t={t.item():.1f})",
+ )
+
+ # Update latents using scheduler
+ latents_dtype = latents.dtype
+ latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Handle dtype mismatch
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ latents = latents.to(latents_dtype)
+
+ # Update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
+ progress_bar.update()
+
+ # Step 8: Decode latents to images
+ if output_type == "latent":
+ image = latents
+ vae_decode_perf = 0.0 # No VAE decoding for latent output
+ else:
+ # Unpack and denormalize latents
+ latents = pipeline.model._unpack_latents(latents, height, width, pipeline.model.vae_scale_factor)
+
+ # Denormalize latents
+ latents = (latents / pipeline.vae_decode.model.scaling_factor) + pipeline.vae_decode.model.shift_factor
+ # Initialize VAE decoder inference session
+ if pipeline.vae_decode.qpc_session is None:
+ pipeline.vae_decode.qpc_session = QAICInferenceSession(
+ str(pipeline.vae_decode.qpc_path), device_ids=pipeline.vae_decode.device_ids
+ )
+
+ # Allocate output buffer for VAE decoder
+ output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.float32)}
+ pipeline.vae_decode.qpc_session.set_buffers(output_buffer)
+
+ # MAD Validation for VAE
+ # PyTorch reference inference
+ image_torch = pytorch_pipeline.vae.decode(latents, return_dict=False)[0]
+
+ # Run VAE decoder inference and measure time
+ inputs = {"latent_sample": latents.numpy()}
+ start_decode_time = time.time()
+ image = pipeline.vae_decode.qpc_session.run(inputs)
+ end_decode_time = time.time()
+ vae_decode_perf = end_decode_time - start_decode_time
+
+ # VAE MAD validation
+ mad_validator.validate_module_mad(image_torch.detach().cpu().numpy(), image["sample"], "vae_decoder")
+
+ # Post-process image
+ image_tensor = torch.from_numpy(image["sample"])
+ image = pipeline.model.image_processor.postprocess(image_tensor, output_type=output_type)
+
+ # Build performance metrics
+ perf_metrics = [
+ ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]),
+ ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]),
+ ModulePerf(module_name="transformer", perf=transformer_perf),
+ ModulePerf(module_name="vae_decoder", perf=vae_decode_perf),
+ ]
+
+ return QEffPipelineOutput(
+ pipeline_module=perf_metrics,
+ images=image,
+ )
+
+
+@pytest.fixture(scope="session")
+def flux_pipeline():
+ """Setup compiled Flux pipeline for testing"""
+ config = INITIAL_TEST_CONFIG["model_setup"]
+
+ pipeline = QEffFluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ use_onnx_subfunctions=config["use_onnx_subfunctions"],
+ )
+
+ # Reduce to 2 layers for testing
+ original_blocks = pipeline.transformer.model.transformer_blocks
+ org_single_blocks = pipeline.transformer.model.single_transformer_blocks
+
+ pipeline.transformer.model.config["num_layers"] = config["num_transformer_layers"]
+ pipeline.transformer.model.config["num_single_layers"] = config["num_single_layers"]
+ pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList(
+ [original_blocks[i] for i in range(0, pipeline.transformer.model.config["num_layers"])]
+ )
+ pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList(
+ [org_single_blocks[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])]
+ )
+
+ ### Pytorch pipeline
+ pytorch_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ original_blocks_pt = pytorch_pipeline.transformer.transformer_blocks
+ org_single_blocks_pt = pytorch_pipeline.transformer.single_transformer_blocks
+ pytorch_pipeline.transformer.transformer_blocks = torch.nn.ModuleList(
+ [original_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_layers"])]
+ )
+ pytorch_pipeline.transformer.single_transformer_blocks = torch.nn.ModuleList(
+ [org_single_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])]
+ )
+ return pipeline, pytorch_pipeline
+
+
+@pytest.mark.diffusion_models
+@pytest.mark.on_qaic
+def test_flux_pipeline(flux_pipeline):
+ """
+ Comprehensive Flux pipeline test that follows the exact same flow as pipeline_flux.py:
+ - 256x256 resolution - 2 transformer layers
+ - MAD validation
+ - Functional image generation test
+ - Export/compilation checks
+ - Returns QEffPipelineOutput with performance metrics
+ """
+ pipeline, pytorch_pipeline = flux_pipeline
+ config = INITIAL_TEST_CONFIG
+
+ # Print test header
+ DiffusersTestUtils.print_test_header(
+ f"FLUX PIPELINE TEST - {config['model_setup']['height']}x{config['model_setup']['width']} Resolution, {config['model_setup']['num_transformer_layers']} Layers",
+ config,
+ )
+
+ # Test parameters
+ test_prompt = config["pipeline_params"]["test_prompt"]
+ num_inference_steps = config["pipeline_params"]["num_inference_steps"]
+ guidance_scale = config["pipeline_params"]["guidance_scale"]
+ max_sequence_length = config["pipeline_params"]["max_sequence_length"]
+
+ # Generate with MAD validation
+ generator = torch.manual_seed(42)
+ start_time = time.time()
+
+ try:
+ # Run the pipeline with integrated MAD validation (follows exact pipeline flow)
+ result = flux_pipeline_call_with_mad_validation(
+ pipeline,
+ pytorch_pipeline,
+ height=config["model_setup"]["height"],
+ width=config["model_setup"]["width"],
+ prompt=test_prompt,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ max_sequence_length=max_sequence_length,
+ custom_config_path=CONFIG_PATH,
+ generator=generator,
+ mad_tolerances=config["mad_validation"]["tolerances"],
+ parallel_compile=True,
+ return_dict=True,
+ )
+
+ execution_time = time.time() - start_time
+
+ # Validate image generation
+ if config["pipeline_params"]["validate_gen_img"]:
+ assert result is not None, "Pipeline returned None"
+ assert hasattr(result, "images"), "Result missing 'images' attribute"
+ assert len(result.images) > 0, "No images generated"
+
+ generated_image = result.images[0]
+ expected_size = (config["model_setup"]["height"], config["model_setup"]["width"])
+ # Validate image properties using utilities
+ image_validation = DiffusersTestUtils.validate_image_generation(
+ generated_image, expected_size, config["pipeline_params"]["min_image_variance"]
+ )
+
+ print("\n IMAGE VALIDATION PASSED")
+ print(f" - Size: {image_validation['size']}")
+ print(f" - Mode: {image_validation['mode']}")
+ print(f" - Variance: {image_validation['variance']:.2f}")
+ print(f" - Mean pixel value: {image_validation['mean_pixel_value']:.2f}")
+ file_path = "test_flux_256x256_2layers.png"
+ # Save test image
+ generated_image.save(file_path)
+
+ if os.path.exists(file_path):
+ print(f"Image saved successfully at: {file_path}")
+ else:
+ print("Image was not saved.")
+
+ if config["validation_checks"]["onnx_export"]:
+ # Check if ONNX files exist (basic check)
+ print("\n ONNX Export Validation:")
+ for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]:
+ module_obj = getattr(pipeline, module_name, None)
+ if module_obj and hasattr(module_obj, "onnx_path") and module_obj.onnx_path:
+ DiffusersTestUtils.check_file_exists(str(module_obj.onnx_path), f"{module_name} ONNX")
+
+ if config["validation_checks"]["compilation"]:
+ # Check if QPC files exist (basic check)
+ print("\n Compilation Validation:")
+ for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]:
+ module_obj = getattr(pipeline, module_name, None)
+ if module_obj and hasattr(module_obj, "qpc_path") and module_obj.qpc_path:
+ DiffusersTestUtils.check_file_exists(str(module_obj.qpc_path), f"{module_name} QPC")
+
+ # Print test summary using utilities
+ print(f"\nTotal execution time: {execution_time:.4f}s")
+ except Exception as e:
+ print(f"\nTEST FAILED: {e}")
+ raise
+
+
+if __name__ == "__main__":
+ # This allows running the test file directly for debugging
+ pytest.main([__file__, "-v", "-s", "-m", "flux"])
+# pytest tests/diffusers/test_flux.py -m flux -v -s --tb=short
diff --git a/tests/diffusers/test_wan.py b/tests/diffusers/test_wan.py
new file mode 100644
index 000000000..f11db826b
--- /dev/null
+++ b/tests/diffusers/test_wan.py
@@ -0,0 +1,535 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+"""
+Test for wan pipeline
+# TODO : 1. Add pytest for call method
+ 2. See if we reduce height and width
+ 3. Keep test for Sub fn as default once sdk supports
+"""
+
+import time
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import pytest
+import safetensors.torch
+import torch
+from diffusers import WanPipeline
+from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
+from diffusers.utils import export_to_video
+from huggingface_hub import hf_hub_download
+
+from QEfficient import QEffWanPipeline
+from QEfficient.diffusers.pipelines.pipeline_utils import (
+ ModulePerf,
+ QEffPipelineOutput,
+ calculate_latent_dimensions_with_frames,
+ set_module_device_ids,
+)
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils import constants
+from QEfficient.utils._utils import load_json
+from tests.diffusers.diffusers_utils import DiffusersTestUtils, MADValidator
+
+# Test Configuration for 192x320 resolution with 1 layer
+CONFIG_PATH = "tests/diffusers/wan_test_config.json"
+INITIAL_TEST_CONFIG = load_json(CONFIG_PATH)
+
+
+def wan_pipeline_call_with_mad_validation(
+ pipeline,
+ pytorch_pipeline,
+ height: int = 192,
+ width: int = 320,
+ num_frames: int = 81,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ num_inference_steps: int = 2,
+ guidance_scale: float = 1.0,
+ guidance_scale_2: Optional[float] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ custom_config_path: Optional[str] = None,
+ use_onnx_subfunctions: bool = False,
+ parallel_compile: bool = True,
+ mad_tolerances: Dict[str, float] = None,
+):
+ """
+ Pipeline call function that replicates the exact flow of pipeline_wan.py.__call__()
+ while adding comprehensive MAD validation for transformer modules only.
+
+ This function follows the EXACT same structure as QEffWanPipeline.__call__()
+ but adds MAD validation hooks for transformer testing.
+ """
+ # Initialize MAD validator
+ mad_validator = MADValidator(tolerances=mad_tolerances)
+
+ device = "cpu"
+
+ # Step 1: Compile() (export and compile)
+ pipeline.cl, pipeline.latent_height, pipeline.latent_width, pipeline.latent_frames = (
+ calculate_latent_dimensions_with_frames(
+ height,
+ width,
+ num_frames,
+ pipeline.model.vae.config.scale_factor_spatial,
+ pipeline.model.vae.config.scale_factor_temporal,
+ pipeline.patch_height,
+ pipeline.patch_width,
+ )
+ )
+ pipeline.compile(
+ compile_config=custom_config_path,
+ parallel=parallel_compile,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ use_onnx_subfunctions=use_onnx_subfunctions,
+ )
+
+ set_module_device_ids(pipeline)
+
+ # Step 2: Check inputs
+ pipeline.model.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ guidance_scale_2,
+ )
+
+ if num_frames % pipeline.model.vae.config.scale_factor_temporal != 1:
+ num_frames = (
+ num_frames
+ // pipeline.model.vae.config.scale_factor_temporal
+ * pipeline.model.vae.config.scale_factor_temporal
+ + 1
+ )
+ num_frames = max(num_frames, 1)
+
+ if pipeline.model.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
+ pipeline._guidance_scale = guidance_scale
+ pipeline._guidance_scale_2 = guidance_scale_2
+ pipeline._attention_kwargs = attention_kwargs
+ pipeline._current_timestep = None
+ pipeline._interrupt = False
+
+ # Step 3: Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Step 4: Encode input prompt(using CPU text encoder for now)
+ prompt_embeds, negative_prompt_embeds = pipeline.model.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=pipeline.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Get PyTorch reference prompt embeddings
+ # For standard WAN pipeline, CFG is determined by presence of negative prompts
+ do_classifier_free_guidance = negative_prompt is not None
+ pytorch_prompt_embeds, pytorch_negative_prompt_embeds = pytorch_pipeline.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = pipeline.transformer.model.transformer_high.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ pytorch_prompt_embeds = pytorch_prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+ pytorch_negative_prompt_embeds = pytorch_negative_prompt_embeds.to(transformer_dtype)
+
+ # Step 5: Prepare timesteps
+ pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = pipeline.scheduler.timesteps
+
+ # Step 6: Prepare latent variables
+ num_channels_latents = pipeline.transformer.model.config.in_channels
+ latents = pipeline.model.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
+
+ # Step 7: Setup transformer inference session
+ if pipeline.transformer.qpc_session is None:
+ pipeline.transformer.qpc_session = QAICInferenceSession(
+ str(pipeline.transformer.qpc_path), device_ids=pipeline.transformer.device_ids
+ )
+
+ output_buffer = {
+ "output": np.random.rand(
+ batch_size,
+ pipeline.cl,
+ constants.WAN_DIT_OUT_CHANNELS,
+ ).astype(np.int32),
+ }
+ pipeline.transformer.qpc_session.set_buffers(output_buffer)
+ transformer_perf = []
+
+ # Step 8: Denoising loop with transformer MAD validation
+ if pipeline.model.config.boundary_ratio is not None:
+ boundary_timestep = pipeline.model.config.boundary_ratio * pipeline.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
+
+ with pipeline.model.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if pipeline._interrupt:
+ continue
+
+ pipeline._current_timestep = t
+
+ # Determine which transformer to use (high or low noise)
+ if boundary_timestep is None or t >= boundary_timestep:
+ # High-noise stage
+ current_model = pipeline.transformer.model.transformer_high
+ pytorch_current_model = pytorch_pipeline.transformer
+ model_type = torch.ones(1, dtype=torch.int64)
+ model_name = "transformer_high"
+ else:
+ # Low-noise stage
+ current_model = pipeline.transformer.model.transformer_low
+ pytorch_current_model = pytorch_pipeline.transformer_2
+ model_type = torch.ones(2, dtype=torch.int64)
+ model_name = "transformer_low"
+
+ latent_model_input = latents.to(transformer_dtype)
+ if pipeline.model.config.expand_timesteps:
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ timestep = t.expand(latents.shape[0])
+
+ batch_size, num_channels, num_frames, height, width = latents.shape
+ p_t, p_h, p_w = current_model.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ # Prepare transformer inputs
+ rotary_emb = current_model.rope(latent_model_input)
+ rotary_emb = torch.cat(rotary_emb, dim=0)
+ ts_seq_len = None
+ timestep = timestep.flatten()
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = current_model.condition_embedder(
+ timestep, prompt_embeds, encoder_hidden_states_image=None, timestep_seq_len=ts_seq_len
+ )
+
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ # Prepare inputs for QAIC inference
+ inputs_aic = {
+ "hidden_states": latents.detach().numpy(),
+ "encoder_hidden_states": encoder_hidden_states.detach().numpy(),
+ "rotary_emb": rotary_emb.detach().numpy(),
+ "temb": temb.detach().numpy(),
+ "timestep_proj": timestep_proj.detach().numpy(),
+ "tsp": model_type.detach().numpy(),
+ }
+
+ # PyTorch reference inference (standard WAN transformer has different signature)
+ noise_pred_torch = pytorch_current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=pytorch_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # QAIC inference
+ with current_model.cache_context("cond"):
+ start_transformer_step_time = time.time()
+ outputs = pipeline.transformer.qpc_session.run(inputs_aic)
+ end_transformer_step_time = time.time()
+ transformer_perf.append(end_transformer_step_time - start_transformer_step_time)
+
+ hidden_states = torch.tensor(outputs["output"])
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ noise_pred = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ # Transformer MAD validation
+ print(f" Performing MAD validation for {model_name} at step {i}...")
+ mad_validator.validate_module_mad(
+ noise_pred_torch.detach().cpu().numpy(),
+ noise_pred.detach().cpu().numpy(),
+ model_name,
+ f"step {i} (t={t.item():.1f})",
+ )
+
+ # Update latents using scheduler
+ latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
+ progress_bar.update()
+
+ # Step 9: Decode latents to video (using CPU VAE for now)
+ if not output_type == "latent":
+ latents = latents.to(pipeline.vae_decode.dtype)
+ latents_mean = (
+ torch.tensor(pipeline.vae_decode.config.latents_mean)
+ .view(1, pipeline.vae_decode.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(pipeline.vae_decode.config.latents_std).view(
+ 1, pipeline.vae_decode.config.z_dim, 1, 1, 1
+ ).to(latents.device, latents.dtype)
+ latents = latents / latents_std + latents_mean
+
+ video = pipeline.model.vae.decode(latents, return_dict=False)[0]
+
+ video = pipeline.model.video_processor.postprocess_video(video.detach())
+ else:
+ video = latents
+
+ # Build performance metrics
+ perf_metrics = [
+ ModulePerf(module_name="transformer", perf=transformer_perf),
+ ]
+
+ return QEffPipelineOutput(
+ pipeline_module=perf_metrics,
+ images=video,
+ )
+
+
+@pytest.fixture(scope="session")
+def wan_pipeline():
+ """Setup compiled WAN pipeline for testing with LoRA adapters and 2 layers total"""
+ config = INITIAL_TEST_CONFIG["model_setup"]
+
+ def load_wan_lora(path: str):
+ return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path))
+
+ # Download and load LoRA adapters
+ high_noise_lora_path = hf_hub_download(
+ repo_id="lightx2v/Wan2.2-Lightning",
+ filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors",
+ )
+ low_noise_lora_path = hf_hub_download(
+ repo_id="lightx2v/Wan2.2-Lightning",
+ filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors",
+ )
+
+ # Load PyTorch reference pipeline
+ pytorch_pipeline = WanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers")
+
+ # Load into the transformers
+ pytorch_pipeline.transformer.load_lora_adapter(load_wan_lora(high_noise_lora_path), adapter_name="high_noise")
+ pytorch_pipeline.transformer.set_adapters(["high_noise"], weights=[1.0])
+
+ pytorch_pipeline.transformer_2.load_lora_adapter(load_wan_lora(low_noise_lora_path), adapter_name="low_noise")
+ pytorch_pipeline.transformer_2.set_adapters(["low_noise"], weights=[1.0])
+
+ # ### for 2 layer model
+ pytorch_pipeline.transformer.config.num_layers = config["num_transformer_layers_high"]
+ pytorch_pipeline.transformer_2.config.num_layers = config["num_transformer_layers_low"]
+ original_blocks = pytorch_pipeline.transformer.blocks
+ org_blocks = pytorch_pipeline.transformer_2.blocks
+ pytorch_pipeline.transformer.blocks = torch.nn.ModuleList(
+ [original_blocks[i] for i in range(0, pytorch_pipeline.transformer.config.num_layers)]
+ )
+ pytorch_pipeline.transformer_2.blocks = torch.nn.ModuleList(
+ [org_blocks[i] for i in range(0, pytorch_pipeline.transformer_2.config.num_layers)]
+ )
+
+ # Load QEff WAN pipeline
+ pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers")
+
+ # Load LoRA adapters into transformers
+ pipeline.transformer.model.transformer_high.load_lora_adapter(
+ load_wan_lora(high_noise_lora_path), adapter_name="high_noise"
+ )
+ pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0])
+ pipeline.transformer.model.transformer_low.load_lora_adapter(
+ load_wan_lora(low_noise_lora_path), adapter_name="low_noise"
+ )
+ pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0])
+
+ # Reduce to 1 layer (1 high, 1 low) for testing
+ pipeline.transformer.model.transformer_high.config.num_layers = config["num_transformer_layers_high"]
+ pipeline.transformer.model.transformer_low.config.num_layers = config["num_transformer_layers_low"]
+
+ original_blocks_high = pipeline.transformer.model.transformer_high.blocks
+ original_blocks_low = pipeline.transformer.model.transformer_low.blocks
+
+ pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList(
+ [original_blocks_high[i] for i in range(0, config["num_transformer_layers_high"])]
+ )
+ pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList(
+ [original_blocks_low[i] for i in range(0, config["num_transformer_layers_low"])]
+ )
+
+ return pipeline, pytorch_pipeline
+
+
+@pytest.mark.diffusion_models
+@pytest.mark.on_qaic
+@pytest.mark.wan
+def test_wan_pipeline(wan_pipeline):
+ """
+ Comprehensive WAN pipeline test that focuses on transformer validation:
+ - 192x320 resolution - 2 transformer layers total (1 high + 1 low)
+ - MAD validation for transformer modules only
+ - Functional video generation test
+ - Export/compilation checks for transformer
+ - Returns QEffPipelineOutput with performance metrics
+ """
+ pipeline, pytorch_pipeline = wan_pipeline
+ config = INITIAL_TEST_CONFIG
+
+ # Print test header
+ DiffusersTestUtils.print_test_header(
+ f"WAN PIPELINE TEST - {config['model_setup']['height']}x{config['model_setup']['width']} Resolution, {config['model_setup']['num_frames']} Frames, 2 Layers Total",
+ config,
+ )
+
+ # Test parameters
+ test_prompt = config["pipeline_params"]["test_prompt"]
+ num_inference_steps = config["pipeline_params"]["num_inference_steps"]
+ guidance_scale = config["pipeline_params"]["guidance_scale"]
+ guidance_scale_2 = config["pipeline_params"]["guidance_scale_2"]
+ max_sequence_length = config["pipeline_params"]["max_sequence_length"]
+ num_frames = config["model_setup"]["num_frames"]
+
+ # Generate with MAD validation
+ generator = torch.manual_seed(42)
+ start_time = time.time()
+
+ try:
+ # Run the pipeline with integrated MAD validation (focuses on transformer)
+ result = wan_pipeline_call_with_mad_validation(
+ pipeline,
+ pytorch_pipeline,
+ height=config["model_setup"]["height"],
+ width=config["model_setup"]["width"],
+ num_frames=num_frames,
+ prompt=test_prompt,
+ guidance_scale=guidance_scale,
+ guidance_scale_2=guidance_scale_2,
+ num_inference_steps=num_inference_steps,
+ max_sequence_length=max_sequence_length,
+ custom_config_path=CONFIG_PATH,
+ generator=generator,
+ mad_tolerances=config["mad_validation"]["tolerances"],
+ parallel_compile=True,
+ return_dict=True,
+ )
+
+ execution_time = time.time() - start_time
+
+ # Validate video generation
+ if config["pipeline_params"]["validate_gen_video"]:
+ assert result is not None, "Pipeline returned None"
+ assert hasattr(result, "images"), "Result missing 'images' attribute"
+ assert len(result.images) > 0, "No video frames generated"
+
+ generated_video = result.images[0]
+ assert len(generated_video) == num_frames, f"Expected {num_frames} frames, got {len(generated_video)}"
+
+ # Validate first frame properties
+ first_frame = generated_video[0]
+ expected_size = (config["model_setup"]["width"], config["model_setup"]["height"])
+
+ # Convert numpy array to PIL Image if needed for validation
+ if isinstance(first_frame, np.ndarray):
+ from PIL import Image
+
+ if first_frame.dtype != np.uint8:
+ first_frame = (first_frame * 255).astype(np.uint8)
+ if len(first_frame.shape) == 3 and first_frame.shape[0] == 3:
+ first_frame = first_frame.transpose(1, 2, 0)
+ first_frame = Image.fromarray(first_frame)
+
+ # Validate video frame properties
+ frame_validation = DiffusersTestUtils.validate_image_generation(
+ first_frame, expected_size, config["pipeline_params"]["min_video_variance"]
+ )
+
+ print("\n VIDEO VALIDATION PASSED")
+ print(f" - Frame count: {len(generated_video)}")
+ print(f" - Frame size: {frame_validation['size']}")
+ print(f" - Frame mode: {frame_validation['mode']}")
+ print(f" - Frame variance: {frame_validation['variance']:.2f}")
+ print(f" - Mean pixel value: {frame_validation['mean_pixel_value']:.2f}")
+
+ # Save result as video
+ frames = result.images[0]
+ export_to_video(frames, "test_wan_output_t2v.mp4", fps=16)
+ print("\n VIDEO SAVED: test_wan_output_t2v.mp4")
+ print(result)
+
+ if config["validation_checks"]["onnx_export"]:
+ # Check if transformer ONNX file exists
+ print("\n ONNX Export Validation:")
+ if hasattr(pipeline.transformer, "onnx_path") and pipeline.transformer.onnx_path:
+ DiffusersTestUtils.check_file_exists(str(pipeline.transformer.onnx_path), "transformer ONNX")
+
+ if config["validation_checks"]["compilation"]:
+ # Check if transformer QPC file exists
+ print("\n Compilation Validation:")
+ if hasattr(pipeline.transformer, "qpc_path") and pipeline.transformer.qpc_path:
+ DiffusersTestUtils.check_file_exists(str(pipeline.transformer.qpc_path), "transformer QPC")
+
+ # Print test summary
+ print(f"\nTotal execution time: {execution_time:.4f}s")
+ print(" WAN TRANSFORMER TEST COMPLETED SUCCESSFULLY")
+
+ except Exception as e:
+ print(f"\nTEST FAILED: {e}")
+ raise
+
+
+if __name__ == "__main__":
+ # This allows running the test file directly for debugging
+ pytest.main([__file__, "-v", "-s", "-m", "wan"])
+# pytest tests/diffusers/test_wan.py -m wan -v -s --tb=short
diff --git a/tests/diffusers/wan_test_config.json b/tests/diffusers/wan_test_config.json
new file mode 100644
index 000000000..1ed36294a
--- /dev/null
+++ b/tests/diffusers/wan_test_config.json
@@ -0,0 +1,63 @@
+{
+ "model_setup": {
+ "height": 192,
+ "width": 320,
+ "num_frames": 81,
+ "num_transformer_layers_high": 1,
+ "num_transformer_layers_low": 1,
+ "use_onnx_subfunctions": false
+ },
+ "mad_validation": {
+ "tolerances": {
+ "transformer_high": 0.3,
+ "transformer_low": 0.2
+ }
+ },
+ "pipeline_params": {
+ "test_prompt": "A cat walking in a garden",
+ "num_inference_steps": 2,
+ "guidance_scale": 1.0,
+ "guidance_scale_2": 1.0,
+ "max_sequence_length": 512,
+ "validate_gen_video": true,
+ "min_video_variance": 1.0
+ },
+ "validation_checks": {
+ "video_generation": true,
+ "onnx_export": true,
+ "compilation": true
+ },
+ "modules": {
+ "transformer": {
+ "specializations": [
+ {
+ "batch_size": "1",
+ "num_channels": "16",
+ "steps": "1",
+ "sequence_length": "512",
+ "model_type": 1
+ },
+ {
+ "batch_size": "1",
+ "num_channels": "16",
+ "steps": "1",
+ "sequence_length": "512",
+ "model_type": 2
+ }
+ ],
+ "compilation": {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts_mos": 1
+ },
+ "execute": {
+ "device_ids": null
+ }
+ }
+ }
+}
diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py
index 00a4216b7..46b33c60b 100644
--- a/tests/peft/lora/test_lora_model.py
+++ b/tests/peft/lora/test_lora_model.py
@@ -222,7 +222,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate(
# export
start = perf_counter()
- qeff_model.export(export_dir=tmp_path)
+ onnx_path = qeff_model.export(export_dir=tmp_path)
end = perf_counter()
export_time_0 = end - start
model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.export_hash)
@@ -237,7 +237,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate(
assert export_time_1 < export_time_0
# test compile
- qeff_model.compile(prefill_seq_len=32, ctx_len=64)
+ qeff_model.compile(onnx_path=onnx_path, prefill_seq_len=32, ctx_len=64)
assert Path(qeff_model.qpc_path).is_dir()
assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))
diff --git a/tests/peft/test_peft_model.py b/tests/peft/test_peft_model.py
index cc94467db..c3bb2f140 100644
--- a/tests/peft/test_peft_model.py
+++ b/tests/peft/test_peft_model.py
@@ -178,9 +178,9 @@ def test_auto_peft_model_for_causal_lm_activate_invalid(base_config, adapter_con
def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_config, batch_size, tmp_path):
_, lora_model = create_peft_model(base_config, adapter_config)
qeff_model = QEffAutoPeftModelForCausalLM(lora_model)
- qeff_model.export(tmp_path)
+ onnx_path = qeff_model.export(tmp_path)
start = perf_counter()
- qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
+ qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
end = perf_counter()
compile_time_0 = end - start
@@ -197,7 +197,7 @@ def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_con
)
start = perf_counter()
- qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
+ qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
end = perf_counter()
compile_time_1 = end - start
assert compile_time_1 < 0.01 * compile_time_0
diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py
new file mode 100644
index 000000000..6358940df
--- /dev/null
+++ b/tests/transformers/models/test_disagg_mode.py
@@ -0,0 +1,192 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import time
+
+import numpy as np
+import pytest
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, HybridCache
+
+from QEfficient import QEFFAutoModelForCausalLM
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers
+
+model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32
+
+prompt2 = """
+Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures.
+
+As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown.
+
+The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location.
+"""
+prompt1 = "Once upon a time"
+
+prompts = [prompt1, prompt2]
+
+
+@pytest.mark.on_qaic
+@pytest.mark.parametrize("model_id", [model_id])
+@pytest.mark.parametrize("prompt", prompts)
+def test_disagg_mode_prefill(model_id, prompt):
+ # Run prefill
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ PREFILL_SEQ_LEN = 256
+ CTX_LEN = 256
+ inputs = tokenizer(prompt, return_tensors="np", padding=True)
+ padded_len = inputs["input_ids"].shape[1]
+ num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
+ padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
+
+ replace_transformers_quantizers()
+ model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2)
+ config = model.config
+ inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+ inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+ inputs.pop("token_type_ids", None)
+ inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()}
+ cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN)
+ ins = tokenizer(prompt, return_tensors="pt")
+ out = model(**ins, past_key_values=cache)
+
+ undo_transformers_quantizers()
+
+ qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2)
+ qeff_model.prefill(True)
+ config = qeff_model.model.config
+ inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+ inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+ inputs.pop("token_type_ids", None)
+ inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
+ past_key_values = []
+ for i in range(config.num_hidden_layers):
+ cache_len = 128 if i % 2 == 0 else PREFILL_SEQ_LEN
+ pad_shape = (1, 8, cache_len, 64)
+ past_key = torch.zeros((pad_shape), dtype=torch.float32)
+ past_value = torch.zeros((pad_shape), dtype=torch.float32)
+ pkv = (past_key, past_value)
+ past_key_values.append(pkv)
+ inputs["past_key_values"] = past_key_values
+
+ qeff_out = qeff_model.model(**inputs)
+
+ # Check our pytorch implementation
+ assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4
+
+ prefill_qpc_path = qeff_model.compile(
+ prefill_seq_len=PREFILL_SEQ_LEN,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=False,
+ mxint8_kv_cache=False,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ prefill_only=True,
+ )
+
+ prefill_session = QAICInferenceSession(prefill_qpc_path)
+ logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
+ prefill_session.set_buffers({"logits": logits_out_placeholder})
+ inputs.pop("past_key_values")
+ inputs = {k: v.detach().numpy() for k, v in inputs.items()}
+ st = time.time()
+ qpc_out = prefill_session.run(inputs)
+ print(f"time for prefill_run={time.time() - st} sec\n")
+ del prefill_session
+ # Check QAIC output isclose with QEFF pytorch output
+ assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2
+
+
+@pytest.mark.skip(reason="no way of currently testing this without the assert sdk")
+@pytest.mark.on_qaic
+@pytest.mark.parametrize("model_id", [model_id])
+@pytest.mark.parametrize("prompt", prompts)
+def test_disagg_mode_prefill_chunked(model_id, prompt):
+ # Run prefill
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ PREFILL_SEQ_LEN = 128
+ CTX_LEN = 128 * 3
+ inputs = tokenizer(prompt, return_tensors="np", padding=True)
+ padded_len = inputs["input_ids"].shape[1]
+ num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
+ padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
+
+ replace_transformers_quantizers()
+ model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2)
+ config = model.config
+ inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+ inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+ inputs.pop("token_type_ids", None)
+ inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()}
+ cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN)
+ ins = tokenizer(prompt, return_tensors="pt")
+ out = model(**ins, past_key_values=cache)
+
+ undo_transformers_quantizers()
+
+ qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2)
+ qeff_model.prefill(True, enable_chunking=True)
+ config = qeff_model.model.config
+ inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
+ inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
+ inputs.pop("token_type_ids", None)
+ inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
+ past_key_values = []
+ for i in range(config.num_hidden_layers):
+ cache_len = CTX_LEN
+ pad_shape = (1, 8, cache_len, 64)
+ past_key = torch.zeros((pad_shape), dtype=torch.float32)
+ past_value = torch.zeros((pad_shape), dtype=torch.float32)
+ pkv = (past_key, past_value)
+ past_key_values.append(pkv)
+ inputs["past_key_values"] = past_key_values
+
+ for i in range(num_chunks):
+ chunk_inputs = inputs.copy()
+ chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+ chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+
+ qeff_out = qeff_model.model(**chunk_inputs)
+ inputs["past_key_values"] = qeff_out["past_key_values"]
+
+ # Check our pytorch implementation
+ assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4
+
+ prefill_qpc_path = qeff_model.compile(
+ prefill_seq_len=PREFILL_SEQ_LEN,
+ ctx_len=CTX_LEN,
+ num_cores=16,
+ mxfp6_matmul=False,
+ mxint8_kv_cache=False,
+ num_devices=1,
+ mos=1,
+ aic_enable_depth_first=True,
+ num_speculative_tokens=None,
+ prefill_only=True,
+ enable_chunking=True,
+ )
+ prefill_session = QAICInferenceSession(prefill_qpc_path)
+ prefill_session.skip_buffers(
+ [x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")]
+ )
+ logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
+ prefill_session.set_buffers({"logits": logits_out_placeholder})
+ inputs.pop("past_key_values")
+ inputs = {k: v.detach().numpy() for k, v in inputs.items()}
+ st = time.time()
+ for i in range(num_chunks):
+ chunk_inputs = inputs.copy()
+ chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+ chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
+ qpc_out = prefill_session.run(chunk_inputs)
+ print(f"time for prefill_run={time.time() - st} sec\n")
+ del prefill_session
+ # Check QAIC output isclose with QEFF pytorch output
+ assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 8e-2
diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py
index 9335e1d91..26cb6fda9 100644
--- a/tests/transformers/sampler/test_sampler.py
+++ b/tests/transformers/sampler/test_sampler.py
@@ -5,15 +5,18 @@
#
# -----------------------------------------------------------------------------
-from typing import List
+from typing import List, Optional, Tuple, Union
import numpy as np
import pytest
+from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
-from QEfficient import QEFFAutoModelForCausalLM
+from QEfficient import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import load_hf_tokenizer
from QEfficient.utils.constants import Constants
+from QEfficient.utils.test_utils import InternProcessor
+from tests.transformers.models.image_text_to_text.test_continuous_batching import set_num_layers
sampler_transform_configs = [
pytest.param(
@@ -24,6 +27,20 @@
20, # generation_len
2, # full_batch_size
1, # spec_length
+ False, # is_vlm
+ ),
+ pytest.param(
+ "OpenGVLab/InternVL2_5-1B", # model
+ (
+ ["https://picsum.photos/id/237/536/354"] * 2,
+ ["Can you describe the image in detail."] * 2,
+ ), # images and prompts
+ 128, # prefill_seq_len
+ 4096, # ctx_len
+ 20, # generation_len
+ 2, # full_batch_size
+ None, # spec_length
+ True, # is_vlm
),
]
greedy_sampling_configs = [
@@ -35,6 +52,20 @@
20, # generation_len
4, # full_batch_size
1, # spec_length
+ False, # is_vlm
+ ),
+ pytest.param(
+ "OpenGVLab/InternVL2_5-1B", # model
+ (
+ ["https://picsum.photos/id/237/536/354"] * 2,
+ ["Can you describe the image in detail."] * 2,
+ ), # images and prompts
+ 128, # prefill_seq_len
+ 4096, # ctx_len
+ 20, # generation_len
+ 2, # full_batch_size
+ None, # spec_length
+ True, # is_vlm
),
]
random_sampling_configs = [
@@ -46,23 +77,98 @@
20, # generation_len
4, # full_batch_size
1, # spec_length
+ False, # is_vlm
+ ),
+ pytest.param(
+ "OpenGVLab/InternVL2_5-1B", # model
+ (
+ ["https://picsum.photos/id/237/536/354"] * 4,
+ ["Can you describe the image in detail."] * 4,
+ ), # images and prompts
+ 128, # prefill_seq_len
+ 4096, # ctx_len
+ 20, # generation_len
+ 4, # full_batch_size
+ None, # spec_length
+ True, # is_vlm
+ ),
+]
+guided_decoding_configs = [
+ pytest.param(
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # model
+ Constants.INPUT_STR * 4, # prompts
+ 32, # prefill_seq_len
+ 64, # ctx_len
+ 20, # generation_len
+ 4, # full_batch_size
+ 1, # spec_length
+ False, # is_vlm
+ ),
+ pytest.param(
+ "OpenGVLab/InternVL2_5-1B", # model
+ (
+ ["https://picsum.photos/id/237/536/354"] * 2,
+ ["Can you describe the image in detail."] * 2,
+ ), # images and prompts
+ 128, # prefill_seq_len
+ 4096, # ctx_len
+ 20, # generation_len
+ 2, # full_batch_size
+ None, # spec_length
+ True, # is_vlm
),
]
+def prepare_model_setup(
+ model: str, is_vlm: bool, num_hidden_layers: int, prompts: Union[List, Tuple], spec_length: Optional[int]
+):
+ additional_configs = {}
+ additional_params = {}
+ if is_vlm:
+ config = AutoConfig.from_pretrained(model, trust_remote_code=True)
+ config = set_num_layers(config, n_layer=num_hidden_layers)
+ additional_configs["config"] = config
+ additional_configs["kv_offload"] = True
+ assert isinstance(prompts, tuple), "For VLMs, both image and text prompts must be provided."
+ additional_params["images"] = prompts[0]
+ prompts = prompts[1]
+
+ if "InternVL" in model:
+ additional_configs["trust_remote_code"] = True
+ model_hf = AutoModelForCausalLM.from_pretrained(
+ model,
+ config=config,
+ trust_remote_code=True,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True, use_fast=False)
+ additional_params["processor"] = InternProcessor(model_hf, tokenizer)
+ qeff_class = QEFFAutoModelForCausalLM
+ else:
+ additional_params["processor"] = AutoProcessor.from_pretrained(model)
+ qeff_class = QEFFAutoModelForImageTextToText
+ else:
+ if num_hidden_layers != -1:
+ additional_configs["num_hidden_layers"] = num_hidden_layers
+ spec_length = (spec_length or 1) - 1
+ qeff_class = QEFFAutoModelForCausalLM
+ return additional_configs, additional_params, prompts, spec_length, qeff_class
+
+
@pytest.mark.on_qaic
@pytest.mark.parametrize(
- "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length",
+ "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm",
sampler_transform_configs,
)
def test_sampler_transform(
model: str,
- prompts: List[str],
+ prompts: Union[List[str], tuple[List[str], List[str]]],
prefill_seq_len: int,
ctx_len: int,
generation_len: int,
full_batch_size: int,
- spec_length: int,
+ spec_length: Optional[int],
+ is_vlm: bool,
):
"""
Test if `SamplerTransform` adds nodes at the output of a `QEffForCausalLM model` to enable the
@@ -70,48 +176,78 @@ def test_sampler_transform(
next tokens and/or probability distributions.
"""
# Export and compile QEfficient models
- model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ num_hidden_layers = 2
+ additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
+ model, is_vlm, num_hidden_layers, prompts, spec_length
+ )
+ model_w_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
- num_hidden_layers=2,
qaic_config={
"include_sampler": True,
"return_pdfs": False,
"max_top_k_ids": 512,
},
+ **additional_configs,
)
- model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ model_w_sampler_w_guided_decoding = qeff_class.from_pretrained(
+ model,
+ continuous_batching=True,
+ qaic_config={
+ "include_sampler": True,
+ "return_pdfs": False,
+ "max_top_k_ids": 512,
+ "include_guided_decoding": True,
+ },
+ **additional_configs,
+ )
+ model_wo_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
- num_hidden_layers=2,
qaic_config={
"include_sampler": False,
"return_pdfs": False,
},
+ **additional_configs,
+ )
+ model_w_sampler_qpc_path = model_w_sampler.compile(
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ full_batch_size=full_batch_size,
+ num_devices=1,
+ num_cores=16,
+ num_speculative_tokens=spec_length,
+ mxint8_kv_cache=True,
+ mxfp6_matmul=True,
)
- model_w_sampler_qpc_path: str = model_w_sampler.compile(
+ model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding.compile(
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
- model_wo_sampler_qpc_path: str = model_wo_sampler.compile(
+ model_wo_sampler_qpc_path = model_wo_sampler.compile(
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
+ if is_vlm:
+ model_w_sampler_qpc_path = model_w_sampler_qpc_path[1]
+ model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding_qpc_path[1]
+ model_wo_sampler_qpc_path = model_wo_sampler_qpc_path[1]
# Init qaic session
model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path)
+ model_w_sampler_w_guided_decoding_session = QAICInferenceSession(model_w_sampler_w_guided_decoding_qpc_path)
model_wo_sampler_session = QAICInferenceSession(model_wo_sampler_qpc_path)
# Skip inputs/outputs buffers
@@ -119,6 +255,12 @@ def test_sampler_transform(
model_w_sampler_session.skip_buffers(
set([x for x in model_w_sampler_session.output_names if x.endswith("_RetainedState")])
)
+ model_w_sampler_w_guided_decoding_session.skip_buffers(
+ set([x for x in model_w_sampler_w_guided_decoding_session.input_names if x.startswith("past_")])
+ )
+ model_w_sampler_w_guided_decoding_session.skip_buffers(
+ set([x for x in model_w_sampler_w_guided_decoding_session.output_names if x.endswith("_RetainedState")])
+ )
model_wo_sampler_session.skip_buffers(
set([x for x in model_wo_sampler_session.input_names if x.startswith("past_")])
)
@@ -132,47 +274,58 @@ def test_sampler_transform(
assert input_name in model_w_sampler_session.input_names, (
f"Sampler input {input_name} not found in QPC compiled with On Device Sampler"
)
+ assert input_name in model_w_sampler_w_guided_decoding_session.input_names, (
+ f"Sampler input {input_name} not found in QPC compiled with On Device Sampler and Guided Decoding"
+ )
assert input_name not in model_wo_sampler_session.input_names, (
f"Sampler input {input_name} found in QPC compiled without On Device Sampler"
)
+ assert "token_bitmasks" in model_w_sampler_w_guided_decoding_session.input_names, (
+ "Sampler input token_bitmasks not found in QPC compiled with On Device Sampler and Guided Decoding"
+ )
@pytest.mark.on_qaic
@pytest.mark.parametrize(
- "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length",
+ "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm",
greedy_sampling_configs,
)
def test_greedy_sampling(
model: str,
- prompts: List[str],
+ prompts: Union[List[str], tuple[List[str], List[str]]],
prefill_seq_len: int,
ctx_len: int,
generation_len: int,
full_batch_size: int,
- spec_length: int,
+ spec_length: Optional[int],
+ is_vlm: bool,
):
"""
- Test greedy sampling with QPC compiled with and without On Device Sampling.
+ Test greedy sampling with QPCs compiled with and without On Device Sampling.
"""
# Export and compile QEfficient models
- model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ num_hidden_layers = 4
+ additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
+ model, is_vlm, num_hidden_layers, prompts, spec_length
+ )
+ model_w_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
- num_hidden_layers=4,
qaic_config={
"include_sampler": True,
"return_pdfs": False,
"max_top_k_ids": 512,
},
+ **additional_configs,
)
- model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ model_wo_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
- num_hidden_layers=4,
qaic_config={
"include_sampler": False,
"return_pdfs": False,
},
+ **additional_configs,
)
model_w_sampler.compile(
prefill_seq_len=prefill_seq_len,
@@ -180,7 +333,7 @@ def test_greedy_sampling(
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
@@ -190,7 +343,7 @@ def test_greedy_sampling(
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
@@ -211,8 +364,9 @@ def test_greedy_sampling(
"top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
"top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
- "random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32),
},
+ **additional_params,
)
model_wo_sampler_exec_info = model_wo_sampler.generate(
tokenizer=tokenizer,
@@ -221,6 +375,7 @@ def test_greedy_sampling(
include_sampler=False,
return_pdfs=False,
sampling_params=None,
+ **additional_params,
)
# Compare generated texts and ids
@@ -233,25 +388,29 @@ def test_greedy_sampling(
@pytest.mark.on_qaic
-@pytest.mark.skip
@pytest.mark.parametrize(
- "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length",
+ "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm",
random_sampling_configs,
)
def test_random_sampling(
model: str,
- prompts: List[str],
+ prompts: Union[List[str], tuple[List[str], List[str]]],
prefill_seq_len: int,
ctx_len: int,
generation_len: int,
full_batch_size: int,
- spec_length: int,
+ spec_length: Optional[int],
+ is_vlm: bool,
):
"""
- Test random sampling with QPC compiled with and without On Device Sampling.
+ Test random sampling with QPCs compiled with and without On Device Sampling.
"""
# Export and compile QEfficient models
- model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ num_hidden_layers = -1
+ additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
+ model, is_vlm, num_hidden_layers, prompts, spec_length
+ )
+ model_w_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
qaic_config={
@@ -259,14 +418,16 @@ def test_random_sampling(
"return_pdfs": False,
"max_top_k_ids": 512,
},
+ **additional_configs,
)
- model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained(
+ model_wo_sampler = qeff_class.from_pretrained(
model,
continuous_batching=True,
qaic_config={
"include_sampler": False,
"return_pdfs": False,
},
+ **additional_configs,
)
model_w_sampler.compile(
prefill_seq_len=prefill_seq_len,
@@ -274,7 +435,7 @@ def test_random_sampling(
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
@@ -284,13 +445,14 @@ def test_random_sampling(
full_batch_size=full_batch_size,
num_devices=1,
num_cores=16,
- num_speculative_tokens=spec_length - 1,
+ num_speculative_tokens=spec_length,
mxint8_kv_cache=True,
mxfp6_matmul=True,
)
# Generate texts from prompts
tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model)
+ np.random.seed(0)
model_w_sampler_exec_info = model_w_sampler.generate(
tokenizer=tokenizer,
prompts=prompts,
@@ -301,12 +463,15 @@ def test_random_sampling(
"repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
# "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
- "temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
- "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
+ "temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
"top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
"min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
- "random_numbers": np.array(0.26, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=512), (full_batch_size, 1)).astype(
+ np.float32
+ ),
},
+ **additional_params,
)
model_wo_sampler_exec_info = model_wo_sampler.generate(
tokenizer=tokenizer,
@@ -315,63 +480,120 @@ def test_random_sampling(
include_sampler=False,
return_pdfs=False,
sampling_params=None,
+ **additional_params,
)
# Compare generated texts
- golden_texts = {
- "w_sampler": "Raymond and my favorite color, alongside reds or purples (I canβt have them both",
- "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ",
- }
- golden_ids = {
- "w_sampler": [
- [
- 21380,
- 322,
- 590,
- 25448,
- 2927,
- 29892,
- 19963,
- 2654,
- 29879,
- 470,
- 3708,
- 2701,
- 313,
- 29902,
- 508,
- 30010,
- 29873,
- 505,
- 963,
- 1716,
- ]
- ],
- "wo_sampler": [
- [
- 2259,
- 7075,
- 322,
- 306,
- 626,
- 263,
- 7047,
- 22055,
- 29889,
- 306,
- 505,
- 1063,
- 1985,
- 297,
- 278,
- 13661,
- 363,
- 278,
- 4940,
- 29871,
- ]
- ],
- }
+ if model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
+ golden_texts = {
+ "w_sampler": "Aiden and I am a freelance writer who loves to explore the world. With over",
+ "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ",
+ }
+ golden_ids = {
+ "w_sampler": [
+ [
+ 319,
+ 3615,
+ 322,
+ 306,
+ 626,
+ 263,
+ 3005,
+ 295,
+ 749,
+ 9227,
+ 1058,
+ 12355,
+ 267,
+ 304,
+ 26987,
+ 278,
+ 3186,
+ 29889,
+ 2973,
+ 975,
+ ]
+ ],
+ "wo_sampler": [
+ [
+ 2259,
+ 7075,
+ 322,
+ 306,
+ 626,
+ 263,
+ 7047,
+ 22055,
+ 29889,
+ 306,
+ 505,
+ 1063,
+ 1985,
+ 297,
+ 278,
+ 13661,
+ 363,
+ 278,
+ 4940,
+ 29871,
+ ]
+ ],
+ }
+ elif model == "OpenGVLab/InternVL2_5-1B":
+ golden_texts = {
+ "w_sampler": "The description of this picture would be as follows:\n\nAn adorable black puppy is sitting on a wooden surface",
+ "wo_sampler": "The image features a black puppy sitting on a wooden surface. The puppy has a shiny, glossy coat",
+ }
+ golden_ids = {
+ "w_sampler": [
+ [
+ 785,
+ 4008,
+ 315,
+ 419,
+ 6802,
+ 1035,
+ 387,
+ 438,
+ 11017,
+ 1447,
+ 2082,
+ 40608,
+ 3691,
+ 41189,
+ 374,
+ 11699,
+ 389,
+ 264,
+ 22360,
+ 7329,
+ ]
+ ],
+ "wo_sampler": [
+ [
+ 785,
+ 2168,
+ 4419,
+ 264,
+ 3691,
+ 41189,
+ 11699,
+ 389,
+ 264,
+ 22360,
+ 7329,
+ 13,
+ 576,
+ 41189,
+ 702,
+ 264,
+ 41199,
+ 11,
+ 73056,
+ 22875,
+ ]
+ ],
+ }
for i in range(full_batch_size):
assert (
tokenizer.decode(model_w_sampler_exec_info.generated_ids[i][:generation_len]) == golden_texts["w_sampler"]
@@ -385,3 +607,118 @@ def test_random_sampling(
assert (model_wo_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["wo_sampler"]).all(), (
"Without sampler generated ids do not match"
)
+
+
+@pytest.mark.on_qaic
+@pytest.mark.parametrize(
+ "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm",
+ guided_decoding_configs,
+)
+def test_guided_decoding(
+ model: str,
+ prompts: Union[List[str], tuple[List[str], List[str]]],
+ prefill_seq_len: int,
+ ctx_len: int,
+ generation_len: int,
+ full_batch_size: int,
+ spec_length: Optional[int],
+ is_vlm: bool,
+):
+ """
+ Test QPCs compiled with and without guided decoding.
+ """
+ # Export and compile QEfficient models
+ num_hidden_layers = 2
+ additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup(
+ model, is_vlm, num_hidden_layers, prompts, spec_length
+ )
+ model_w_sampler_w_guided_decoding = qeff_class.from_pretrained(
+ model,
+ continuous_batching=True,
+ qaic_config={
+ "include_sampler": True,
+ "return_pdfs": False,
+ "max_top_k_ids": 1024,
+ "include_guided_decoding": True,
+ },
+ **additional_configs,
+ )
+ model_w_sampler_wo_guided_decoding = qeff_class.from_pretrained(
+ model,
+ continuous_batching=True,
+ qaic_config={
+ "include_sampler": True,
+ "return_pdfs": False,
+ "max_top_k_ids": 1024,
+ },
+ **additional_configs,
+ )
+ model_w_sampler_w_guided_decoding.compile(
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ full_batch_size=full_batch_size,
+ num_devices=1,
+ num_cores=16,
+ num_speculative_tokens=spec_length,
+ mxint8_kv_cache=True,
+ mxfp6_matmul=True,
+ )
+ model_w_sampler_wo_guided_decoding.compile(
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ full_batch_size=full_batch_size,
+ num_devices=1,
+ num_cores=16,
+ num_speculative_tokens=spec_length,
+ mxint8_kv_cache=True,
+ mxfp6_matmul=True,
+ )
+
+ # Generate texts from prompts
+ tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model)
+ np.random.seed(0)
+ sampling_params = {
+ "repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ # "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "top_ks": np.array(1024, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1),
+ "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1),
+ "random_numbers": np.zeros((full_batch_size, 1024), dtype=np.float32),
+ }
+ if is_vlm:
+ vocab_size = model_w_sampler_w_guided_decoding.model.language_model.config.vocab_size
+ else:
+ vocab_size = model_w_sampler_w_guided_decoding.model.config.vocab_size
+ model_w_sampler_w_guided_decoding_exec_info = model_w_sampler_w_guided_decoding.generate(
+ tokenizer=tokenizer,
+ prompts=prompts,
+ generation_len=generation_len,
+ include_sampler=True,
+ return_pdfs=False,
+ include_guided_decoding=True,
+ sampling_params={
+ **sampling_params,
+ **{
+ "token_bitmasks": np.tile(
+ np.random.choice([True, False], size=(vocab_size,)),
+ (full_batch_size, 1),
+ )
+ },
+ },
+ **additional_params,
+ )
+ model_w_sampler_wo_guided_decoding_exec_info = model_w_sampler_wo_guided_decoding.generate(
+ tokenizer=tokenizer,
+ prompts=prompts,
+ generation_len=generation_len,
+ include_sampler=True,
+ return_pdfs=False,
+ sampling_params=sampling_params,
+ **additional_params,
+ )
+ assert (
+ model_w_sampler_w_guided_decoding_exec_info.generated_ids
+ != model_w_sampler_wo_guided_decoding_exec_info.generated_ids
+ ).any(), "Sampler outputs with and without guided decoding should not match"
diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py
index 0810ac6ba..72477d56a 100644
--- a/tests/transformers/test_causal_lm.py
+++ b/tests/transformers/test_causal_lm.py
@@ -14,10 +14,11 @@
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
+from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils import constants, get_padding_shape_from_config
from QEfficient.utils.hash_utils import hash_dict_params
-configs = [
+test_configs = [
# name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params
("gpt2", 256, 2, 4, 128, 512, 127, {}),
("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}),
@@ -36,30 +37,43 @@
("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}),
]
-configs = [
- AutoConfig.for_model(
- model_name,
- max_position_embeddings=max_position_embeddings,
- num_hidden_layers=num_hidden_layers,
- num_attention_heads=num_attention_heads,
- hidden_size=hidden_size,
- intermediate_size=intermediate_size,
- vocab_size=vocab_size,
- **additional_params,
- )
- for (
- model_name,
- max_position_embeddings,
- num_hidden_layers,
- num_attention_heads,
- hidden_size,
- intermediate_size,
- vocab_size,
- additional_params,
- ) in configs
+test_prefill_only_specialized_models_configs = [
+ ("gpt_oss", 256, 2, 2, 32, 32, 127, {"num_key_value_heads": 2}),
]
+
+
+def get_auto_config_from_test_config(configs):
+ auto_configs = [
+ AutoConfig.for_model(
+ model_name,
+ max_position_embeddings=max_position_embeddings,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ vocab_size=vocab_size,
+ **additional_params,
+ )
+ for (
+ model_name,
+ max_position_embeddings,
+ num_hidden_layers,
+ num_attention_heads,
+ hidden_size,
+ intermediate_size,
+ vocab_size,
+ additional_params,
+ ) in configs
+ ]
+ return auto_configs
+
+
+configs = get_auto_config_from_test_config(test_configs)
config_ids = [x.model_type for x in configs]
+prefill_only_configs = get_auto_config_from_test_config(test_prefill_only_specialized_models_configs)
+prefill_only_config_ids = [x.model_type for x in prefill_only_configs]
+
model_kwargs = {"attn_implementation": "eager"}
@@ -144,20 +158,21 @@ def test_causal_lm_export_and_hash(config, cb, tmp_path):
@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("subfunc", [False, True], ids=["False", "True"])
@pytest.mark.parametrize("config", configs, ids=config_ids)
-def test_causal_lm_hash_creation(config, cb, tmp_path):
+def test_causal_lm_hash_creation(config, cb, subfunc, tmp_path):
model = AutoModelForCausalLM.from_config(config, **model_kwargs)
qeff_model = QEFFAutoModelForCausalLM(model, cb)
- qeff_model.export(tmp_path)
+ qeff_model.export(tmp_path, use_onnx_subfunctions=subfunc)
hash_params = {}
hash_params["config"] = qeff_model.model.config.to_diff_dict()
hash_params["peft_config"] = None
hash_params["applied_transform_names"] = qeff_model._transform_names()
hash_params["qeff_auto_class"] = qeff_model.__class__.__name__
+ hash_params["max_seq_len_cached"] = None
hash_params["qaic_config"] = None
# Create parameters separately for hash creation
-
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
@@ -190,12 +205,12 @@ def test_causal_lm_hash_creation(config, cb, tmp_path):
)
output_names = []
output_names.append("logits")
-
+ onnx_out_name_suffix = "InternalRetainedState" if subfunc else "RetainedState"
for i in range(qeff_model.num_layers):
pkv_dynamic_axes[i][0] = "full_batch_size" if qeff_model.continuous_batching else "batch_size"
for kv in ["key", "value"]:
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i]
- output_names.append(f"past_{kv}.{i}_RetainedState")
+ output_names.append(f"past_{kv}.{i}_{onnx_out_name_suffix}")
if qeff_model.continuous_batching:
dynamic_axes["batch_index"] = {0: "batch_size"}
@@ -204,14 +219,35 @@ def test_causal_lm_hash_creation(config, cb, tmp_path):
export_params["output_names"] = output_names
export_params["dynamic_axes"] = dynamic_axes
hash_params["export_params"] = export_params
+ if subfunc:
+ hash_params["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model)
+
manual_hash = hash_dict_params(hash_params)
assert manual_hash == qeff_model.export_hash
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", prefill_only_configs, ids=prefill_only_config_ids)
+def test_prefill_only_specialized_models(config, cb, tmp_path):
+ model = AutoModelForCausalLM.from_config(config, **model_kwargs)
+ qeff_model = QEFFAutoModelForCausalLM(model, cb)
+ if cb:
+ with pytest.raises(NotImplementedError):
+ qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False)
+ else:
+ with pytest.raises(ValueError):
+ qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False)
+ qeff_model.export(tmp_path, prefill_only=True, prefill_seq_len=256, offload_pt_weights=False)
+ first_export_hash = qeff_model.export_hash
+ qeff_model.export(tmp_path, prefill_only=False, offload_pt_weights=False)
+ second_export_hash = qeff_model.export_hash
+ assert first_export_hash != second_export_hash
+
+
@pytest.fixture
def tmp_cache(tmp_path, monkeypatch):
- monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path)
+ monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path)
yield tmp_path
diff --git a/tests/transformers/test_speech_seq2seq.py b/tests/transformers/test_speech_seq2seq.py
index 59281b73b..bc53cb539 100644
--- a/tests/transformers/test_speech_seq2seq.py
+++ b/tests/transformers/test_speech_seq2seq.py
@@ -141,7 +141,7 @@ def test_seq2seq_hash_creation(config, tmp_path):
@pytest.fixture
def tmp_cache(tmp_path, monkeypatch):
- monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path)
+ monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path)
yield tmp_path
diff --git a/tests/transformers/test_subfunction.py b/tests/transformers/test_subfunction.py
index 36cfc0ce5..47e49cf2c 100644
--- a/tests/transformers/test_subfunction.py
+++ b/tests/transformers/test_subfunction.py
@@ -4,7 +4,9 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
+from collections import Counter
+import onnx
import pytest
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@@ -44,24 +46,74 @@
config_ids = [x.model_type for x in configs]
+def has_gpt2block_function(onnx_path):
+ """Check if ONNX model contains QEffGPT2Block function definition."""
+ model = onnx.load(onnx_path, load_external_data=False)
+ function_names = [f.name for f in model.functions]
+ gpt2block_functions = [name for name in function_names if "QEffGPT2Block" in name]
+ return len(gpt2block_functions) > 0, gpt2block_functions
+
+
+def get_gpt2block_call_count(onnx_path):
+ """Get count of QEffGPT2Block function calls in the ONNX model graph."""
+ model = onnx.load(onnx_path, load_external_data=False)
+ calls = Counter([n.op_type for n in model.graph.node])
+ gpt2block_calls = {k: v for k, v in calls.items() if "QEffGPT2Block" in k}
+ return gpt2block_calls
+
+
+@pytest.mark.on_qaic
@pytest.mark.parametrize("config", configs, ids=config_ids)
def test_subfunction_vs_nonsubfunction(config, tmp_path):
tokenizer = AutoTokenizer.from_pretrained(config.model_type)
model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False)
- # model_0_0 = QEFFAutoModelForCausalLM.from_pretrained(config.model_type)
+ # Export with subfunctions enabled
with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False)
- hash_0_0 = model_0_0.export_hash
+ # Export without subfunctions
without_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=False)
- hash_0_1 = model_0_0.export_hash
- assert hash_0_0 != hash_0_1
+ # Verify that the model with subfunctions has QEffGPT2Block function definition
+ has_gpt2block, gpt2block_names = has_gpt2block_function(with_sub_func_onnx)
+ assert has_gpt2block, (
+ "Model exported with use_onnx_subfunctions=True should contain QEffGPT2Block function definition"
+ )
+ print(f"\nGpt2Block functions found: {gpt2block_names}")
+
+ # Verify that the model without subfunctions has no QEffGPT2Block function definition
+ has_gpt2block_without, _ = has_gpt2block_function(without_sub_func_onnx)
+ assert not has_gpt2block_without, (
+ "Model exported with use_onnx_subfunctions=False should not contain QEffGPT2Block function definition"
+ )
+
+ # Get QEffGPT2Block call counts
+ gpt2block_calls_with_sub = get_gpt2block_call_count(with_sub_func_onnx)
+ gpt2block_calls_without_sub = get_gpt2block_call_count(without_sub_func_onnx)
+ print(f"\nGpt2Block call counts with subfunctions: {gpt2block_calls_with_sub}")
+ print(f"QEffGPT2Block call counts without subfunctions: {gpt2block_calls_without_sub}")
+
+ # Verify that QEffGPT2Block function calls exist in the subfunction model
+ assert len(gpt2block_calls_with_sub) > 0, (
+ "Expected to find QEffGPT2Block function calls in graph when use_onnx_subfunctions=True"
+ )
+
+ # Verify that QEffGPT2Block function calls do NOT exist in the non-subfunction model
+ assert len(gpt2block_calls_without_sub) == 0, (
+ "Expected NO QEffGPT2Block function calls in graph when use_onnx_subfunctions=False"
+ )
+
+ # Compile and test generation to ensure functional equivalence
compile_params = {"prefill_seq_len": 8, "ctx_len": 16}
+
model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params)
generation_00 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer)
model_0_0.compile(onnx_path=without_sub_func_onnx, **compile_params)
generation_01 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer)
- assert generation_00.generated_texts == generation_01.generated_texts
+
+ # Verify that both models produce the same output
+ assert generation_00.generated_texts == generation_01.generated_texts, (
+ "Models with and without subfunctions should produce identical outputs"
+ )
diff --git a/tests/utils/test_hash_utils.py b/tests/utils/test_hash_utils.py
index fefa73973..b7a5495c6 100644
--- a/tests/utils/test_hash_utils.py
+++ b/tests/utils/test_hash_utils.py
@@ -41,7 +41,7 @@ def test_to_hashable_float_nan(value):
def test_json_serializable():
# Test with a set
- assert json_serializable({1, 2, 3}) == [1, 2, 3]
+ assert json_serializable({1, 2, 3}) == ["1", "2", "3"]
# Test with an unsupported type
with pytest.raises(TypeError):
json_serializable({1, 2, 3, {4, 5}})