Skip to content

Commit b500140

Browse files
Add Wan2.2 VACE - Fun (#12324)
* support Wan2.2-VACE-Fun-A14B * support Wan2.2-VACE-Fun-A14B * support Wan2.2-VACE-Fun-A14B * Apply style fixes * test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent f5c113e commit b500140

File tree

3 files changed

+94
-17
lines changed

3 files changed

+94
-17
lines changed

scripts/convert_wan_to_diffusers.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,29 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
278278
}
279279
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
280280
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
281+
elif model_type == "Wan2.2-VACE-Fun-14B":
282+
config = {
283+
"model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
284+
"diffusers_config": {
285+
"added_kv_proj_dim": None,
286+
"attention_head_dim": 128,
287+
"cross_attn_norm": True,
288+
"eps": 1e-06,
289+
"ffn_dim": 13824,
290+
"freq_dim": 256,
291+
"in_channels": 16,
292+
"num_attention_heads": 40,
293+
"num_layers": 40,
294+
"out_channels": 16,
295+
"patch_size": [1, 2, 2],
296+
"qk_norm": "rms_norm_across_heads",
297+
"text_dim": 4096,
298+
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
299+
"vace_in_channels": 96,
300+
},
301+
}
302+
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
303+
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
281304
elif model_type == "Wan2.2-I2V-14B-720p":
282305
config = {
283306
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
@@ -975,7 +998,17 @@ def get_args():
975998
image_encoder=image_encoder,
976999
image_processor=image_processor,
9771000
)
978-
elif "VACE" in args.model_type:
1001+
elif "Wan2.2-VACE" in args.model_type:
1002+
pipe = WanVACEPipeline(
1003+
transformer=transformer,
1004+
transformer_2=transformer_2,
1005+
text_encoder=text_encoder,
1006+
tokenizer=tokenizer,
1007+
vae=vae,
1008+
scheduler=scheduler,
1009+
boundary_ratio=0.875,
1010+
)
1011+
elif "Wan-VACE" in args.model_type:
9791012
pipe = WanVACEPipeline(
9801013
transformer=transformer,
9811014
text_encoder=text_encoder,

src/diffusers/pipelines/wan/pipeline_wan_vace.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,26 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
152152
text_encoder ([`T5EncoderModel`]):
153153
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
154154
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
155-
transformer ([`WanTransformer3DModel`]):
155+
transformer ([`WanVACETransformer3DModel`]):
156156
Conditional Transformer to denoise the input latents.
157+
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
158+
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
159+
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
160+
`transformer` is used.
157161
scheduler ([`UniPCMultistepScheduler`]):
158162
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
159163
vae ([`AutoencoderKLWan`]):
160164
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
165+
boundary_ratio (`float`, *optional*, defaults to `None`):
166+
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
167+
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
168+
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
169+
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
161170
"""
162171

163172
model_cpu_offload_seq = "text_encoder->transformer->vae"
164173
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
174+
_optional_components = ["transformer_2"]
165175

166176
def __init__(
167177
self,
@@ -170,6 +180,8 @@ def __init__(
170180
transformer: WanVACETransformer3DModel,
171181
vae: AutoencoderKLWan,
172182
scheduler: FlowMatchEulerDiscreteScheduler,
183+
transformer_2: WanVACETransformer3DModel = None,
184+
boundary_ratio: Optional[float] = None,
173185
):
174186
super().__init__()
175187

@@ -178,9 +190,10 @@ def __init__(
178190
text_encoder=text_encoder,
179191
tokenizer=tokenizer,
180192
transformer=transformer,
193+
transformer_2=transformer_2,
181194
scheduler=scheduler,
182195
)
183-
196+
self.register_to_config(boundary_ratio=boundary_ratio)
184197
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
185198
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
186199
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -321,6 +334,7 @@ def check_inputs(
321334
video=None,
322335
mask=None,
323336
reference_images=None,
337+
guidance_scale_2=None,
324338
):
325339
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
326340
if height % base != 0 or width % base != 0:
@@ -332,6 +346,8 @@ def check_inputs(
332346
raise ValueError(
333347
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
334348
)
349+
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
350+
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
335351

336352
if prompt is not None and prompt_embeds is not None:
337353
raise ValueError(
@@ -667,6 +683,7 @@ def __call__(
667683
num_frames: int = 81,
668684
num_inference_steps: int = 50,
669685
guidance_scale: float = 5.0,
686+
guidance_scale_2: Optional[float] = None,
670687
num_videos_per_prompt: Optional[int] = 1,
671688
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
672689
latents: Optional[torch.Tensor] = None,
@@ -728,6 +745,10 @@ def __call__(
728745
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
729746
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
730747
usually at the expense of lower image quality.
748+
guidance_scale_2 (`float`, *optional*, defaults to `None`):
749+
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
750+
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
751+
and the pipeline's `boundary_ratio` are not None.
731752
num_videos_per_prompt (`int`, *optional*, defaults to 1):
732753
The number of images to generate per prompt.
733754
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -793,6 +814,7 @@ def __call__(
793814
video,
794815
mask,
795816
reference_images,
817+
guidance_scale_2,
796818
)
797819

798820
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -802,7 +824,11 @@ def __call__(
802824
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
803825
num_frames = max(num_frames, 1)
804826

827+
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
828+
guidance_scale_2 = guidance_scale
829+
805830
self._guidance_scale = guidance_scale
831+
self._guidance_scale_2 = guidance_scale_2
806832
self._attention_kwargs = attention_kwargs
807833
self._current_timestep = None
808834
self._interrupt = False
@@ -896,36 +922,53 @@ def __call__(
896922
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
897923
self._num_timesteps = len(timesteps)
898924

925+
if self.config.boundary_ratio is not None:
926+
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
927+
else:
928+
boundary_timestep = None
929+
899930
with self.progress_bar(total=num_inference_steps) as progress_bar:
900931
for i, t in enumerate(timesteps):
901932
if self.interrupt:
902933
continue
903934

904935
self._current_timestep = t
936+
937+
if boundary_timestep is None or t >= boundary_timestep:
938+
# wan2.1 or high-noise stage in wan2.2
939+
current_model = self.transformer
940+
current_guidance_scale = guidance_scale
941+
else:
942+
# low-noise stage in wan2.2
943+
current_model = self.transformer_2
944+
current_guidance_scale = guidance_scale_2
945+
905946
latent_model_input = latents.to(transformer_dtype)
906947
timestep = t.expand(latents.shape[0])
907948

908-
noise_pred = self.transformer(
909-
hidden_states=latent_model_input,
910-
timestep=timestep,
911-
encoder_hidden_states=prompt_embeds,
912-
control_hidden_states=conditioning_latents,
913-
control_hidden_states_scale=conditioning_scale,
914-
attention_kwargs=attention_kwargs,
915-
return_dict=False,
916-
)[0]
917-
918-
if self.do_classifier_free_guidance:
919-
noise_uncond = self.transformer(
949+
with current_model.cache_context("cond"):
950+
noise_pred = current_model(
920951
hidden_states=latent_model_input,
921952
timestep=timestep,
922-
encoder_hidden_states=negative_prompt_embeds,
953+
encoder_hidden_states=prompt_embeds,
923954
control_hidden_states=conditioning_latents,
924955
control_hidden_states_scale=conditioning_scale,
925956
attention_kwargs=attention_kwargs,
926957
return_dict=False,
927958
)[0]
928-
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
959+
960+
if self.do_classifier_free_guidance:
961+
with current_model.cache_context("uncond"):
962+
noise_uncond = current_model(
963+
hidden_states=latent_model_input,
964+
timestep=timestep,
965+
encoder_hidden_states=negative_prompt_embeds,
966+
control_hidden_states=conditioning_latents,
967+
control_hidden_states_scale=conditioning_scale,
968+
attention_kwargs=attention_kwargs,
969+
return_dict=False,
970+
)[0]
971+
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
929972

930973
# compute the previous noisy sample x_t -> x_t-1
931974
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

tests/pipelines/wan/test_wan_vace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def get_dummy_components(self):
8787
"scheduler": scheduler,
8888
"text_encoder": text_encoder,
8989
"tokenizer": tokenizer,
90+
"transformer_2": None,
9091
}
9192
return components
9293

0 commit comments

Comments
 (0)