Skip to content

Commit 81ccd1c

Browse files
committed
support Wan2.2 LoRA finetune
1 parent eb6a005 commit 81ccd1c

File tree

16 files changed

+800
-14
lines changed

16 files changed

+800
-14
lines changed

examples/mmada/README.md

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,6 @@ The following experiments are tested on Ascend Atlas 800T A2 machines with minds
190190
|:-:|:-:|:-:|:-:|:-:|:-:|
191191
| MMaDA-8B-Base | 8 | 4 | zero2 | finetune | 1.29 |
192192

193-
The following experiments are tested on Ascend Atlas 800T A2 machines with mindspore **2.6.0** under **pynative** mode:
194-
195-
| model | # card(s) | batch size | parallelism | task | per batch time (seconds) |
196-
|:-:|:-:|:-:|:-:|:-:|:-:|
197-
| MMaDA-8B-Base | 8 | 4 | zero2 | finetune | 1.30 |
198-
199-
200193

201194
## 🤝 Acknowledgments
202195

examples/wan2_2/finetune.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Wan2.2 LoRA Finetune
2+
3+
We provide an example of how to finetune Wan2.2 model (5B) Text-to-Video task using LoRA (Low-Rank Adaptation) technique.
4+
5+
## Prerequisites
6+
7+
Before running the finetuning script, ensure you have the following prerequisites:
8+
9+
#### Requirements
10+
| mindspore | ascend driver | firmware | cann toolkit/kernel|
11+
| :-------: | :-----------: | :---------: | :----------------: |
12+
| 2.7.0 | 25.2.0 | 7.7.0.6.236 | 8.2.RC1 |
13+
14+
#### Dataset Preparation
15+
16+
Prepare your dataset following the format shown in `https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset`
17+
18+
## Start Finetuning
19+
20+
We use the script `scripts/train_lora_2p.sh` to start the finetuning process. You can modify the parameters in the script as needed.
21+
22+
```bash
23+
bash scripts/train_lora_2p.sh
24+
```
25+
26+
The lora checkpoint and the visualization results will be saved in the `output` directory by default.
27+
28+
## Fintune Result
29+
30+
After finetuning on 2 Ascend devices, we obtained the following results:
31+
32+
Training loss curve:
33+
34+
35+
36+
## Performance
37+
38+
|model | precision | task | resolution | card | batch size | recompute | s/step |
39+
|--------------|-----------|---------------|------------|------| ---------- | --------- |--------|
40+
|Wan2.2-TI2V-5B| bf16 | Text-To-Video | 704x1280 | 2 | 2 | ON | 27 |

examples/wan2_2/generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def _parse_args():
102102
"--frame_num", type=int, default=None, help="How many frames of video are generated. The number should be 4n+1"
103103
)
104104
parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory.")
105+
parser.add_argument("--lora_dir", type=str, default=None, help="The path to the LoRA checkpoint directory.")
105106
parser.add_argument(
106107
"--offload_model",
107108
type=str2bool,
@@ -302,6 +303,7 @@ def generate(args):
302303
use_sp=(args.ulysses_size > 1),
303304
t5_cpu=args.t5_cpu,
304305
convert_model_dtype=args.convert_model_dtype,
306+
lora_dir=args.lora_dir,
305307
)
306308

307309
logging.info("Generating video ...")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
msrun --worker_num=2 --local_worker_num=2 train.py --task ti2v-5B --size 1280*704 --t5_zero3 \
2+
--ckpt_dir ./model/Wan2.2-TI2V-5B \
3+
--data_root ./data/Disney-VideoGeneration-Dataset \
4+
--caption_column prompt.txt \
5+
--video_column videos.txt

