diff --git a/QEfficient/diffusers/models/autoencoders/autoencoder_kl_wan.py b/QEfficient/diffusers/models/autoencoders/autoencoder_kl_wan.py new file mode 100644 index 000000000..868214455 --- /dev/null +++ b/QEfficient/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -0,0 +1,200 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +from diffusers.models.autoencoders.autoencoder_kl_wan import ( + WanDecoder3d, + WanEncoder3d, + WanResample, + WanResidualBlock, + WanUpsample, +) + +CACHE_T = 2 + +modes = [] + +# Used max(0, x.shape[2] - CACHE_T) instead of CACHE_T because x.shape[2] is either 1 or 4, +# and CACHE_T = 2. This ensures the value never goes negative + + +class QEffWanResample(WanResample): + def __qeff_init__(self): + # Changed upsampling mode from "nearest-exact" to "nearest" for ONNX compatibility. + # Since the scale factor is an integer, both modes behave the + if self.mode in ("upsample2d", "upsample3d"): + self.resample[0] = WanUpsample(scale_factor=(2.0, 2.0), mode="nearest") + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + modes.append(self.mode) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QEffWanResidualBlock(WanResidualBlock): + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QEffWanEncoder3d(WanEncoder3d): + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QEffWanDecoder3d(WanDecoder3d): + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py index 4fb5c3f12..fa637b2e9 100644 --- a/QEfficient/diffusers/models/pytorch_transforms.py +++ b/QEfficient/diffusers/models/pytorch_transforms.py @@ -5,6 +5,12 @@ # # ----------------------------------------------------------------------------- +from diffusers.models.autoencoders.autoencoder_kl_wan import ( + WanDecoder3d, + WanEncoder3d, + WanResample, + WanResidualBlock, +) from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm from diffusers.models.transformers.transformer_flux import ( FluxAttention, @@ -18,6 +24,12 @@ from QEfficient.base.pytorch_transforms import ModuleMappingTransform from QEfficient.customop.rms_norm import CustomRMSNormAIC +from QEfficient.diffusers.models.autoencoders.autoencoder_kl_wan import ( + QEffWanDecoder3d, + QEffWanEncoder3d, + QEffWanResample, + QEffWanResidualBlock, +) from QEfficient.diffusers.models.normalization import ( QEffAdaLayerNormContinuous, QEffAdaLayerNormZero, @@ -54,6 +66,10 @@ class AttentionTransform(ModuleMappingTransform): WanAttnProcessor: QEffWanAttnProcessor, WanAttention: QEffWanAttention, WanTransformer3DModel: QEffWanTransformer3DModel, + WanDecoder3d: QEffWanDecoder3d, + WanEncoder3d: QEffWanEncoder3d, + WanResidualBlock: QEffWanResidualBlock, + WanResample: QEffWanResample, } diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index 19e7701d4..8a9930556 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -229,7 +229,7 @@ class QEffVAE(QEFFBaseModel): _onnx_transforms (List): ONNX transformations applied after export """ - _pytorch_transforms = [CustomOpsTransform] + _pytorch_transforms = [CustomOpsTransform, AttentionTransform] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] @property @@ -287,6 +287,37 @@ def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tu return example_inputs, dynamic_axes, output_names + def get_video_onnx_params(self) -> Tuple[Dict, Dict, List[str]]: + """ + Generate ONNX export configuration for the VAE decoder. + + Args: + latent_height (int): Height of latent representation (default: 32) + latent_width (int): Width of latent representation (default: 32) + + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + # VAE decoder takes latent representation as input + example_inputs = { + "latent_sample": torch.randn(bs, 16, 21, 12, 16), + "return_dict": False, + } + + output_names = ["sample"] + + # All dimensions except channels can be dynamic + dynamic_axes = { + "latent_sample": {0: "batch_size", 2: "num_frames", 3: "latent_height", 4: "latent_width"}, + } + + return example_inputs, dynamic_axes, output_names + def export( self, inputs: Dict, diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py index edae438ae..96ceb3cd5 100644 --- a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -11,7 +11,7 @@ for high-performance text-to-video generation on Qualcomm AI hardware. The pipeline supports WAN 2.2 architectures with unified transformer. -TODO: 1. Update Vae, umt5 to Qaic; present running on cpu +TODO: 1. Update umt5 to Qaic; present running on cpu """ import os @@ -22,7 +22,7 @@ import torch from diffusers import WanPipeline -from QEfficient.diffusers.pipelines.pipeline_module import QEffWanUnifiedTransformer +from QEfficient.diffusers.pipelines.pipeline_module import QEffVAE, QEffWanUnifiedTransformer from QEfficient.diffusers.pipelines.pipeline_utils import ( ONNX_SUBFUNCTION_MODULE, ModulePerf, @@ -106,16 +106,22 @@ def __init__(self, model, **kwargs): self.transformer = QEffWanUnifiedTransformer(self.unified_wrapper) # VAE decoder for latent-to-video conversion - self.vae_decode = model.vae - + self.vae_decoder = QEffVAE(model.vae, "decoder") # Store all modules in a dictionary for easy iteration during export/compile - # TODO: add text encoder, vae decoder on QAIC - self.modules = {"transformer": self.transformer} + # TODO: add text encoder on QAIC + self.modules = {"transformer": self.transformer, "vae_decoder": self.vae_decoder} # Copy tokenizers and scheduler from the original model self.tokenizer = model.tokenizer self.text_encoder.tokenizer = model.tokenizer self.scheduler = model.scheduler + + self.vae_decoder.model.forward = lambda latent_sample, return_dict: self.vae_decoder.model.decode( + latent_sample, return_dict + ) + + self.vae_decoder.get_onnx_params = self.vae_decoder.get_video_onnx_params + self.vae_decoder.model.config["_use_default_values"].sort() # Extract patch dimensions from transformer configuration _, self.patch_height, self.patch_width = self.transformer.model.config.patch_size @@ -336,7 +342,14 @@ def compile( "latent_width": latent_width, # Latent space width "num_frames": latent_frames, # Latent frames }, - ] + ], + "vae_decoder": [ + { + "num_frames": latent_frames, + "latent_height": latent_height, + "latent_width": latent_width, + } + ], } # Use generic utility functions for compilation @@ -722,31 +735,45 @@ def __call__( # Step 9: Decode latents to video if not output_type == "latent": # Prepare latents for VAE decoding - latents = latents.to(self.vae_decode.dtype) + latents = latents.to(self.vae_decoder.model.dtype) # Apply VAE normalization (denormalization) latents_mean = ( - torch.tensor(self.vae_decode.config.latents_mean) - .view(1, self.vae_decode.config.z_dim, 1, 1, 1) + torch.tensor(self.vae_decoder.model.config.latents_mean) + .view(1, self.vae_decoder.model.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae_decode.config.latents_std).view( - 1, self.vae_decode.config.z_dim, 1, 1, 1 + latents_std = 1.0 / torch.tensor(self.vae_decoder.model.config.latents_std).view( + 1, self.vae_decoder.model.config.z_dim, 1, 1, 1 ).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean - # TODO: Enable VAE on QAIC - # VAE Decode latents to video using CPU (temporary) - video = self.model.vae.decode(latents, return_dict=False)[0] # CPU fallback + # Initialize VAE decoder inference session + if self.vae_decoder.qpc_session is None: + self.vae_decoder.qpc_session = QAICInferenceSession( + str(self.vae_decoder.qpc_path), device_ids=self.vae_decoder.device_ids + ) + + # Allocate output buffer for VAE decoder + output_buffer = {"sample": np.random.rand(batch_size, 3, num_frames, height, width).astype(np.int32)} + + inputs = {"latent_sample": latents.numpy()} + + start_decode_time = time.perf_counter() + video = self.vae_decoder.qpc_session.run(inputs) # CPU fallback + end_decode_time = time.perf_counter() + vae_decoder_perf = end_decode_time - start_decode_time # Post-process video for output - video = self.model.video_processor.postprocess_video(video.detach()) + video_tensor = torch.from_numpy(video["sample"]) + video = self.model.video_processor.postprocess_video(video_tensor) else: video = latents # Step 10: Collect performance metrics perf_data = { "transformer": transformer_perf, # Unified transformer (QAIC) + "vae_decoder": vae_decoder_perf, } # Build performance metrics for output diff --git a/examples/diffusers/wan/wan_config.json b/examples/diffusers/wan/wan_config.json index 7e752ba14..c32054db1 100644 --- a/examples/diffusers/wan/wan_config.json +++ b/examples/diffusers/wan/wan_config.json @@ -32,6 +32,36 @@ "execute": { "device_ids": null } - } + }, + "vae_decoder": + { + "specializations": [ + { + "batch_size": 1, + "num_channels": 16, + "num_frames": 21, + "latent_height": 60, + "latent_width": 104 + } + ], + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 8, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "aic-enable-depth-first": true, + "compile_only":true, + "mos": 1, + "mdts_mos": 1 + }, + "execute": + { + "device_ids": null + } + } + } } \ No newline at end of file