diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index fe00d8c078ff..c27c12c6f648 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -748,6 +748,12 @@ def __init__( # The minimal distance between two spatial tiles self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + 'decoder': self._count_conv3d_fast(self.decoder), + 'encoder': self._count_conv3d_fast(self.encoder) + } def enable_tiling( self, @@ -801,18 +807,12 @@ def disable_slicing(self) -> None: self.use_slicing = False def clear_cache(self): - def _count_conv3d(model): - count = 0 - for m in model.modules(): - if isinstance(m, WanCausalConv3d): - count += 1 - return count - - self._conv_num = _count_conv3d(self.decoder) + # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call + self._conv_num = self._cached_conv_counts['decoder'] self._conv_idx = [0] self._feat_map = [None] * self._conv_num # cache encode - self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_num = self._cached_conv_counts['encoder'] self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num @@ -1083,3 +1083,8 @@ def forward( z = posterior.mode() dec = self.decode(z, return_dict=return_dict) return dec + + @staticmethod + def _count_conv3d_fast(model): + # Fast version: relies on model.modules() being a generator; avoids Python loop overhead by using sum + generator expression + return sum(isinstance(m, WanCausalConv3d) for m in model.modules())