@@ -152,16 +152,26 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
152
152
text_encoder ([`T5EncoderModel`]):
153
153
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
154
154
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
155
- transformer ([`WanTransformer3DModel `]):
155
+ transformer ([`WanVACETransformer3DModel `]):
156
156
Conditional Transformer to denoise the input latents.
157
+ transformer_2 ([`WanVACETransformer3DModel`], *optional*):
158
+ Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
159
+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
160
+ `transformer` is used.
157
161
scheduler ([`UniPCMultistepScheduler`]):
158
162
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
159
163
vae ([`AutoencoderKLWan`]):
160
164
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
165
+ boundary_ratio (`float`, *optional*, defaults to `None`):
166
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
167
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
168
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
169
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
161
170
"""
162
171
163
172
model_cpu_offload_seq = "text_encoder->transformer->vae"
164
173
_callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
174
+ _optional_components = ["transformer_2" ]
165
175
166
176
def __init__ (
167
177
self ,
@@ -170,6 +180,8 @@ def __init__(
170
180
transformer : WanVACETransformer3DModel ,
171
181
vae : AutoencoderKLWan ,
172
182
scheduler : FlowMatchEulerDiscreteScheduler ,
183
+ transformer_2 : WanVACETransformer3DModel = None ,
184
+ boundary_ratio : Optional [float ] = None ,
173
185
):
174
186
super ().__init__ ()
175
187
@@ -178,9 +190,10 @@ def __init__(
178
190
text_encoder = text_encoder ,
179
191
tokenizer = tokenizer ,
180
192
transformer = transformer ,
193
+ transformer_2 = transformer_2 ,
181
194
scheduler = scheduler ,
182
195
)
183
-
196
+ self . register_to_config ( boundary_ratio = boundary_ratio )
184
197
self .vae_scale_factor_temporal = 2 ** sum (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
185
198
self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
186
199
self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
@@ -321,6 +334,7 @@ def check_inputs(
321
334
video = None ,
322
335
mask = None ,
323
336
reference_images = None ,
337
+ guidance_scale_2 = None ,
324
338
):
325
339
base = self .vae_scale_factor_spatial * self .transformer .config .patch_size [1 ]
326
340
if height % base != 0 or width % base != 0 :
@@ -332,6 +346,8 @@ def check_inputs(
332
346
raise ValueError (
333
347
f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
334
348
)
349
+ if self .config .boundary_ratio is None and guidance_scale_2 is not None :
350
+ raise ValueError ("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None." )
335
351
336
352
if prompt is not None and prompt_embeds is not None :
337
353
raise ValueError (
@@ -667,6 +683,7 @@ def __call__(
667
683
num_frames : int = 81 ,
668
684
num_inference_steps : int = 50 ,
669
685
guidance_scale : float = 5.0 ,
686
+ guidance_scale_2 : Optional [float ] = None ,
670
687
num_videos_per_prompt : Optional [int ] = 1 ,
671
688
generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
672
689
latents : Optional [torch .Tensor ] = None ,
@@ -728,6 +745,10 @@ def __call__(
728
745
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
729
746
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
730
747
usually at the expense of lower image quality.
748
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
749
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
750
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
751
+ and the pipeline's `boundary_ratio` are not None.
731
752
num_videos_per_prompt (`int`, *optional*, defaults to 1):
732
753
The number of images to generate per prompt.
733
754
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -793,6 +814,7 @@ def __call__(
793
814
video ,
794
815
mask ,
795
816
reference_images ,
817
+ guidance_scale_2 ,
796
818
)
797
819
798
820
if num_frames % self .vae_scale_factor_temporal != 1 :
@@ -802,7 +824,11 @@ def __call__(
802
824
num_frames = num_frames // self .vae_scale_factor_temporal * self .vae_scale_factor_temporal + 1
803
825
num_frames = max (num_frames , 1 )
804
826
827
+ if self .config .boundary_ratio is not None and guidance_scale_2 is None :
828
+ guidance_scale_2 = guidance_scale
829
+
805
830
self ._guidance_scale = guidance_scale
831
+ self ._guidance_scale_2 = guidance_scale_2
806
832
self ._attention_kwargs = attention_kwargs
807
833
self ._current_timestep = None
808
834
self ._interrupt = False
@@ -896,36 +922,53 @@ def __call__(
896
922
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
897
923
self ._num_timesteps = len (timesteps )
898
924
925
+ if self .config .boundary_ratio is not None :
926
+ boundary_timestep = self .config .boundary_ratio * self .scheduler .config .num_train_timesteps
927
+ else :
928
+ boundary_timestep = None
929
+
899
930
with self .progress_bar (total = num_inference_steps ) as progress_bar :
900
931
for i , t in enumerate (timesteps ):
901
932
if self .interrupt :
902
933
continue
903
934
904
935
self ._current_timestep = t
936
+
937
+ if boundary_timestep is None or t >= boundary_timestep :
938
+ # wan2.1 or high-noise stage in wan2.2
939
+ current_model = self .transformer
940
+ current_guidance_scale = guidance_scale
941
+ else :
942
+ # low-noise stage in wan2.2
943
+ current_model = self .transformer_2
944
+ current_guidance_scale = guidance_scale_2
945
+
905
946
latent_model_input = latents .to (transformer_dtype )
906
947
timestep = t .expand (latents .shape [0 ])
907
948
908
- noise_pred = self .transformer (
909
- hidden_states = latent_model_input ,
910
- timestep = timestep ,
911
- encoder_hidden_states = prompt_embeds ,
912
- control_hidden_states = conditioning_latents ,
913
- control_hidden_states_scale = conditioning_scale ,
914
- attention_kwargs = attention_kwargs ,
915
- return_dict = False ,
916
- )[0 ]
917
-
918
- if self .do_classifier_free_guidance :
919
- noise_uncond = self .transformer (
949
+ with current_model .cache_context ("cond" ):
950
+ noise_pred = current_model (
920
951
hidden_states = latent_model_input ,
921
952
timestep = timestep ,
922
- encoder_hidden_states = negative_prompt_embeds ,
953
+ encoder_hidden_states = prompt_embeds ,
923
954
control_hidden_states = conditioning_latents ,
924
955
control_hidden_states_scale = conditioning_scale ,
925
956
attention_kwargs = attention_kwargs ,
926
957
return_dict = False ,
927
958
)[0 ]
928
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
959
+
960
+ if self .do_classifier_free_guidance :
961
+ with current_model .cache_context ("uncond" ):
962
+ noise_uncond = current_model (
963
+ hidden_states = latent_model_input ,
964
+ timestep = timestep ,
965
+ encoder_hidden_states = negative_prompt_embeds ,
966
+ control_hidden_states = conditioning_latents ,
967
+ control_hidden_states_scale = conditioning_scale ,
968
+ attention_kwargs = attention_kwargs ,
969
+ return_dict = False ,
970
+ )[0 ]
971
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
929
972
930
973
# compute the previous noisy sample x_t -> x_t-1
931
974
latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments