From 243c50a9aa653eef3bbb173239658a77b956384d Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Tue, 27 May 2025 09:44:32 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Speed=20up=20method=20`Aut?= =?UTF-8?q?oencoderKLWan.clear=5Fcache`=20by=20886%?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Key optimizations:** - Compute the number of `WanCausalConv3d` modules in each model (`encoder`/`decoder`) **only once during initialization**, store in `self._cached_conv_counts`. This removes unnecessary repeated tree traversals at every `clear_cache` call, which was the main bottleneck (from profiling). - The internal helper `_count_conv3d_fast` is optimized via a generator expression with `sum` for efficiency. All comments from the original code are preserved, except for updated or removed local docstrings/comments relevant to changed lines. **Function signatures and outputs remain unchanged.** --- .../models/autoencoders/autoencoder_kl_wan.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index fafb1fe867e3..7a7592516a71 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -730,19 +730,19 @@ def __init__( base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout ) + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + 'decoder': self._count_conv3d_fast(self.decoder), + 'encoder': self._count_conv3d_fast(self.encoder) + } + def clear_cache(self): - def _count_conv3d(model): - count = 0 - for m in model.modules(): - if isinstance(m, WanCausalConv3d): - count += 1 - return count - - self._conv_num = _count_conv3d(self.decoder) + # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call + self._conv_num = self._cached_conv_counts['decoder'] self._conv_idx = [0] self._feat_map = [None] * self._conv_num # cache encode - self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_num = self._cached_conv_counts['encoder'] self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num @@ -853,3 +853,8 @@ def forward( z = posterior.mode() dec = self.decode(z, return_dict=return_dict) return dec + + @staticmethod + def _count_conv3d_fast(model): + # Fast version: relies on model.modules() being a generator; avoids Python loop overhead by using sum + generator expression + return sum(isinstance(m, WanCausalConv3d) for m in model.modules())