examples/wan2_2/train.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2+
import argparse
3+
import logging
4+
import os
5+
import random
6+
import sys
7+
8+
from PIL import Image
9+
10+
import mindspore as ms
11+
import mindspore.mint.distributed as dist
12+
13+
__dir__ = os.path.dirname(os.path.abspath(__file__))
14+
mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../"))
15+
sys.path.insert(0, mindone_lib_path)
16+
17+
import wan
18+
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
19+
from wan.distributed.util import init_distributed_group
20+
from wan.trainer import LoRATrainer, create_video_dataset
21+
from wan.utils.utils import str2bool
22+
23+
EXAMPLE_PROMPT = {
24+
"t2v-A14B": {
25+
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
26+
},
27+
"i2v-A14B": {
28+
"prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. "
29+
"The fluffy-furred feline gazes directly at the camera with a relaxed expression. "
30+
"Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. "
31+
"The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. "
32+
"A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
33+
"image": "examples/i2v_input.JPG",
34+
},
35+
"ti2v-5B": {
36+
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
37+
},
38+
}
39+
40+
41+
def _validate_args(args):
42+
# Basic check
43+
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
44+
assert args.task in WAN_CONFIGS, f"Unsupported task: {args.task}"
45+
assert args.task in EXAMPLE_PROMPT, f"Unsupported task: {args.task}"
46+
47+
if args.prompt is None:
48+
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
49+
if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
50+
args.image = EXAMPLE_PROMPT[args.task]["image"]
51+
52+
if args.task == "i2v-A14B":
53+
assert args.image is not None, "Please specify the image path for i2v."
54+
55+
cfg = WAN_CONFIGS[args.task]
56+
57+
if args.sample_steps is None:
58+
args.sample_steps = cfg.sample_steps
59+
60+
if args.sample_shift is None:
61+
args.sample_shift = cfg.sample_shift
62+
63+
if args.sample_guide_scale is None:
64+
args.sample_guide_scale = cfg.sample_guide_scale
65+
66+
if args.frame_num is None:
67+
args.frame_num = cfg.frame_num
68+
69+
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)
70+
# Size check
71+
assert (
72+
args.size in SUPPORTED_SIZES[args.task]
73+
), f"Unsupported size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
74+
75+
76+
def _parse_args():
77+
parser = argparse.ArgumentParser(description="Generate a image or video from a text prompt or image using Wan")
78+
parser.add_argument(
79+
"--task", type=str, default="t2v-A14B", choices=list(WAN_CONFIGS.keys()), help="The task to run."
80+
)
81+
parser.add_argument(
82+
"--size",
83+
type=str,
84+
default="1280*720",
85+
choices=list(SIZE_CONFIGS.keys()),
86+
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image.",
87+
)
88+
parser.add_argument(
89+
"--frame_num", type=int, default=None, help="How many frames of video are generated. The number should be 4n+1"
90+
)
91+
parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory.")
92+
parser.add_argument(
93+
"--offload_model",
94+
type=str2bool,
95+
default=None,
96+
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage.",
97+
)
98+
parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.")
99+
parser.add_argument("--t5_zero3", action="store_true", default=False, help="Whether to use ZeRO3 for T5.")
100+
parser.add_argument("--t5_cpu", action="store_true", default=False, help="Whether to place T5 model on CPU.")
101+
parser.add_argument("--save_file", type=str, default=None, help="The file to save the generated video to.")
102+
parser.add_argument("--prompt", type=str, default=None, help="The prompt to generate the video from.")
103+
parser.add_argument("--base_seed", type=int, default=-1, help="The seed to use for generating the video.")
104+
parser.add_argument("--image", type=str, default=None, help="The image to generate the video from.")
105+
parser.add_argument(
106+
"--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++"], help="The solver used to sample."
107+
)
108+
parser.add_argument("--sample_steps", type=int, default=None, help="The sampling steps.")
109+
parser.add_argument(
110+
"--sample_shift", type=float, default=None, help="Sampling shift factor for flow matching schedulers."
111+
)
112+
parser.add_argument("--sample_guide_scale", type=float, default=None, help="Classifier free guidance scale.")
113+
parser.add_argument("--text_dropout_rate", type=float, default=0.1, help="The dropout rate for text encoder.")
114+
parser.add_argument("--validation_interval", type=int, default=100, help="The interval for validation.")
115+
parser.add_argument("--save_interval", type=int, default=100, help="The interval for saving checkpoints.")
116+
parser.add_argument("--output_dir", type=str, default="./output", help="The output directory to save checkpoints.")
117+
parser.add_argument("--learning_rate", type=float, default=1e-4, help="The learning rate for training.")
118+
parser.add_argument("--weight_decay", type=float, default=0.01, help="The weight decay for training.")
119+
parser.add_argument("--data_root", type=str, default=None, help="The root directory of training data.")
120+
parser.add_argument("--dataset_file", type=str, default=None, help="The dataset file for training data.")
121+
parser.add_argument("--caption_column", type=str, default="caption", help="The caption column in the dataset file.")
122+
parser.add_argument("--video_column", type=str, default="video", help="The video column in the dataset file.")
123+
parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
124+
parser.add_argument(
125+
"--text_drop_prob", type=float, default=0.1, help="The probability of dropping text input during training."
126+
)
127+
parser.add_argument("--num_epochs", type=int, default=100, help="The number of epochs for training.")
128+
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="The maximum gradient norm for clipping.")
129+
130+
args = parser.parse_args()
131+
132+
_validate_args(args)
133+
134+
return args
135+
136+
137+
def _init_logging(rank):
138+
# logging
139+
if rank == 0:
140+
# set format
141+
logging.basicConfig(
142+
level=logging.INFO,
143+
format="[%(asctime)s] %(levelname)s: %(message)s",
144+
handlers=[logging.StreamHandler(stream=sys.stdout)],
145+
)
146+
else:
147+
logging.basicConfig(level=logging.ERROR)
148+
149+
150+
def train(args):
151+
rank = int(os.getenv("RANK_ID", 0))
152+
world_size = int(os.getenv("RANK_SIZE", 1))
153+
_init_logging(rank)
154+
155+
if args.offload_model is None:
156+
args.offload_model = False
157+
logging.info(f"offload_model is not specified, set to {args.offload_model}.")
158+
159+
if args.offload_model:
160+
raise ValueError("offload_model is not supported in training currently.")
161+
162+
if world_size > 1:
163+
dist.init_process_group(backend="hccl", init_method="env://", rank=rank, world_size=world_size)
164+
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL)
165+
else:
166+
assert not (args.t5_zero3), "t5_zero3 are not supported in non-distributed environments."
167+
assert not (args.ulysses_size > 1), "sequence parallel are not supported in non-distributed environments."
168+
169+
if args.ulysses_size > 1:
170+
assert args.ulysses_size == world_size, "The number of ulysses_size should be equal to the world size."
171+
init_distributed_group()
172+
173+
cfg = WAN_CONFIGS[args.task]
174+
if args.ulysses_size > 1:
175+
assert (
176+
cfg.num_heads % args.ulysses_size == 0
177+
), f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."
178+
179+
logging.info(f"Training job args: {args}")
180+
logging.info(f"Training model config: {cfg}")
181+
182+
if dist.is_initialized():
183+
base_seed = [args.base_seed] if rank == 0 else [None]
184+
dist.broadcast_object_list(base_seed, src=0)
185+
args.base_seed = base_seed[0]
186+
187+
logging.info(f"Input prompt: {args.prompt}")
188+
img = None
189+
if args.image is not None:
190+
img = Image.open(args.image).convert("RGB")
191+
logging.info(f"Input image: {args.image}")
192+
193+
if "t2v" in args.task:
194+
raise NotImplementedError
195+
elif "ti2v" in args.task:
196+
logging.info("Creating WanTI2V pipeline.")
197+
wan_ti2v = wan.WanTI2V(
198+
config=cfg,
199+
checkpoint_dir=args.ckpt_dir,
200+
rank=rank,
201+
t5_zero3=args.t5_zero3,
202+
use_sp=(args.ulysses_size > 1),
203+
t5_cpu=args.t5_cpu,
204+
convert_model_dtype=True,
205+
)
206+
207+
logging.info("Prepare trainer ...")
208+
size_buckets = tuple(SIZE_CONFIGS[size] for size in SUPPORTED_SIZES[args.task])
209+
train_loader = create_video_dataset(
210+
data_root=args.data_root,
211+
dataset_file=args.dataset_file,
212+
caption_column=args.caption_column,
213+
video_column=args.video_column,
214+
frame_num=args.frame_num,
215+
size_buckets=size_buckets,
216+
batch_size=args.batch_size,
217+
num_shards=world_size,
218+
shard_id=rank,
219+
text_drop_prob=args.text_drop_prob,
220+
)
221+
training_config = dict(
222+
learning_rate=args.learning_rate,
223+
weight_decay=args.weight_decay,
224+
validation_interval=args.validation_interval,
225+
save_interval=args.save_interval,
226+
output_dir=os.path.join(args.output_dir, "ckpt"),
227+
frame_num=args.frame_num,
228+
num_train_timesteps=cfg.num_train_timesteps,
229+
vae_stride=cfg.vae_stride,
230+
patch_size=cfg.patch_size,
231+
max_grad_norm=args.max_grad_norm,
232+
)
233+
generation_config = dict(
234+
input_prompt=args.prompt,
235+
img=img,
236+
size=SIZE_CONFIGS[args.size],
237+
max_area=MAX_AREA_CONFIGS[args.size],
238+
frame_num=args.frame_num,
239+
shift=args.sample_shift,
240+
sample_solver=args.sample_solver,
241+
sampling_steps=args.sample_steps,
242+
guide_scale=args.sample_guide_scale,
243+
seed=args.base_seed,
244+
offload_model=args.offload_model,
245+
output_dir=os.path.join(args.output_dir, "visual"),
246+
sample_fps=cfg.sample_fps,
247+
)
248+
trainer = LoRATrainer(wan_ti2v, train_loader, training_config, generation_config)
249+
250+
logging.info("Start training ...")
251+
trainer.train(args.num_epochs)
252+
else:
253+
raise NotImplementedError
254+
255+
256+
if __name__ == "__main__":
257+
args = _parse_args()
258+
train(args)

