diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py
index be4b86321..013b4b537 100644
--- a/QEfficient/__init__.py
+++ b/QEfficient/__init__.py
@@ -48,6 +48,12 @@ def check_qaic_sdk():
QEFFCommonLoader,
)
from QEfficient.compile.compile_helper import compile
+
+ # Imports for the diffusers
+ from QEfficient.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import QEFFStableDiffusionPipeline
+ from QEfficient.diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion3 import (
+ QEFFStableDiffusion3Pipeline,
+ )
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.peft import QEffAutoPeftModelForCausalLM
@@ -67,6 +73,8 @@ def check_qaic_sdk():
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
+ "QEFFStableDiffusionPipeline",
+ "QEFFStableDiffusion3Pipeline",
]
else:
diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py
index d9d6823ae..ad4740af0 100644
--- a/QEfficient/base/modeling_qeff.py
+++ b/QEfficient/base/modeling_qeff.py
@@ -22,7 +22,7 @@
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
-from QEfficient.utils import constants, create_json, dump_qconfig, generate_mdp_partition_config, load_json
+from QEfficient.utils import constants, create_json, generate_mdp_partition_config, load_json
from QEfficient.utils.cache import QEFF_HOME, to_hashable
logger = logging.getLogger(__name__)
@@ -172,6 +172,7 @@ def _export(
try:
export_kwargs = {} if export_kwargs is None else export_kwargs
+ print("Export_kwargs:", export_kwargs)
torch.onnx.export(
self.model,
(example_inputs,),
@@ -179,7 +180,8 @@ def _export(
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
- opset_version=constants.ONNX_EXPORT_OPSET,
+ opset_version=17,
+ # verbose=True,
**export_kwargs,
)
logger.info("Pytorch export successful")
@@ -187,6 +189,7 @@ def _export(
model = onnx.load(tmp_onnx_path, load_external_data=False)
transform_kwargs = {
"onnx_base_dir": str(tmp_onnx_dir),
+ "temp_onnx_path": tmp_onnx_path,
"model_name": self.model_name,
}
if onnx_transform_kwargs is not None:
@@ -213,7 +216,7 @@ def _export(
self.onnx_path = onnx_path
return onnx_path
- @dump_qconfig
+ # @dump_qconfig
def _compile(
self,
onnx_path: Optional[str] = None,
@@ -352,6 +355,7 @@ def _compile(
command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
+ print(command)
try:
subprocess.run(command, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py
index 61b5c00f6..b579aa359 100644
--- a/QEfficient/base/onnx_transforms.py
+++ b/QEfficient/base/onnx_transforms.py
@@ -8,6 +8,8 @@
from typing import Optional, Tuple
import numpy as np
+import onnx
+import onnxslim
from onnx import ModelProto, external_data_helper, numpy_helper
@@ -99,3 +101,34 @@ def apply(
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
return model, transformed
+
+
+class OnnxSlimTransform(OnnxTransform):
+ """
+ Applies onnx-slim transformations on the given ONNX graph.
+ """
+
+ @classmethod
+ def apply(
+ cls,
+ model: ModelProto,
+ *,
+ onnx_base_dir: Optional[str] = None,
+ **kwargs,
+ ) -> Tuple[ModelProto, bool]:
+ """
+ :param enable_onnx_slim_transform: If True, applies onnx-slim transformations.
+ :param temp_onnx_path: Path to save the slimmed ONNX model.
+ """
+ transformed = False
+ onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False)
+ temp_onnx_path = kwargs.get("temp_onnx_path", None)
+ if not temp_onnx_path:
+ err_str = "temp_onnx_path is required for onnx-slim transform."
+ raise RuntimeError(err_str)
+ if onnx_slim_transform:
+ transformed = True
+ slimmed_model = onnxslim.slim(model)
+ onnx.save(slimmed_model, temp_onnx_path)
+ return slimmed_model, transformed
+ return model, transformed
diff --git a/QEfficient/diffusers/README.md b/QEfficient/diffusers/README.md
new file mode 100644
index 000000000..088108461
--- /dev/null
+++ b/QEfficient/diffusers/README.md
@@ -0,0 +1,110 @@
+
+
+
+
+# **Diffusion Models on Qualcomm Cloud AI 100**
+
+
+
+
+### 🎨 **Experience the Future of AI Image Generation**
+
+* Optimized for Qualcomm Cloud AI 100*
+
+

+
+**Generated with**: `stabilityai/stable-diffusion-3.5-large` • `"A girl laughing"` • 28 steps • 2.0 guidance scale • ⚡
+
+
+
+
+
+
+
+[](https://github.com/huggingface/diffusers)
+
+
+---
+
+## ✨ Overview
+
+QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware.
+
+## 🛠️ Installation
+
+### Prerequisites
+
+Ensure you have Python 3.8+ and the required dependencies:
+
+```bash
+# Create Python virtual environment (Recommended Python 3.10)
+sudo apt install python3.10-venv
+python3.10 -m venv qeff_env
+source qeff_env/bin/activate
+pip install -U pip
+```
+
+### Install QEfficient
+
+```bash
+# Install from GitHub (includes diffusers support)
+pip install git+https://github.com/quic/efficient-transformers
+
+# Or build from source
+git clone https://github.com/quic/efficient-transformers.git
+cd efficient-transformers
+pip install build wheel
+python -m build --wheel --outdir dist
+pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl
+```
+
+### Install Diffusers Dependencies
+
+```bash
+# Install diffusers optional dependencies
+pip install "QEfficient[diffusers]"
+```
+
+---
+
+## 🎯 Supported Models
+
+### Stable Diffusion 3.x Series
+- ✅ [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large)
+- ✅ [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo)
+---
+
+
+## 📚 Examples
+
+Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory:
+
+---
+
+## 🤝 Contributing
+
+We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details.
+
+### Development Setup
+
+```bash
+git clone https://github.com/quic/efficient-transformers.git
+cd efficient-transformers
+pip install -e ".[diffusers,test]"
+```
+
+---
+
+## 🙏 Acknowledgments
+
+- **HuggingFace Diffusers**: For the excellent foundation library
+- **Stability AI**: For the amazing Stable Diffusion models
+---
+
+## 📞 Support
+
+- 📖 **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/)
+- 🐛 **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues)
+
+---
+
diff --git a/QEfficient/diffusers/__init__.py b/QEfficient/diffusers/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/QEfficient/diffusers/models/__init__.py b/QEfficient/diffusers/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/QEfficient/diffusers/models/attention.py b/QEfficient/diffusers/models/attention.py
new file mode 100644
index 000000000..3c9cc268d
--- /dev/null
+++ b/QEfficient/diffusers/models/attention.py
@@ -0,0 +1,75 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import torch
+from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward
+
+
+class QEffJointTransformerBlock(JointTransformerBlock):
+ def forward(
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
+ ):
+ if self.use_dual_attention:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
+ hidden_states, emb=temb
+ )
+ else:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ if self.context_pre_only:
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
+ else:
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # Attention.
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
+ )
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ if self.use_dual_attention:
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
+ hidden_states = hidden_states + attn_output2
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ # ff_output = self.ff(norm_hidden_states)
+ ff_output = self.ff(norm_hidden_states, block_size=4096)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ if self.context_pre_only:
+ encoder_hidden_states = None
+ else:
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ context_ff_output = _chunked_feed_forward(
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
+ )
+ else:
+ # context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ context_ff_output = self.ff_context(norm_encoder_hidden_states, block_size=333)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return encoder_hidden_states, hidden_states
diff --git a/QEfficient/diffusers/models/attention_processor.py b/QEfficient/diffusers/models/attention_processor.py
new file mode 100644
index 000000000..01954e55e
--- /dev/null
+++ b/QEfficient/diffusers/models/attention_processor.py
@@ -0,0 +1,155 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+from typing import Optional
+
+import torch
+from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
+
+
+class QEffAttention(Attention):
+ def __qeff_init__(self):
+ processor = QEffJointAttnProcessor2_0()
+ self.processor = processor
+ processor.query_block_size = 64
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key,
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+
+class QEffJointAttnProcessor2_0(JointAttnProcessor2_0):
+ def __call__(
+ self,
+ attn: QEffAttention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
+
+ query = query.reshape(-1, query.shape[-2], query.shape[-1])
+ key = key.reshape(-1, key.shape[-2], key.shape[-1])
+ value = value.reshape(-1, value.shape[-2], value.shape[-1])
+
+ # pre-transpose the key
+ key = key.transpose(-1, -2)
+ if query.size(-2) != value.size(-2): # cross-attention, use regular attention
+ # QKV done in single block
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ else: # self-attention, use blocked attention
+ # QKV done with block-attention (a la FlashAttentionV2)
+ query_block_size = self.query_block_size
+ query_seq_len = query.size(-2)
+ num_blocks = (query_seq_len + query_block_size - 1) // query_block_size
+ for qidx in range(num_blocks):
+ query_block = query[:, qidx * query_block_size : (qidx + 1) * query_block_size, :]
+ attention_probs = attn.get_attention_scores(query_block, key, attention_mask)
+ hidden_states_block = torch.bmm(attention_probs, value)
+ if qidx == 0:
+ hidden_states = hidden_states_block
+ else:
+ hidden_states = torch.cat((hidden_states, hidden_states_block), -2)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if encoder_hidden_states is not None:
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
diff --git a/QEfficient/diffusers/models/autoencoders/__init__.py b/QEfficient/diffusers/models/autoencoders/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/models/autoencoders/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py b/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py
new file mode 100644
index 000000000..c652f07d2
--- /dev/null
+++ b/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py
@@ -0,0 +1,31 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import torch
+from diffusers import AutoencoderKL
+
+
+class QEffAutoencoderKL(AutoencoderKL):
+ def encode(self, x: torch.Tensor, return_dict: bool = True):
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+ return h
diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py
new file mode 100644
index 000000000..cceb116a1
--- /dev/null
+++ b/QEfficient/diffusers/models/pytorch_transforms.py
@@ -0,0 +1,53 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+from typing import Tuple
+
+from diffusers.models.attention import JointTransformerBlock
+from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
+from diffusers.models.normalization import RMSNorm
+from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
+from torch import nn
+
+from QEfficient.base.pytorch_transforms import ModuleMappingTransform
+from QEfficient.customop.rms_norm import CustomRMSNormAIC
+from QEfficient.diffusers.models.attention import QEffJointTransformerBlock
+from QEfficient.diffusers.models.attention_processor import (
+ QEffAttention,
+ QEffJointAttnProcessor2_0,
+)
+from QEfficient.diffusers.models.transformer_sd3 import QEffSD3Transformer2DModel
+
+
+class CustomOpsTransform(ModuleMappingTransform):
+ _module_mapping = {RMSNorm: CustomRMSNormAIC}
+
+ @classmethod
+ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
+ model, transformed = super().apply(model)
+ return model, transformed
+
+
+class AttentionTransform(ModuleMappingTransform):
+ _module_mapping = {
+ Attention: QEffAttention,
+ JointAttnProcessor2_0: QEffJointAttnProcessor2_0,
+ JointTransformerBlock: QEffJointTransformerBlock,
+ }
+
+ @classmethod
+ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
+ model, transformed = super().apply(model)
+ return model, transformed
+
+
+class OnnxFunctionTransform(ModuleMappingTransform):
+ _module_mapping = {SD3Transformer2DModel: QEffSD3Transformer2DModel}
+
+ @classmethod
+ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
+ model, transformed = super().apply(model)
+ return model, transformed
diff --git a/QEfficient/diffusers/models/transformer_sd3.py b/QEfficient/diffusers/models/transformer_sd3.py
new file mode 100644
index 000000000..1faf38ef8
--- /dev/null
+++ b/QEfficient/diffusers/models/transformer_sd3.py
@@ -0,0 +1,30 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import torch.nn as nn
+from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
+
+from .attention import QEffJointTransformerBlock
+
+
+class QEffSD3Transformer2DModel(SD3Transformer2DModel):
+ def __qeff_init__(self):
+ self.transformer_blocks = nn.ModuleList()
+ self._block_classes = set()
+
+ for i in range(self.config.num_layers):
+ BlockClass = QEffJointTransformerBlock
+ block = BlockClass(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ context_pre_only=i == self.config.num_layers - 1,
+ qk_norm=self.config.qk_norm,
+ use_dual_attention=True if i in self.dual_attention_layers else False,
+ )
+ self.transformer_blocks.append(block)
+ self._block_classes.add(BlockClass)
diff --git a/QEfficient/diffusers/pipelines/__init__.py b/QEfficient/diffusers/pipelines/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py
new file mode 100644
index 000000000..24a6a9e5d
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/pipeline_utils.py
@@ -0,0 +1,453 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import copy
+import hashlib
+
+import torch
+import torch.nn as nn
+
+from QEfficient.base.modeling_qeff import QEFFBaseModel
+from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
+from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform, CustomOpsTransform, OnnxFunctionTransform
+from QEfficient.transformers.models.pytorch_transforms import (
+ T5ModelTransform,
+)
+from QEfficient.utils import constants
+from QEfficient.utils.cache import to_hashable
+
+
+class QEffTextEncoder(QEFFBaseModel):
+ _pytorch_transforms = [CustomOpsTransform, T5ModelTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+ """
+ QEffTextEncoder is a wrapper class for text encoder models that provides ONNX export and compilation capabilities.
+
+ This class extends QEFFBaseModel to handle text encoder models (like T5EncoderModel) with specific
+ transformations and optimizations for efficient inference on Qualcomm AI hardware.
+ """
+
+ def __init__(self, model: nn.modules):
+ super().__init__(model)
+ self.model = copy.deepcopy(model)
+
+ def get_onnx_config(self):
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+ seq_len = self.tokenizer.model_max_length
+
+ example_inputs = {
+ "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
+ }
+
+ dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}}
+
+ output_names = ["pooler_output", "last_hidden_state"]
+ if self.model.__class__.__name__ == "T5EncoderModel":
+ output_names = ["last_hidden_state"]
+ else:
+ example_inputs["output_hidden_states"] = (True,)
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(self, inputs, output_names, dynamic_axes, export_dir=None):
+ return self._export(inputs, output_names, dynamic_axes, export_dir)
+
+ def get_specializations(
+ self,
+ batch_size: int,
+ seq_len: int,
+ ):
+ specializations = [
+ {"batch_size": batch_size, "seq_len": seq_len},
+ ]
+
+ return specializations
+
+ def compile(
+ self,
+ compile_dir,
+ compile_only,
+ specializations,
+ convert_to_fp16,
+ mxfp6_matmul,
+ mdp_ts_num_devices,
+ aic_num_cores,
+ custom_io,
+ **compiler_options,
+ ) -> str:
+ return self._compile(
+ compile_dir=compile_dir,
+ compile_only=compile_only,
+ specializations=specializations,
+ convert_to_fp16=convert_to_fp16,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=mdp_ts_num_devices,
+ aic_num_cores=aic_num_cores,
+ custom_io=custom_io,
+ **compiler_options,
+ )
+
+ @property
+ def model_hash(self) -> str:
+ # Compute the hash with: model_config, continuous_batching, transforms
+ mhash = hashlib.sha256()
+ mhash.update(to_hashable(self.model.config.to_diff_dict()))
+ mhash.update(to_hashable(self._transform_names()))
+ mhash = mhash.hexdigest()[:16]
+ return mhash
+
+ @property
+ def model_name(self) -> str:
+ mname = self.model.__class__.__name__
+ if mname.startswith("QEff") or mname.startswith("QEFF"):
+ mname = mname[4:]
+ return mname
+
+ @property
+ def get_model_config(self) -> dict:
+ return self.model.model.vision_model.config.__dict__
+
+
+class QEffUNet(QEFFBaseModel):
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ """
+ QEffUNet is a wrapper class for UNet models that provides ONNX export and compilation capabilities.
+
+ This class extends QEFFBaseModel to handle UNet models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. It is commonly used in diffusion models for image
+ generation tasks.
+ """
+
+ def __init__(self, model: nn.modules):
+ super().__init__(model.unet)
+ self.model = model.unet
+
+ def export(self, inputs, output_names, dynamic_axes, export_dir=None):
+ return self._export(inputs, output_names, dynamic_axes, export_dir)
+
+ def compile(
+ self,
+ compile_dir,
+ compile_only,
+ specializations,
+ convert_to_fp16,
+ mxfp6_matmul,
+ mdp_ts_num_devices,
+ aic_num_cores,
+ custom_io,
+ **compiler_options,
+ ) -> str:
+ return self._compile(
+ compile_dir=compile_dir,
+ compile_only=compile_only,
+ specializations=specializations,
+ convert_to_fp16=convert_to_fp16,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=mdp_ts_num_devices,
+ aic_num_cores=aic_num_cores,
+ custom_io=custom_io,
+ **compiler_options,
+ )
+
+ @property
+ def model_hash(self) -> str:
+ # Compute the hash with: model_config, continuous_batching, transforms
+ mhash = hashlib.sha256()
+ mhash.update(to_hashable(dict(self.model.config)))
+ mhash.update(to_hashable(self._transform_names()))
+ mhash = mhash.hexdigest()[:16]
+ return mhash
+
+ @property
+ def model_name(self) -> str:
+ mname = self.model.__class__.__name__
+ if mname.startswith("QEff") or mname.startswith("QEFF"):
+ mname = mname[4:]
+ return mname
+
+ @property
+ def get_model_config(self) -> dict:
+ return self.model.model.vision_model.config.__dict__
+
+
+class QEffVAE(QEFFBaseModel):
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ """
+ QEffVAE is a wrapper class for Variational Autoencoder (VAE) models that provides ONNX export and compilation capabilities.
+
+ This class extends QEFFBaseModel to handle VAE models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. VAE models are commonly used in diffusion pipelines
+ for encoding images to latent space and decoding latent representations back to images.
+ """
+
+ def __init__(self, model: nn.modules, type: str):
+ super().__init__(model.vae)
+ self.model = copy.deepcopy(model.vae)
+ self.type = type
+
+ def get_onnx_config(self):
+ # VAE decode
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+ example_inputs = {
+ "latent_sample": torch.randn(bs, 16, 64, 64),
+ "return_dict": False,
+ }
+
+ output_names = ["sample"]
+
+ dynamic_axes = {
+ "latent_sample": {0: "batch_size", 1: "channels", 2: "height", 3: "width"},
+ }
+ return example_inputs, dynamic_axes, output_names
+
+ def export(self, inputs, output_names, dynamic_axes, export_dir=None):
+ return self._export(inputs, output_names, dynamic_axes, export_dir)
+
+ def get_specializations(
+ self,
+ batch_size: int,
+ ):
+ sepcializations = [
+ {
+ "batch_size": batch_size,
+ "channels": 16,
+ "height": 128,
+ "width": 128,
+ }
+ ]
+ return sepcializations
+
+ def compile(
+ self,
+ compile_dir,
+ compile_only,
+ specializations,
+ convert_to_fp16,
+ mxfp6_matmul,
+ mdp_ts_num_devices,
+ aic_num_cores,
+ custom_io,
+ **compiler_options,
+ ) -> str:
+ return self._compile(
+ compile_dir=compile_dir,
+ compile_only=compile_only,
+ specializations=specializations,
+ convert_to_fp16=convert_to_fp16,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=mdp_ts_num_devices,
+ aic_num_cores=aic_num_cores,
+ custom_io=custom_io,
+ **compiler_options,
+ )
+
+ @property
+ def model_hash(self) -> str:
+ # Compute the hash with: model_config, continuous_batching, transforms
+ mhash = hashlib.sha256()
+ mhash.update(to_hashable(dict(self.model.config)))
+ mhash.update(to_hashable(self._transform_names()))
+ mhash.update(to_hashable(self.type))
+ mhash = mhash.hexdigest()[:16]
+ return mhash
+
+ @property
+ def model_name(self) -> str:
+ mname = self.model.__class__.__name__
+ if mname.startswith("QEff") or mname.startswith("QEFF"):
+ mname = mname[4:]
+ return mname
+
+ @property
+ def get_model_config(self) -> dict:
+ return self.model.model.vision_model.config.__dict__
+
+
+class QEffSafetyChecker(QEFFBaseModel):
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ """
+ QEffSafetyChecker is a wrapper class for safety checker models that provides ONNX export and compilation capabilities.
+
+ This class extends QEFFBaseModel to handle safety checker models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. Safety checker models are commonly used in diffusion pipelines
+ to filter out potentially harmful or inappropriate generated content.
+ """
+
+ def __init__(self, model: nn.modules):
+ super().__init__(model.vae)
+ self.model = model.safety_checker
+
+ def export(self, inputs, output_names, dynamic_axes, export_dir=None):
+ return self._export(inputs, output_names, dynamic_axes, export_dir)
+
+ def compile(
+ self,
+ compile_dir,
+ compile_only,
+ specializations,
+ convert_to_fp16,
+ mxfp6_matmul,
+ mdp_ts_num_devices,
+ aic_num_cores,
+ custom_io,
+ **compiler_options,
+ ) -> str:
+ return self._compile(
+ compile_dir=compile_dir,
+ compile_only=compile_only,
+ specializations=specializations,
+ convert_to_fp16=convert_to_fp16,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=mdp_ts_num_devices,
+ aic_num_cores=aic_num_cores,
+ custom_io=custom_io,
+ **compiler_options,
+ )
+
+ @property
+ def model_hash(self) -> str:
+ # Compute the hash with: model_config, continuous_batching, transforms
+ mhash = hashlib.sha256()
+ mhash.update(to_hashable(self.model.config.to_diff_dict()))
+ mhash.update(to_hashable(self._transform_names()))
+ mhash = mhash.hexdigest()[:16]
+ return mhash
+
+ @property
+ def model_name(self) -> str:
+ mname = self.model.__class__.__name__
+ if mname.startswith("QEff") or mname.startswith("QEFF"):
+ mname = mname[4:]
+ return mname
+
+ @property
+ def get_model_config(self) -> dict:
+ return self.model.model.vision_model.config.__dict__
+
+
+class QEffSD3Transformer2DBaseModel(QEFFBaseModel):
+ _pytorch_transforms = [AttentionTransform, CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ """
+ QEffSD3Transformer2DModel is a wrapper class for Stable Diffusion 3 Transformer2D models that provides ONNX export and compilation capabilities.
+
+ This class extends QEFFBaseModel to handle SD3 Transformer2D models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. It is designed for the newer Stable Diffusion 3 architecture
+ that uses transformer-based diffusion models instead of traditional UNet architectures.
+ """
+
+ def __init__(self, model: nn.modules, use_onnx_function):
+ super().__init__(model)
+ if use_onnx_function:
+ self._pytorch_transforms.append(OnnxFunctionTransform)
+ model, _ = OnnxFunctionTransform.apply(model)
+ self.model = model
+
+ def get_onnx_config(self):
+ example_inputs = {
+ "hidden_states": torch.randn(
+ 2,
+ self.model.config.in_channels,
+ self.model.config.sample_size,
+ self.model.config.sample_size,
+ ),
+ "encoder_hidden_states": torch.randn(2, 333, self.model.config.joint_attention_dim),
+ "pooled_projections": torch.randn(2, self.model.config.pooled_projection_dim),
+ "timestep": torch.randint(0, 20, (2,), dtype=torch.int64),
+ }
+
+ output_names = ["output"]
+
+ dynamic_axes = {
+ "hidden_states": {0: "batch_size", 1: "latent_channels", 2: "latent_height", 3: "latent_width"},
+ "encoder_hidden_states": {0: "batch_size", 1: "seq_len"},
+ "pooled_projections": {0: "batch_size"},
+ "timestep": {0: "steps"},
+ "output": {0: "batch_size", 1: "latent_channels", 2: "latent_height", 3: "latent_width"},
+ }
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs,
+ output_names,
+ dynamic_axes,
+ export_dir=None,
+ export_kwargs=None,
+ ):
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ )
+
+ def get_specializations(
+ self,
+ batch_size: int,
+ seq_len: int,
+ ):
+ specializations = [
+ {
+ "batch_size": 2 * batch_size,
+ "latent_channels": 16,
+ "latent_height": self.model.config.sample_size,
+ "latent_width": self.model.config.sample_size,
+ "seq_len": seq_len,
+ "steps": 1,
+ }
+ ]
+
+ return specializations
+
+ def compile(
+ self,
+ compile_dir,
+ compile_only,
+ specializations,
+ convert_to_fp16,
+ mxfp6_matmul,
+ mdp_ts_num_devices,
+ aic_num_cores,
+ custom_io,
+ **compiler_options,
+ ) -> str:
+ return self._compile(
+ compile_dir=compile_dir,
+ compile_only=compile_only,
+ specializations=specializations,
+ convert_to_fp16=convert_to_fp16,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=mdp_ts_num_devices,
+ aic_num_cores=aic_num_cores,
+ custom_io=custom_io,
+ **compiler_options,
+ )
+
+ @property
+ def model_hash(self) -> str:
+ # Compute the hash with: model_config, continuous_batching, transforms
+ mhash = hashlib.sha256()
+ mhash.update(to_hashable(dict(self.model.config)))
+ mhash.update(to_hashable(self._transform_names()))
+ mhash = mhash.hexdigest()[:16]
+ return mhash
+
+ @property
+ def model_name(self) -> str:
+ mname = self.model.__class__.__name__
+ if mname.startswith("QEff") or mname.startswith("QEFF"):
+ mname = mname[4:]
+ return mname
diff --git a/QEfficient/diffusers/pipelines/stable_diffusion/__init__.py b/QEfficient/diffusers/pipelines/stable_diffusion/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/stable_diffusion/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/QEfficient/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
new file mode 100644
index 000000000..7f14f47d7
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -0,0 +1,481 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import os
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from diffusers import StableDiffusionPipeline
+from diffusers.image_processor import VaeImageProcessor
+
+from QEfficient.diffusers.pipelines.pipeline_utils import QEffSafetyChecker, QEffTextEncoder, QEffUNet, QEffVAE
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils import constants
+
+
+class QEFFStableDiffusionPipeline(StableDiffusionPipeline):
+ _hf_auto_class = StableDiffusionPipeline
+
+ def __init__(self, model, *args, **kwargs):
+ # super().__init__(*args, **kwargs)
+ self.tokenizer = model.tokenizer
+ self.scheduler = model.scheduler
+ self.feature_extractor = model.feature_extractor
+
+ self.text_encoder = QEffTextEncoder(model)
+ self.unet = QEffUNet(model)
+
+ # VAE Encoder
+ self.vae_encoder = QEffVAE(model, "encoder")
+ self.vae_encoder.model.forward = lambda sample, return_dict: self.vae_encoder.model.encode(sample, return_dict)
+
+ # VAE Decoder
+ self.vae_decoder = QEffVAE(model, "decoder")
+ self.vae_decoder.model.forward = lambda latent_sample, return_dict: self.vae_decoder.model.decode(
+ latent_sample, return_dict
+ )
+
+ # Saftey Checker
+ self.safety_checker = QEffSafetyChecker(model)
+ self.safety_checker.model.forward = model.safety_checker.forward_onnx
+
+ self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
+
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.model.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ kwargs.update({"attn_implementation": "eager"})
+ model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float32, **kwargs)
+ model.to("cpu")
+ return cls(model, pretrained_model_name_or_path)
+
+ def export(self, export_dir: Optional[str] = None) -> str:
+ """
+ Exports the model to ``ONNX`` format using ``torch.onnx.export``.
+
+ ``Optional`` Args:
+ :export_dir (str, optional): The directory path to store ONNX-graph.
+
+ Returns:
+ :str: Path of the generated ``ONNX`` graph.
+ """
+
+ # Text encoder export
+
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+ seq_len = self.tokenizer.model_max_length
+
+ example_inputs = {
+ "input_ids": torch.zeros((bs, seq_len), dtype=torch.int32),
+ # "attention_mask": torch.ones((bs, seq_len), dtype=bool),
+ }
+
+ dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}}
+
+ output_names = ["last_hidden_state", "pooler_output"]
+
+ # self.text_encoder.model.set_attn_processor(AttnProcessor())
+
+ # config = self.text_encoder.model.text_model.config
+ # for layer in self.text_encoder.model.text_model.encoder.layers:
+ # layer.self_attn = CLIPAttention(config)
+
+ self.text_encoder_onnx_path = self.text_encoder.export(
+ example_inputs,
+ output_names,
+ dynamic_axes,
+ export_dir=export_dir,
+ )
+
+ # UNET Export
+
+ print("###################### Text Encoder Exported #####################")
+
+ unet_example_input = {
+ "sample": torch.randn(
+ bs, self.unet.model.in_channels, self.unet.model.config.sample_size, self.unet.model.config.sample_size
+ ),
+ "timestep": torch.tensor([1]),
+ "encoder_hidden_states": torch.randn(bs, seq_len, self.unet.model.config.cross_attention_dim),
+ "return_dict": False,
+ }
+
+ output_names = ["out_sample"]
+
+ dynamic_axes = {
+ "sample": {0: "batch_size", 1: "channels", 2: "height", 3: "width"},
+ "timestep": {0: "batch_size"},
+ "encoder_hidden_states": {0: "batch_size", 1: "seq_len"},
+ }
+ # self.unet.model.set_attn_processor(AttnProcessor())
+
+ self.unet_onnx_path = self.unet.export(
+ unet_example_input,
+ output_names,
+ dynamic_axes,
+ export_dir=export_dir,
+ )
+
+ print("###################### UNet Exported #####################")
+
+ vae_encoder_input = {
+ "sample": torch.randn(bs, 3, 512, 512),
+ "return_dict": False,
+ }
+
+ output_names = ["latent_sample"]
+
+ dynamic_axes = {
+ "sample": {0: "batch_size", 1: "channels", 2: "height", 3: "width"},
+ }
+
+ # self.vae_encoder.model.set_attn_processor(AttnProcessor())
+
+ self.vae_encoder_onnx_path = self.vae_encoder.export(
+ vae_encoder_input,
+ output_names,
+ dynamic_axes,
+ export_dir=None,
+ )
+
+ print("###################### VAE Encoder Exported #####################")
+
+ vae_decoder_input = {
+ "latent_sample": torch.randn(bs, 4, 64, 64),
+ "return_dict": False,
+ }
+
+ output_names = ["sample"]
+
+ dynamic_axes = {
+ "latent_sample": {0: "batch_size", 1: "channels", 2: "height", 3: "width"},
+ }
+
+ # self.vae_decoder.model.set_attn_processor(AttnProcessor())
+
+ self.vae_decoder_onnx_path = self.vae_decoder.export(
+ vae_decoder_input,
+ output_names,
+ dynamic_axes,
+ export_dir=None,
+ )
+
+ print("###################### VAE Decoder Exported #####################")
+
+ saftey_checker_input = {"clip_input": torch.randn(bs, 3, 224, 224), "images": torch.randn(bs, 3, 512, 512)}
+ output_names = ["out_images", "has_nsfw_concepts"]
+
+ dynamic_axes = {
+ "clip_input": {0: "batch_size", 1: "channels", 2: "clip_height", 3: "clip_width"},
+ "images": {0: "batch_size", 1: "channels", 2: "height", 3: "width"},
+ }
+
+ # self.safety_checker.model.set_attn_processor(AttnProcessor())
+
+ # for layer in self.safety_checker.model.vision_model.vision_model.encoder.layers:
+ # config = self.safety_checker.model.config.vision_config
+ # layer.self_attn = CLIPAttention(config)
+ # Replace with eager version
+
+ self.safety_checker_onnx_path = self.safety_checker.export(
+ saftey_checker_input,
+ output_names,
+ dynamic_axes,
+ export_dir=None,
+ )
+
+ print("###################### Safety Checker Exported #####################")
+
+ def compile(
+ self,
+ onnx_path: Optional[str] = None,
+ compile_dir: Optional[str] = None,
+ *,
+ seq_len: Union[int, List[int]] = 32,
+ batch_size: int = 1,
+ num_devices: int = 1,
+ num_cores: int = 16, # FIXME: Make this mandatory arg
+ mxfp6_matmul: bool = False,
+ **compiler_options,
+ ) -> str:
+ # Compile text_encoder
+
+ # Make specilization
+
+ seq_len = self.tokenizer.model_max_length
+
+ specializations = [
+ {"batch_size": batch_size, "seq_len": seq_len},
+ ]
+
+ self.text_encoder_compile_path = self.text_encoder._compile(
+ onnx_path,
+ compile_dir,
+ compile_only=True,
+ specializations=specializations,
+ convert_to_fp16=True,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=num_devices,
+ aic_num_cores=num_cores,
+ **compiler_options,
+ )
+
+ print("###################### Text Encoder Compiled #####################")
+
+ # Compile unet
+
+ specializations = [
+ {
+ "batch_size": batch_size,
+ "channels": 4,
+ "height": self.unet.model.config.sample_size,
+ "width": self.unet.model.config.sample_size,
+ "seq_len": seq_len,
+ }
+ ]
+
+ self.compiled_unet_path = self.unet._compile(
+ onnx_path,
+ compile_dir,
+ compile_only=True,
+ specializations=specializations,
+ convert_to_fp16=True,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=num_devices,
+ aic_num_cores=num_cores,
+ **compiler_options,
+ )
+
+ print("###################### Unet Compiled #####################")
+
+ # Compile vae_encoder
+
+ # encoder_specializations = [
+ # {
+ # "batch_size": batch_size,
+ # "channels": self.vae_encoder.model.config.in_channels,
+ # "height": self.vae_encoder.model.config.sample_size,
+ # "width": self.vae_encoder.model.config.sample_size,
+ # }
+ # ]
+
+ # self.vae_encoder_compile_path=self.vae_encoder._compile(
+ # onnx_path,
+ # compile_dir,
+ # compile_only=True,
+ # specializations=encoder_specializations,
+ # convert_to_fp16=True,
+ # )
+
+ print("###################### VAE Encoder Compiled #####################")
+
+ # compile vae decoder
+
+ # decoder_sepcializations = [
+ # {
+ # "batch_size": batch_size,
+ # "channels": 4,
+ # "height": self.vae_decoder.model.config.sample_size,
+ # "width": self.vae_decoder.model.config.sample_size,
+ # }
+ # ]
+
+ # self.vae_decoder_compile_path=self.vae_decoder._compile(
+ # onnx_path,
+ # compile_dir,
+ # compile_only=True,
+ # specializations=decoder_sepcializations,
+ # convert_to_fp16=True,
+ # )
+
+ # TODO: Add support of comilation for now it will run on host
+
+ print("###################### VAE Decoder Compiled #####################")
+
+ # compile safety check
+
+ safety_check_specializations = [
+ {
+ "batch_size": batch_size,
+ "channels": 3,
+ "height": 512,
+ "width": 512,
+ "clip_height": 224,
+ "clip_width": 224,
+ }
+ ]
+
+ self.compiled_safety_checker_path = self.safety_checker._compile(
+ onnx_path,
+ compile_dir,
+ compile_only=True,
+ specializations=safety_check_specializations,
+ convert_to_fp16=True,
+ )
+
+ print("###################### Safety Checker Compiled #####################")
+
+ # def generate()
+
+ @property
+ def model_name(self) -> str:
+ pass
+
+ @property
+ def model_hash(self) -> str:
+ pass
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ device_ids: List[int] = [0],
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ **kwargs,
+ ):
+ # # Get output for text_encoder
+ if self.text_encoder.qpc_session is None:
+ self.text_encoder.qpc_session = QAICInferenceSession(str(self.text_encoder_compile_path), device_ids)
+
+ # Dynamic switching to closest seq_Len based on input_ids_len
+
+ # find the inputs/attention mask shape for which qpc compiled
+ bs, compield_inputs_shape = self.text_encoder.qpc_session.bindings[0].dims
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_encoder_output = {
+ "last_hidden_state": np.random.rand(bs, 77, 768).astype(np.float32),
+ "pooler_output": np.random.rand(bs, 768).astype(np.float32),
+ }
+ self.text_encoder.qpc_session.set_buffers(text_encoder_output)
+ ## Testing with the ORT output ##
+
+ import onnxruntime as ort
+
+ ort_session = ort.InferenceSession(str(self.text_encoder.onnx_path))
+
+ onnx_inputs = {k: v for k, v in text_inputs.items() if k in [i.name for i in ort_session.get_inputs()]}
+
+ onnx_inputs["input_ids"] = onnx_inputs["input_ids"].astype(np.int32)
+
+ ort_outputs = ort_session.run(None, onnx_inputs)
+ text_inputs_pt = {k: torch.from_numpy(v) for k, v in onnx_inputs.items()}
+
+ pt_output = self.text_encoder.model(**text_inputs_pt)
+ mad = torch.mean(torch.abs(pt_output[0] - torch.tensor(ort_outputs[0])))
+ print("CLIP: MAD onnx vs pytorch", mad)
+
+ self.text_encoder.qpc_session.set_buffers(text_encoder_output)
+ ai100_output = self.text_encoder.qpc_session.run(onnx_inputs)
+ mad_ai100_onnnx = np.mean(np.abs(ai100_output["last_hidden_state"] - ort_outputs[0]))
+
+ print("CLIP: MAD ai100 vs onnx", mad_ai100_onnnx)
+
+ ai100_output = ai100_output["last_hidden_state"]
+
+ ## CLIP done here
+ # 4. Prepare timesteps
+
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas)
+ timesteps = timesteps.numpy()
+ # 5. Prepare latent variables
+ # 0. Default height and width to unet
+ # timesteps = timesteps.astype(np.float32)
+
+ width = height = self.unet.model.config.sample_size
+ height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
+
+ num_channels_latents = self.unet.model.config.in_channels
+ latents = self.prepare_latents(
+ bs,
+ num_channels_latents,
+ height,
+ width,
+ torch.float32,
+ generator,
+ latents,
+ )
+
+ # Load qpc
+ self.unet_qpc_session = QAICInferenceSession(str(self.compiled_unet_path), [1])
+
+ unet_output = {"out_sample": np.random.rand(bs, 4, 64, 64).astype(np.float32)}
+ self.unet_qpc_session.set_buffers(unet_output)
+
+ # 3. Denoising loop
+ for t in timesteps:
+ latent_input = latents
+ latent_input = self.scheduler.scale_model_input(latent_input, t)
+ noise_pred = self.unet_qpc_session.run(
+ {"encoder_hidden_states": ai100_output, "timestep": t, "sample": latent_input.numpy()}
+ )
+ latents = self.scheduler.step(noise_pred["out_sample"], t, latents).prev_sample
+
+ # VAE decode step
+ # TODO: Add QPC for VAE decode
+ image = self.vae_decoder.model(latents / self.vae_decoder.model.config.scaling_factor, return_dict=False)[0]
+
+ # Saftey check
+
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image.detach(), output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt")
+
+ self.safety_checker_session = QAICInferenceSession(str(self.compiled_safety_checker_path), [2])
+
+ safety_checker_output = {
+ "out_images": np.random.rand(1, 3, 512, 512).astype(np.float32),
+ "has_nsfw_concepts": np.bool_(1),
+ }
+ self.safety_checker_session.set_buffers(safety_checker_output)
+
+ checker_output = self.safety_checker_session.run(
+ {"clip_input": safety_checker_input["pixel_values"].numpy(), "images": image.detach().numpy()}
+ )
+
+ has_nsfw_concept = checker_output["has_nsfw_concepts"].astype("bool")
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+ image = self.image_processor.postprocess(image.detach(), output_type=output_type, do_denormalize=do_denormalize)
+
+ # self.maybe_free_model_hooks()
+
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/QEfficient/diffusers/pipelines/stable_diffusion_3/__init__.py b/QEfficient/diffusers/pipelines/stable_diffusion_3/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/stable_diffusion_3/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion3.py b/QEfficient/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion3.py
new file mode 100644
index 000000000..9d9a440ac
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion3.py
@@ -0,0 +1,929 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import os
+from typing import Any, Callable, Dict, List, Optional, Union
+from venv import logger
+
+import numpy as np
+import torch
+from diffusers import StableDiffusion3Pipeline
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
+
+from QEfficient.diffusers.pipelines.pipeline_utils import QEffSD3Transformer2DBaseModel, QEffTextEncoder, QEffVAE
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils import constants
+
+
+class QEFFStableDiffusion3Pipeline(StableDiffusion3Pipeline):
+ _hf_auto_class = StableDiffusion3Pipeline
+ """
+ A QEfficient-optimized Stable Diffusion 3 pipeline, inheriting from `diffusers.StableDiffusion3Pipeline`.
+
+ This class integrates QEfficient components (e.g., optimized models for text encoder,
+ transformer, and VAE) to enhance performance, particularly for deployment on Qualcomm AI hardware.
+ It provides methods for text-to-image generation leveraging these optimized components.
+ """
+
+ def __init__(self, model, use_onnx_function, *args, **kwargs):
+ self.use_onnx_function = use_onnx_function
+ self.text_encoder = QEffTextEncoder(model.text_encoder)
+ self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2)
+ self.text_encoder_3 = QEffTextEncoder(model.text_encoder_3)
+ self.transformer = QEffSD3Transformer2DBaseModel(model.transformer, self.use_onnx_function)
+ self.vae_decode = QEffVAE(model, "decoder")
+
+ self.tokenizer = model.tokenizer
+ self.text_encoder.tokenizer = model.tokenizer
+ self.text_encoder_2.tokenizer = model.tokenizer_2
+ self.text_encoder_3.tokenizer = model.tokenizer_3
+ self.tokenizer_max_length = model.tokenizer_max_length
+ self.scheduler = model.scheduler
+
+ self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode(
+ latent_sample, return_dict
+ )
+
+ self.vae_scale_factor = (
+ 2 ** (len(model.vae.config.block_out_channels) - 1) if getattr(model, "vae", None) else 8
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=model.vae_scale_factor)
+
+ self.t_max_length = (
+ model.tokenizer.model_max_length if hasattr(model, "tokenizer") and model.tokenizer is not None else 77
+ )
+ self.default_sample_size = (
+ model.transformer.config.sample_size
+ if hasattr(model, "transformer") and model.transformer is not None
+ else 128
+ )
+ self.patch_size = (
+ model.transformer.config.patch_size
+ if hasattr(model, "transformer") and model.transformer is not None
+ else 2
+ )
+
+ @classmethod
+ def from_pretrained(
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], use_onnx_function=False, **kwargs
+ ):
+ """
+ Instantiate a QEFFStableDiffusion3Pipeline from pretrained Diffusers models.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ The path to the pretrained model or its name.
+ **kwargs (additional keyword arguments):
+ Additional arguments that can be passed to the underlying `StableDiffusion3Pipeline.from_pretrained`
+ method.
+ """
+ model = cls._hf_auto_class.from_pretrained(
+ pretrained_model_name_or_path,
+ torch_dtype=torch.float32,
+ **kwargs,
+ )
+ model.to("cpu")
+ return cls(
+ model=model,
+ use_onnx_function=use_onnx_function,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ )
+
+ def export(self, export_dir: Optional[str] = None) -> str:
+ """
+ Exports the model to ``ONNX`` format using ``torch.onnx.export``.
+
+ ``Optional`` Args:
+ :export_dir (str, optional): The directory path to store ONNX-graph.
+
+ Returns:
+ :str: Path of the generated ``ONNX`` graph.
+ """
+
+ # text_encoder
+ example_inputs_text_encoder, dynamic_axes_text_encoder, output_names_text_encoder = (
+ self.text_encoder.get_onnx_config()
+ )
+
+ for i in range(0, 13):
+ output_names_text_encoder.append("hidden_states_" + str(i))
+ self.text_encoder.export(
+ inputs=example_inputs_text_encoder,
+ output_names=output_names_text_encoder,
+ dynamic_axes=dynamic_axes_text_encoder,
+ export_dir=export_dir,
+ )
+
+ # text_encoder_2
+ example_inputs_text_encoder_2, dynamic_axes_text_encoder_2, output_names_text_encoder_2 = (
+ self.text_encoder_2.get_onnx_config()
+ )
+
+ for i in range(0, 33):
+ output_names_text_encoder_2.append("hidden_states_" + str(i))
+
+ self.text_encoder_2.export(
+ inputs=example_inputs_text_encoder_2,
+ output_names=output_names_text_encoder_2,
+ dynamic_axes=dynamic_axes_text_encoder_2,
+ export_dir=export_dir,
+ )
+
+ # t5_text_encoder
+ example_inputs_text_encoder_3, dynamic_axes_text_encoder_3, output_names_text_encoder_3 = (
+ self.text_encoder_3.get_onnx_config()
+ )
+
+ with torch.no_grad():
+ prev_sf = 1
+ for i in range(len(self.text_encoder_3.model.encoder.block)):
+ wosf = constants.WO_SFS[i]
+ self.text_encoder_3.model.encoder.block[i].layer[0].SelfAttention.o.weight *= 1 / wosf
+ self.text_encoder_3.model.encoder.block[i].layer[0].scaling_factor *= prev_sf / wosf
+ self.text_encoder_3.model.encoder.block[i].layer[1].DenseReluDense.wo.weight *= 1 / wosf
+ prev_sf = wosf
+
+ self.text_encoder_3.export(
+ inputs=example_inputs_text_encoder_3,
+ output_names=output_names_text_encoder_3,
+ dynamic_axes=dynamic_axes_text_encoder_3,
+ export_dir=export_dir,
+ )
+
+ # transformers
+ example_inputs_transformer, dynamic_axes_transformer, output_names_transformer = (
+ self.transformer.get_onnx_config()
+ )
+ export_kwargs = {}
+ if self.use_onnx_function:
+ export_kwargs = {
+ "export_modules_as_functions": self.transformer.model._block_classes,
+ "do_constant_folding": True,
+ }
+
+ self.transformer.export(
+ inputs=example_inputs_transformer,
+ output_names=output_names_transformer,
+ dynamic_axes=dynamic_axes_transformer,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ )
+
+ # vae
+ example_inputs_vae, dynamic_axes_vae, output_names_vae = self.vae_decode.get_onnx_config()
+
+ self.vae_decoder_onnx_path = self.vae_decode.export(
+ example_inputs_vae,
+ output_names_vae,
+ dynamic_axes_vae,
+ export_dir=export_dir,
+ )
+
+ def compile(
+ self,
+ onnx_path: Optional[str] = None,
+ compile_dir: Optional[str] = None,
+ *,
+ seq_len: Union[int, List[int]] = 32,
+ batch_size: int = 1,
+ num_devices_text_encoder: int = 1,
+ num_devices_transformer: int = 4,
+ num_devices_vae_decoder: int = 1,
+ num_cores: int = 16, # FIXME: Make this mandatory arg
+ mxfp6_matmul: bool = False,
+ **compiler_options,
+ ) -> str:
+ """
+ Compiles the ONNX graphs of the different model components for deployment on Qualcomm AI hardware.
+
+ This method takes the ONNX paths of the text encoders, transformer, and VAE decoder,
+ and compiles them into an optimized format for inference.
+
+ Args:
+ onnx_path (`str`, *optional*):
+ The base directory where ONNX files were exported. If None, it assumes the ONNX
+ paths are already set as attributes (e.g., `self.text_encoder_onnx_path`).
+ This parameter is currently not fully utilized as individual ONNX paths are derived
+ from the `export` method.
+ compile_dir (`str`, *optional*):
+ The directory path to store the compiled artifacts. If None, a default location
+ might be used by the underlying compilation process.
+ seq_len (`Union[int, List[int]]`, *optional*, defaults to 32):
+ The sequence length(s) to use for compiling the text encoders. Can be a single
+ integer or a list of integers for multiple sequence lengths.
+ batch_size (`int`, *optional*, defaults to 1):
+ The batch size to use for compilation.
+ num_devices_text_encoder (`int`, *optional*, defaults to 1):
+ The number of AI devices to deploy the text encoder models on.
+ num_devices_transformer (`int`, *optional*, defaults to 4):
+ The number of AI devices to deploy the transformer model on.
+ num_devices_vae_decoder (`int`, *optional*, defaults to 1):
+ The number of AI devices to deploy the VAE decoder model on.
+ num_cores (`int`, *optional*, defaults to 16):
+ The number of cores to use for compilation. This argument is currently marked
+ as `FIXME: Make this mandatory arg`.
+ mxfp6_matmul (`bool`, *optional*, defaults to `False`):
+ If `True`, enables mixed-precision floating-point 6-bit matrix multiplication
+ optimization during compilation.
+ **compiler_options:
+ Additional keyword arguments to pass to the underlying compiler.
+
+ Returns:
+ `str`: A message indicating the compilation status or path to compiled artifacts.
+ (Note: The current implementation might need to return specific paths for each compiled model).
+ """
+ if any(
+ path is None
+ for path in [
+ self.text_encoder.onnx_path,
+ self.text_encoder_2.onnx_path,
+ self.text_encoder_3.onnx_path,
+ self.transformer.onnx_path,
+ self.vae_decode.onnx_path,
+ ]
+ ):
+ self.export()
+
+ # text_encoder
+ specializations_text_encoder = self.text_encoder.get_specializations(
+ batch_size, self.tokenizer.model_max_length
+ )
+
+ self.text_encoder_compile_path = self.text_encoder._compile(
+ onnx_path,
+ compile_dir,
+ compile_only=True,
+ specializations=specializations_text_encoder,
+ convert_to_fp16=True,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=num_devices_text_encoder,
+ aic_num_cores=num_cores,
+ **compiler_options,
+ )
+
+ # text encoder 2
+ specializations_text_encoder_2 = self.text_encoder_2.get_specializations(
+ batch_size, self.tokenizer.model_max_length
+ )
+
+ self.text_encoder_2_compile_path = self.text_encoder_2._compile(
+ onnx_path,
+ compile_dir,
+ compile_only=True,
+ specializations=specializations_text_encoder_2,
+ convert_to_fp16=True,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=num_devices_text_encoder,
+ aic_num_cores=num_cores,
+ **compiler_options,
+ )
+
+ # text_encoder 3
+ specializations_text_encoder_3 = self.text_encoder_3.get_specializations(batch_size, 256)
+
+ self.text_encoder_3_compile_path = self.text_encoder_3._compile(
+ onnx_path,
+ compile_dir,
+ compile_only=True,
+ specializations=specializations_text_encoder_3,
+ convert_to_fp16=True,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=num_devices_text_encoder,
+ aic_num_cores=num_cores,
+ **compiler_options,
+ )
+
+ # transformer
+ specializations_transformer = self.transformer.get_specializations(batch_size, 333)
+
+ compiler_options = {"mos": 1, "ols": 2}
+ self.trasformers_compile_path = self.transformer._compile(
+ onnx_path,
+ compile_dir,
+ compile_only=True,
+ specializations=specializations_transformer,
+ convert_to_fp16=True,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=num_devices_transformer,
+ aic_num_cores=num_cores,
+ **compiler_options,
+ )
+
+ # vae
+ specializations_vae = self.vae_decode.get_specializations(batch_size)
+
+ self.vae_decoder_compile_path = self.vae_decode._compile(
+ onnx_path,
+ compile_dir,
+ compile_only=True,
+ specializations=specializations_vae,
+ convert_to_fp16=True,
+ mdp_ts_num_devices=num_devices_vae_decoder,
+ )
+
+ def _get_clip_prompt_embeds(
+ self,
+ text_encoder,
+ tokenizer,
+ clip_index: bool,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: Optional[int] = 1,
+ clip_skip: Optional[int] = None,
+ device_ids: List[int] = None,
+ ):
+ """
+ Get CLIP prompt embeddings for a given text encoder and tokenizer.
+
+ Args:
+ text_encoder: The QEffTextEncoder instance to use for encoding.
+ tokenizer: The tokenizer to use for text preprocessing.
+ clip_index (int): Index of the CLIP model (0 or 1) to determine embedding dimensions and hidden state range.
+ prompt (Union[str, List[str]]): The input prompt(s) to encode.
+ num_images_per_prompt (Optional[int], defaults to 1): Number of images to generate per prompt.
+ clip_skip (Optional[int], optional): Number of layers to skip from the end when extracting hidden states.
+ device_ids (List[int], optional): List of device IDs to use for inference.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - prompt_embd_text_encoder: The prompt embeddings from the text encoder.
+ - pooled_prompt_embeds_text_encoder: The pooled prompt embeddings.
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ # to pick correct hidden_state range for each clip model
+ hidden_state_range = 33 if clip_index else 13
+
+ # choose embed_dim based on the clip model index.
+ embed_dim = 1280 if clip_index else 768
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+
+ if text_encoder.qpc_session is None:
+ text_encoder.qpc_session = QAICInferenceSession(text_encoder.qpc_path, device_ids=device_ids)
+
+ text_encoder_output = {
+ "pooler_output": np.random.rand(batch_size, embed_dim).astype(np.int32),
+ "last_hidden_state": np.random.rand(batch_size, self.tokenizer_max_length, embed_dim).astype(np.int32),
+ }
+
+ for i in range(0, hidden_state_range):
+ text_encoder_output[f"hidden_states_{i}"] = np.random.rand(
+ batch_size, self.tokenizer_max_length, embed_dim
+ ).astype(np.int32)
+ text_encoder.qpc_session.set_buffers(text_encoder_output)
+
+ aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
+ aic_embeddings = text_encoder.qpc_session.run(aic_text_input)
+ aic_text_encoder_emb = aic_embeddings["pooler_output"]
+
+ ## [TEMP] CHECK ACC ##
+ # prompt_embeds_pytorch = text_encoder.model(text_input_ids, output_hidden_states=True)
+ # pt_pooled_embed = prompt_embeds_pytorch[0].detach().numpy()
+ # mad = np.mean(np.abs(pt_pooled_embed - aic_text_encoder_emb))
+ # print(f"CLIP text encoder {clip_index}- pooled embed MAD: ", mad)
+ ### END CHECK ACC ##
+
+ pooled_prompt_embeds = torch.tensor(aic_text_encoder_emb)
+ if clip_skip is None:
+ prompt_embd_text_encoder = torch.tensor(aic_embeddings[list(aic_embeddings.keys())[-2]])
+ else:
+ prompt_embd_text_encoder = torch.tensor(aic_embeddings[list(aic_embeddings.keys())[-(clip_skip + 2)]])
+ _, seq_len, _ = prompt_embd_text_encoder.shape
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embd_text_encoder = prompt_embd_text_encoder.repeat(1, num_images_per_prompt, 1)
+ prompt_embd_text_encoder = prompt_embd_text_encoder.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds_text_encoder = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds_text_encoder = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embd_text_encoder, pooled_prompt_embeds_text_encoder
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Get T5 prompt embeddings for the given prompt(s).
+
+ Args:
+ prompt (Union[str, List[str]], optional): The input prompt(s) to encode.
+ num_images_per_prompt (int, defaults to 1): Number of images to generate per prompt.
+ max_sequence_length (int, defaults to 256): Maximum sequence length for tokenization.
+ device (Optional[torch.device], optional): The device to place tensors on.
+ dtype (Optional[torch.dtype], optional): The data type for tensors.
+
+ Returns:
+ torch.Tensor: The T5 prompt embeddings with shape (batch_size * num_images_per_prompt, seq_len, hidden_size).
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.text_encoder_3.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.text_encoder_3.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.text_encoder_3.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+ if self.text_encoder_3.qpc_session is None:
+ self.text_encoder_3.qpc_session = QAICInferenceSession(str(self.text_encoder_3_compile_path))
+
+ aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
+ prompt_embeds = torch.tensor(self.text_encoder_3.qpc_session.run(aic_text_input)["last_hidden_state"])
+
+ # AIC Testing
+ # prompt_embeds_torch = self.text_encoder_3.model(text_input_ids.to(device))[0]
+ # mad = torch.abs(prompt_embeds - aic_embeddings).mean()
+ # print("Clip text-encoder-3 Pytorch vs AI 100:", mad)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ prompt_3: Union[str, List[str]],
+ device_ids: List[int] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clip_skip: Optional[int] = None,
+ max_sequence_length: int = 256,
+ ):
+ """
+ Encode the given prompts into text embeddings using the three text encoders (CLIP and T5).
+
+ This method processes prompts through multiple text encoders to generate embeddings suitable
+ for Stable Diffusion 3 generation. It handles both positive and negative prompts for
+ classifier-free guidance.
+
+ Args:
+ prompt (Union[str, List[str]]): The primary prompt(s) to encode.
+ prompt_2 (Union[str, List[str]]): The secondary prompt(s) for the second CLIP encoder.
+ prompt_3 (Union[str, List[str]]): The tertiary prompt(s) for the T5 encoder.
+ device_ids (List[int], optional): List of device IDs to use for inference.
+ num_images_per_prompt (int, defaults to 1): Number of images to generate per prompt.
+ do_classifier_free_guidance (bool, defaults to True): Whether to use classifier-free guidance.
+ negative_prompt (Optional[Union[str, List[str]]], optional): The negative prompt(s) to encode.
+ negative_prompt_2 (Optional[Union[str, List[str]]], optional): The negative prompt(s) for the second CLIP encoder.
+ negative_prompt_3 (Optional[Union[str, List[str]]], optional): The negative prompt(s) for the T5 encoder.
+ prompt_embeds (Optional[torch.FloatTensor], optional): Pre-computed prompt embeddings.
+ negative_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-computed negative prompt embeddings.
+ pooled_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-computed pooled prompt embeddings.
+ negative_pooled_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-computed negative pooled prompt embeddings.
+ clip_skip (Optional[int], optional): Number of layers to skip from the end when extracting CLIP hidden states.
+ max_sequence_length (int, defaults to 256): Maximum sequence length for T5 tokenization.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
+ - prompt_embeds: The combined prompt embeddings from all encoders.
+ - negative_prompt_embeds: The combined negative prompt embeddings (if classifier-free guidance is enabled).
+ - pooled_prompt_embeds: The pooled prompt embeddings from CLIP encoders.
+ - negative_pooled_prompt_embeds: The pooled negative prompt embeddings (if classifier-free guidance is enabled).
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ prompt_3 = prompt_3 or prompt
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
+
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
+ self.text_encoder,
+ self.text_encoder.tokenizer,
+ clip_index=0,
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ device_ids=device_ids,
+ )
+
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
+ self.text_encoder_2,
+ self.text_encoder_2.tokenizer,
+ clip_index=1,
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ device_ids=device_ids,
+ )
+
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
+
+ t5_prompt_embed = self._get_t5_prompt_embeds(
+ prompt=prompt_3,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ clip_prompt_embeds = torch.nn.functional.pad(
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
+ )
+
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+ negative_prompt_3 = (
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
+ )
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
+ self.text_encoder,
+ self.text_encoder.tokenizer,
+ clip_index=0,
+ prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ device_ids=device_ids,
+ )
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
+ self.text_encoder_2,
+ self.text_encoder_2.tokenizer,
+ clip_index=1,
+ prompt=negative_prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ device_ids=device_ids,
+ )
+
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
+ negative_pooled_prompt_embeds = torch.cat(
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
+ )
+
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
+ prompt=negative_prompt_3,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
+ negative_clip_prompt_embeds,
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
+ )
+
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ ):
+ """
+ Generate images from text prompts using the QEfficient-optimized Stable Diffusion 3 pipeline.
+
+ This method performs text-to-image generation by encoding the input prompts through multiple
+ text encoders, running the diffusion process with the transformer model, and decoding the
+ final latents to images using the VAE decoder. All components are optimized for Qualcomm AI hardware.
+
+ Args:
+ prompt (Union[str, List[str]], optional): The primary text prompt(s) to guide image generation.
+ prompt_2 (Optional[Union[str, List[str]]], optional): Secondary prompt(s) for the second CLIP encoder.
+ If None, defaults to `prompt`.
+ prompt_3 (Optional[Union[str, List[str]]], optional): Tertiary prompt(s) for the T5 encoder.
+ If None, defaults to `prompt`.
+ height (Optional[int], optional): Height of the generated image in pixels. If None, uses default
+ sample size scaled by VAE scale factor.
+ width (Optional[int], optional): Width of the generated image in pixels. If None, uses default
+ sample size scaled by VAE scale factor.
+ num_inference_steps (int, defaults to 28): Number of denoising steps during generation.
+ timesteps (List[int], optional): Custom timesteps to use for denoising. If provided, overrides
+ `num_inference_steps`.
+ guidance_scale (float, defaults to 7.0): Guidance scale for classifier-free guidance. Higher values
+ result in images more closely linked to the prompt at the expense of lower image quality.
+ negative_prompt (Optional[Union[str, List[str]]], optional): Negative prompt(s) to guide what not
+ to include in image generation.
+ negative_prompt_2 (Optional[Union[str, List[str]]], optional): Negative prompt(s) for the second
+ CLIP encoder.
+ negative_prompt_3 (Optional[Union[str, List[str]]], optional): Negative prompt(s) for the T5 encoder.
+ num_images_per_prompt (Optional[int], defaults to 1): Number of images to generate per prompt.
+ generator (Optional[Union[torch.Generator, List[torch.Generator]]], optional): Random number
+ generator(s) for reproducible generation.
+ latents (Optional[torch.FloatTensor], optional): Pre-generated noisy latents sampled from a Gaussian
+ distribution to be used as inputs for image generation.
+ prompt_embeds (Optional[torch.FloatTensor], optional): Pre-generated text embeddings. Can be used
+ to easily tweak text inputs (prompt weighting).
+ negative_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-generated negative text embeddings.
+ pooled_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-generated pooled text embeddings.
+ negative_pooled_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-generated negative pooled
+ text embeddings.
+ output_type (Optional[str], defaults to "pil"): Output format of the generated images. Choose between
+ "pil", "np", "pt", or "latent".
+ return_dict (bool, defaults to True): Whether to return a `StableDiffusion3PipelineOutput` instead
+ of a plain tuple.
+ joint_attention_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments to pass
+ to the joint attention layers.
+ clip_skip (Optional[int], optional): Number of layers to skip from the end when extracting CLIP
+ hidden states.
+ callback_on_step_end (Optional[Callable[[int, int, Dict], None]], optional): Callback function
+ called at the end of each denoising step.
+ callback_on_step_end_tensor_inputs (List[str], defaults to ["latents"]): List of tensor inputs
+ to pass to the callback function.
+ max_sequence_length (int, defaults to 256): Maximum sequence length for T5 text encoder tokenization.
+
+ Returns:
+ Union[StableDiffusion3PipelineOutput, Tuple]: If `return_dict` is True, returns a
+ `StableDiffusion3PipelineOutput` object containing the generated images. Otherwise,
+ returns a tuple with the generated images.
+
+ Examples:
+ ```python
+ # Basic text-to-image generation
+ from QEfficient import QEFFStableDiffusion3Pipeline
+
+ pipeline = QEFFStableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-large")
+ pipeline.compile(num_devices_text_encoder=1, num_devices_transformer=4, num_devices_vae_decoder=1)
+
+ # NOTE: guidance_scale <=1 is not supported
+ image = pipeline("A girl laughing", num_inference_steps=28, guidance_scale=2.0).images[0]
+ image.save("girl_laughing.png")
+ ```
+ """
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ device = "cpu"
+
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.model.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ ###### AIC related changes of transformers ######
+ if self.transformer.qpc_session is None:
+ self.transformer.qpc_session = QAICInferenceSession(str(self.transformer.qpc_path))
+
+ output_buffer = {
+ "output": np.random.rand(
+ 2 * batch_size, num_channels_latents, self.default_sample_size, self.default_sample_size
+ ).astype(np.int32),
+ }
+
+ self.transformer.qpc_session.set_buffers(output_buffer)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ timestep = np.array([t], dtype=np.int64)
+
+ # noise_pred_torch = self.transformer.model(
+ # hidden_states=latent_model_input,
+ # timestep=torch.tensor(timestep),
+ # encoder_hidden_states=prompt_embeds,
+ # pooled_projections=pooled_prompt_embeds,
+ # joint_attention_kwargs=self.joint_attention_kwargs,
+ # return_dict=False,
+ # )[0]
+
+ noise_pred = self.transformer.qpc_session.run(
+ {
+ "encoder_hidden_states": prompt_embeds.detach().numpy(),
+ "pooled_projections": pooled_prompt_embeds.numpy(),
+ "timestep": timestep,
+ "hidden_states": latent_model_input.numpy(),
+ }
+ )
+
+ # ###### ACCURACY TESTING #######
+ # mad=np.mean(np.abs(noise_pred_torch.detach().numpy()-noise_pred['output']))
+ # print("transfromer model MAD:", mad)
+
+ noise_pred = torch.tensor(noise_pred["output"])
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = (
+ latents / self.vae_decode.model.config.scaling_factor
+ ) + self.vae_decode.model.config.shift_factor
+
+ # image_torch = self.vae_decode.model(latents, return_dict=False)[0]
+
+ vae_session = QAICInferenceSession(str(self.vae_decoder_compile_path))
+
+ output_buffer = {
+ "sample": np.random.rand(
+ batch_size, 3, self.vae_decode.model.config.sample_size, self.vae_decode.model.config.sample_size
+ ).astype(np.int32)
+ }
+
+ vae_session.set_buffers(output_buffer)
+ inputs = {"latent_sample": latents.numpy()}
+ image = vae_session.run(inputs)
+ # mad= np.mean(np.abs(image['sample']-image_torch.detach().numpy()))
+ # print("VAE mad: ",mad)
+ image = self.image_processor.postprocess(torch.tensor(image["sample"]), output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusion3PipelineOutput(images=image)
diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py
index 8519d824c..8fe1c0868 100644
--- a/QEfficient/generation/cloud_infer.py
+++ b/QEfficient/generation/cloud_infer.py
@@ -84,7 +84,7 @@ def __init__(
self.binding_index_map = {binding.name: binding.index for binding in self.bindings}
# Create and load Program
prog_properties = qaicrt.QAicProgramProperties()
- prog_properties.SubmitRetryTimeoutMs = 60_000
+ prog_properties.SubmitRetryTimeoutMs = 60_00000
if device_ids and len(device_ids) > 1:
prog_properties.devMapping = ":".join(map(str, device_ids))
self.program = qaicrt.Program(self.context, None, qpc, prog_properties)
diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py
index ca74c0ddd..6719396c0 100644
--- a/QEfficient/transformers/models/pytorch_transforms.py
+++ b/QEfficient/transformers/models/pytorch_transforms.py
@@ -142,6 +142,13 @@
Starcoder2ForCausalLM,
Starcoder2Model,
)
+from transformers.models.t5.modeling_t5 import (
+ T5Attention,
+ T5LayerCrossAttention,
+ T5LayerFF,
+ T5LayerNorm,
+ T5LayerSelfAttention,
+)
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
@@ -309,6 +316,13 @@
QEffStarcoder2ForCausalLM,
QEffStarcoder2Model,
)
+from QEfficient.transformers.models.t5.modeling_t5 import (
+ QEffT5Attention,
+ QEffT5LayerCrossAttention,
+ QEffT5LayerFF,
+ QEffT5LayerNorm,
+ QEffT5LayerSelfAttention,
+)
from QEfficient.transformers.models.whisper.modeling_whisper import (
QEffWhisperAttention,
QEffWhisperDecoder,
@@ -617,6 +631,22 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform):
_match_class_replace_method = {}
+class T5ModelTransform(ModuleMappingTransform):
+ # supported architectures
+ _module_mapping = {
+ T5LayerFF: QEffT5LayerFF,
+ T5LayerSelfAttention: QEffT5LayerSelfAttention,
+ T5LayerCrossAttention: QEffT5LayerCrossAttention,
+ T5Attention: QEffT5Attention,
+ T5LayerNorm: QEffT5LayerNorm,
+ }
+
+ @classmethod
+ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
+ model, transformed = super().apply(model)
+ return model, transformed
+
+
class PoolingTransform:
"""
Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output.
diff --git a/QEfficient/transformers/models/t5/__init__.py b/QEfficient/transformers/models/t5/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/QEfficient/transformers/models/t5/modeling_t5.py b/QEfficient/transformers/models/t5/modeling_t5.py
new file mode 100644
index 000000000..9ba5869d7
--- /dev/null
+++ b/QEfficient/transformers/models/t5/modeling_t5.py
@@ -0,0 +1,217 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+from transformers.models.t5.modeling_t5 import (
+ T5Attention,
+ T5LayerCrossAttention,
+ T5LayerFF,
+ T5LayerNorm,
+ T5LayerSelfAttention,
+)
+
+
+class QEffT5LayerNorm(T5LayerNorm):
+ def forward(self, hidden_states):
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
+ # half-precision inputs is done in fp32
+
+ div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32))
+ variance = div_first.pow(2).sum(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+class QEffT5LayerFF(T5LayerFF):
+ def forward(self, hidden_states):
+ forwarded_states = self.layer_norm(hidden_states)
+ forwarded_states = self.DenseReluDense(forwarded_states)
+ hidden_states = hidden_states * 1.0 + self.dropout(forwarded_states)
+ return hidden_states
+
+
+class QEffT5Attention(T5Attention):
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ key_value_states=None,
+ position_bias=None,
+ past_key_value=None,
+ layer_head_mask=None,
+ query_length=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ # Input is (batch_size, seq_length, dim)
+ # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
+ is_cross_attention = key_value_states is not None
+
+ query_states = self.q(hidden_states)
+ query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ if past_key_value is not None:
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = past_key_value.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_value.self_attention_cache
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_value is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.key_cache[self.layer_idx]
+ value_states = curr_past_key_value.value_cache[self.layer_idx]
+ else:
+ key_states = self.k(current_states)
+ value_states = self.v(current_states)
+ key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention:
+ past_key_value.is_updated[self.layer_idx] = True
+
+ # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
+
+ if position_bias is None:
+ key_length = key_states.shape[-2]
+ # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
+ real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(
+ real_seq_length, key_length, device=scores.device, cache_position=cache_position
+ )
+ # Original line: position_bias = position_bias[:, :, -seq_length:, :]
+ if past_key_value is not None: # This block is where the patch applies
+ # position_bias = position_bias[:, :, -hidden_states.size(1) :, :] # Original line (commented in patch)
+ position_bias = position_bias[:, :, -1:, :] # Added by patch
+
+ if mask is not None:
+ causal_mask = mask[:, :, :, : key_states.shape[-2]]
+ position_bias = position_bias + causal_mask
+
+ if self.pruned_heads:
+ mask = torch.ones(position_bias.shape[1])
+ mask[list(self.pruned_heads)] = 0
+ position_bias_masked = position_bias[:, mask.bool()]
+ else:
+ position_bias_masked = position_bias
+
+ scores += position_bias_masked
+
+ # (batch_size, n_heads, seq_length, key_length)
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(batch_size, -1, self.inner_dim)
+ attn_output = self.o(attn_output)
+
+ outputs = (attn_output, past_key_value, position_bias)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
+
+
+class QEffT5LayerSelfAttention(T5LayerSelfAttention):
+ def __qeff_init__(self):
+ self.scaling_factor = 1.0
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.SelfAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ hidden_states = hidden_states * 1.0 + self.dropout(attention_output[0]) # Modified by patch
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class QEffT5LayerCrossAttention(T5LayerCrossAttention):
+ def forward(
+ self,
+ hidden_states,
+ key_value_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ query_length=None,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.EncDecAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ query_length=query_length,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ layer_output = hidden_states * 1.0 + self.dropout(attention_output[0]) # Modified by patch
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
+ return outputs
diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py
index 50f36ea32..e458fe5b2 100644
--- a/QEfficient/utils/constants.py
+++ b/QEfficient/utils/constants.py
@@ -68,7 +68,7 @@ def get_models_dir():
ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512
ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80
ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99
-ONNX_EXPORT_OPSET = 13
+ONNX_EXPORT_OPSET = 17
COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"]
@@ -103,6 +103,35 @@ def get_models_dir():
GEMMA3_MAX_POSITION_EMBEDDINGS = 32768
+# wo_sfs: weight output scaling factors (used to normalize T5 encoder output weights before export)
+WO_SFS = [
+ 61,
+ 203,
+ 398,
+ 615,
+ 845,
+ 1190,
+ 1402,
+ 2242,
+ 1875,
+ 2393,
+ 3845,
+ 3213,
+ 3922,
+ 4429,
+ 5020,
+ 5623,
+ 6439,
+ 6206,
+ 5165,
+ 4593,
+ 2802,
+ 2618,
+ 1891,
+ 1419,
+]
+
+
class Constants:
# Export Constants.
SEQ_LEN = 32
diff --git a/docs/image/girl_laughing.png b/docs/image/girl_laughing.png
new file mode 100644
index 000000000..f3ad34a7a
Binary files /dev/null and b/docs/image/girl_laughing.png differ
diff --git a/examples/diffusers/stable_diffusion_3/stable_diffusion_35_example.py b/examples/diffusers/stable_diffusion_3/stable_diffusion_35_example.py
new file mode 100644
index 000000000..ca326a601
--- /dev/null
+++ b/examples/diffusers/stable_diffusion_3/stable_diffusion_35_example.py
@@ -0,0 +1,17 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+from QEfficient import QEFFStableDiffusion3Pipeline
+
+pipeline = QEFFStableDiffusion3Pipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3.5-large-turbo", use_onnx_function=False
+)
+pipeline.compile(num_devices_text_encoder=1, num_devices_transformer=4, num_devices_vae_decoder=1)
+
+# NOTE: guidance_scale <=1 is not supported
+image = pipeline("A girl laughing", num_inference_steps=28, guidance_scale=2.0).images[0]
+image.save("girl_laughing_turbo.png")
diff --git a/pyproject.toml b/pyproject.toml
index 479736c22..bf439548a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,18 +39,18 @@ dependencies = [
"fire",
"py7zr",
"torchmetrics==1.7.0",
- "torch==2.4.1; platform_machine=='aarch64'",
+ "torch==2.7.1; platform_machine=='aarch64'",
# Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11
"torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'",
- "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'",
- "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'",
+ "torch@https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'",
+ "torch@https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'",
]
[project.optional-dependencies]
test = ["pytest","pytest-mock"]
docs = ["Sphinx==7.1.2","sphinx-rtd-theme==2.0.0","myst-parser==3.0.1","sphinx-multiversion"]
quality = ["black", "ruff", "hf_doc_builder@git+https://github.com/huggingface/doc-builder.git"]
-
+diffusers = ["diffusers== 0.31.0"]
[build-system]
requires = ["setuptools>=62.0.0"]
build-backend = "setuptools.build_meta"
@@ -71,4 +71,4 @@ target-version = "py310"
[tool.pytest.ini_options]
addopts = "-W ignore -s -v"
junit_logging = "all"
-doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
+doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
\ No newline at end of file