Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions QEfficient/diffusers/models/autoencoders/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import torch
from diffusers.models.autoencoders.autoencoder_kl_wan import (
WanDecoder3d,
WanEncoder3d,
WanResample,
WanResidualBlock,
WanUpsample,
)

CACHE_T = 2

modes = []

# Used max(0, x.shape[2] - CACHE_T) instead of CACHE_T because x.shape[2] is either 1 or 4,
# and CACHE_T = 2. This ensures the value never goes negative


class QEffWanResample(WanResample):
def __qeff_init__(self):
# Changed upsampling mode from "nearest-exact" to "nearest" for ONNX compatibility.
# Since the scale factor is an integer, both modes behave the
if self.mode in ("upsample2d", "upsample3d"):
self.resample[0] = WanUpsample(scale_factor=(2.0, 2.0), mode="nearest")

def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat(
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
)
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1

x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
modes.append(self.mode)
x = self.resample(x)
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)

if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x


class QEffWanResidualBlock(WanResidualBlock):
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Apply shortcut connection
h = self.conv_shortcut(x)

# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)

if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)

x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)

# Second normalization and activation
x = self.norm2(x)
x = self.nonlinearity(x)

# Dropout
x = self.dropout(x)

if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)

x = self.conv2(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv2(x)

# Add residual connection
return x + h


class QEffWanEncoder3d(WanEncoder3d):
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)

## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)

## middle
x = self.mid_block(x, feat_cache, feat_idx)

## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x


class QEffWanDecoder3d(WanDecoder3d):
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)

## middle
x = self.mid_block(x, feat_cache, feat_idx)

## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)

## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
16 changes: 16 additions & 0 deletions QEfficient/diffusers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
#
# -----------------------------------------------------------------------------

from diffusers.models.autoencoders.autoencoder_kl_wan import (
WanDecoder3d,
WanEncoder3d,
WanResample,
WanResidualBlock,
)
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
from diffusers.models.transformers.transformer_flux import (
FluxAttention,
Expand All @@ -18,6 +24,12 @@

from QEfficient.base.pytorch_transforms import ModuleMappingTransform
from QEfficient.customop.rms_norm import CustomRMSNormAIC
from QEfficient.diffusers.models.autoencoders.autoencoder_kl_wan import (
QEffWanDecoder3d,
QEffWanEncoder3d,
QEffWanResample,
QEffWanResidualBlock,
)
from QEfficient.diffusers.models.normalization import (
QEffAdaLayerNormContinuous,
QEffAdaLayerNormZero,
Expand Down Expand Up @@ -54,6 +66,10 @@ class AttentionTransform(ModuleMappingTransform):
WanAttnProcessor: QEffWanAttnProcessor,
WanAttention: QEffWanAttention,
WanTransformer3DModel: QEffWanTransformer3DModel,
WanDecoder3d: QEffWanDecoder3d,
WanEncoder3d: QEffWanEncoder3d,
WanResidualBlock: QEffWanResidualBlock,
WanResample: QEffWanResample,
}


Expand Down
33 changes: 32 additions & 1 deletion QEfficient/diffusers/pipelines/pipeline_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class QEffVAE(QEFFBaseModel):
_onnx_transforms (List): ONNX transformations applied after export
"""

_pytorch_transforms = [CustomOpsTransform]
_pytorch_transforms = [CustomOpsTransform, AttentionTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

@property
Expand Down Expand Up @@ -287,6 +287,37 @@ def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tu

return example_inputs, dynamic_axes, output_names

def get_video_onnx_params(self) -> Tuple[Dict, Dict, List[str]]:
"""
Generate ONNX export configuration for the VAE decoder.

Args:
latent_height (int): Height of latent representation (default: 32)
latent_width (int): Width of latent representation (default: 32)

Returns:
Tuple containing:
- example_inputs (Dict): Sample inputs for ONNX export
- dynamic_axes (Dict): Specification of dynamic dimensions
- output_names (List[str]): Names of model outputs
"""
bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE

# VAE decoder takes latent representation as input
example_inputs = {
"latent_sample": torch.randn(bs, 16, 21, 12, 16),
"return_dict": False,
}

output_names = ["sample"]

# All dimensions except channels can be dynamic
dynamic_axes = {
"latent_sample": {0: "batch_size", 2: "num_frames", 3: "latent_height", 4: "latent_width"},
}

return example_inputs, dynamic_axes, output_names

def export(
self,
inputs: Dict,
Expand Down
Loading
Loading