examples/wan2_2/wan/modules/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,11 @@ def construct(
498498

499499
x = x.to(self.dtype)
500500
for block in self.blocks:
501-
x = block(x, **kwargs)
501+
if self.training:
502+
# recompute to save memory
503+
x = ms.recompute(block, x, **kwargs)
504+
else:
505+
x = block(x, **kwargs)
502506

503507
# head
504508
x = self.head(x, e)

examples/wan2_2/wan/modules/vae2_1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def construct(self, x: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor]:
484484
return x_recon, mu, log_var
485485

486486
def encode(
487-
self, x: ms.Tensor, scale: List[Union[ms.Tensor, float]]
487+
self, x: ms.Tensor, scale: List[Union[ms.Tensor, float]], return_log_var: bool = False
488488
) -> Union[Tuple[ms.Tensor, ms.Tensor], ms.Tensor]:
489489
self.clear_cache()
490490
# cache
@@ -508,6 +508,9 @@ def encode(
508508
else:
509509
mu = (mu - scale[0]) * scale[1]
510510
self.clear_cache()
511+
512+
if return_log_var:
513+
return mu, log_var
511514
return mu
512515

513516
def decode(self, z: ms.Tensor, scale: List[Union[ms.Tensor, float]]) -> ms.Tensor:

examples/wan2_2/wan/modules/vae2_2.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
22
import logging
33
import math
4-
from typing import Any, List, Optional, Tuple
4+
from typing import Any, List, Optional, Tuple, Union
55

66
import numpy as np
77

@@ -684,7 +684,9 @@ def construct(self, x: ms.Tensor, scale: List[Any] = [0, 1]) -> Tuple[ms.Tensor,
684684
x_recon = self.decode(mu, scale)
685685
return x_recon, mu
686686

687-
def encode(self, x: ms.Tensor, scale: List[Any]) -> ms.Tensor:
687+
def encode(
688+
self, x: ms.Tensor, scale: List[Any], return_log_var: bool = False
689+
) -> Union[ms.Tensor, Tuple[ms.Tensor, ms.Tensor]]:
688690
self.clear_cache()
689691
x = patchify(x, patch_size=2)
690692
t = x.shape[2]
@@ -706,6 +708,9 @@ def encode(self, x: ms.Tensor, scale: List[Any]) -> ms.Tensor:
706708
else:
707709
mu = (mu - scale[0]) * scale[1]
708710
self.clear_cache()
711+
712+
if return_log_var:
713+
return mu, log_var
709714
return mu
710715

711716
def decode(self, z: ms.Tensor, scale: List[Any]) -> ms.Tensor:

0 commit comments

Comments
 (0)