Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions examples/mmada/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,6 @@ The following experiments are tested on Ascend Atlas 800T A2 machines with minds
|:-:|:-:|:-:|:-:|:-:|:-:|
| MMaDA-8B-Base | 8 | 4 | zero2 | finetune | 1.29 |

The following experiments are tested on Ascend Atlas 800T A2 machines with mindspore **2.6.0** under **pynative** mode:

| model | # card(s) | batch size | parallelism | task | per batch time (seconds) |
|:-:|:-:|:-:|:-:|:-:|:-:|
| MMaDA-8B-Base | 8 | 4 | zero2 | finetune | 1.30 |



## 🤝 Acknowledgments

Expand Down
40 changes: 40 additions & 0 deletions examples/wan2_2/finetune.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Wan2.2 LoRA Finetune

We provide an example of how to finetune Wan2.2 model (5B) Text-to-Video task using LoRA (Low-Rank Adaptation) technique.

## Prerequisites

Before running the finetuning script, ensure you have the following prerequisites:

#### Requirements
| mindspore | ascend driver | firmware | cann toolkit/kernel|
| :-------: | :-----------: | :---------: | :----------------: |
| 2.7.0 | 25.2.0 | 7.7.0.6.236 | 8.2.RC1 |

#### Dataset Preparation

Prepare your dataset following the format shown in `https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset`

## Start Finetuning

We use the script `scripts/train_lora_2p.sh` to start the finetuning process. You can modify the parameters in the script as needed.

```bash
bash scripts/train_lora_2p.sh
```

The lora checkpoint and the visualization results will be saved in the `output` directory by default.

## Fintune Result

After finetuning on 2 Ascend devices, we obtained the following results:

Training loss curve:



## Performance

|model | precision | task | resolution | card | batch size | recompute | s/step |
|--------------|-----------|---------------|------------|------| ---------- | --------- |--------|
|Wan2.2-TI2V-5B| bf16 | Text-To-Video | 704x1280 | 2 | 2 | ON | 27 |
2 changes: 2 additions & 0 deletions examples/wan2_2/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _parse_args():
"--frame_num", type=int, default=None, help="How many frames of video are generated. The number should be 4n+1"
)
parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory.")
parser.add_argument("--lora_dir", type=str, default=None, help="The path to the LoRA checkpoint directory.")
parser.add_argument(
"--offload_model",
type=str2bool,
Expand Down Expand Up @@ -302,6 +303,7 @@ def generate(args):
use_sp=(args.ulysses_size > 1),
t5_cpu=args.t5_cpu,
convert_model_dtype=args.convert_model_dtype,
lora_dir=args.lora_dir,
)

logging.info("Generating video ...")
Expand Down
5 changes: 5 additions & 0 deletions examples/wan2_2/scripts/train_lora_2p.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
msrun --worker_num=2 --local_worker_num=2 train.py --task ti2v-5B --size 1280*704 --t5_zero3 \
--ckpt_dir ./model/Wan2.2-TI2V-5B \
--data_root ./data/Disney-VideoGeneration-Dataset \
--caption_column prompt.txt \
--video_column videos.txt
258 changes: 258 additions & 0 deletions examples/wan2_2/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import logging
import os
import random
import sys

from PIL import Image

import mindspore as ms
import mindspore.mint.distributed as dist

__dir__ = os.path.dirname(os.path.abspath(__file__))
mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../"))
sys.path.insert(0, mindone_lib_path)

import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.distributed.util import init_distributed_group
from wan.trainer import LoRATrainer, create_video_dataset
from wan.utils.utils import str2bool

EXAMPLE_PROMPT = {
"t2v-A14B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"i2v-A14B": {
"prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. "
"The fluffy-furred feline gazes directly at the camera with a relaxed expression. "
"Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. "
"The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. "
"A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image": "examples/i2v_input.JPG",
},
"ti2v-5B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
}


def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupported task: {args.task}"
assert args.task in EXAMPLE_PROMPT, f"Unsupported task: {args.task}"

if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
args.image = EXAMPLE_PROMPT[args.task]["image"]

if args.task == "i2v-A14B":
assert args.image is not None, "Please specify the image path for i2v."

cfg = WAN_CONFIGS[args.task]

if args.sample_steps is None:
args.sample_steps = cfg.sample_steps

if args.sample_shift is None:
args.sample_shift = cfg.sample_shift

if args.sample_guide_scale is None:
args.sample_guide_scale = cfg.sample_guide_scale

if args.frame_num is None:
args.frame_num = cfg.frame_num

args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)
# Size check
assert (
args.size in SUPPORTED_SIZES[args.task]
), f"Unsupported size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"


def _parse_args():
parser = argparse.ArgumentParser(description="Generate a image or video from a text prompt or image using Wan")
parser.add_argument(
"--task", type=str, default="t2v-A14B", choices=list(WAN_CONFIGS.keys()), help="The task to run."
)
parser.add_argument(
"--size",
type=str,
default="1280*720",
choices=list(SIZE_CONFIGS.keys()),
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image.",
)
parser.add_argument(
"--frame_num", type=int, default=None, help="How many frames of video are generated. The number should be 4n+1"
)
parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory.")
parser.add_argument(
"--offload_model",
type=str2bool,
default=None,
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage.",
)
parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.")
parser.add_argument("--t5_zero3", action="store_true", default=False, help="Whether to use ZeRO3 for T5.")
parser.add_argument("--t5_cpu", action="store_true", default=False, help="Whether to place T5 model on CPU.")
parser.add_argument("--save_file", type=str, default=None, help="The file to save the generated video to.")
parser.add_argument("--prompt", type=str, default=None, help="The prompt to generate the video from.")
parser.add_argument("--base_seed", type=int, default=-1, help="The seed to use for generating the video.")
parser.add_argument("--image", type=str, default=None, help="The image to generate the video from.")
parser.add_argument(
"--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++"], help="The solver used to sample."
)
parser.add_argument("--sample_steps", type=int, default=None, help="The sampling steps.")
parser.add_argument(
"--sample_shift", type=float, default=None, help="Sampling shift factor for flow matching schedulers."
)
parser.add_argument("--sample_guide_scale", type=float, default=None, help="Classifier free guidance scale.")
parser.add_argument("--text_dropout_rate", type=float, default=0.1, help="The dropout rate for text encoder.")
parser.add_argument("--validation_interval", type=int, default=100, help="The interval for validation.")
parser.add_argument("--save_interval", type=int, default=100, help="The interval for saving checkpoints.")
parser.add_argument("--output_dir", type=str, default="./output", help="The output directory to save checkpoints.")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="The learning rate for training.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="The weight decay for training.")
parser.add_argument("--data_root", type=str, default=None, help="The root directory of training data.")
parser.add_argument("--dataset_file", type=str, default=None, help="The dataset file for training data.")
parser.add_argument("--caption_column", type=str, default="caption", help="The caption column in the dataset file.")
parser.add_argument("--video_column", type=str, default="video", help="The video column in the dataset file.")
parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
parser.add_argument(
"--text_drop_prob", type=float, default=0.1, help="The probability of dropping text input during training."
)
parser.add_argument("--num_epochs", type=int, default=100, help="The number of epochs for training.")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="The maximum gradient norm for clipping.")

args = parser.parse_args()

_validate_args(args)

return args


def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)],
)
else:
logging.basicConfig(level=logging.ERROR)


def train(args):
rank = int(os.getenv("RANK_ID", 0))
world_size = int(os.getenv("RANK_SIZE", 1))
_init_logging(rank)

if args.offload_model is None:
args.offload_model = False
logging.info(f"offload_model is not specified, set to {args.offload_model}.")

if args.offload_model:
raise ValueError("offload_model is not supported in training currently.")

if world_size > 1:
dist.init_process_group(backend="hccl", init_method="env://", rank=rank, world_size=world_size)
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL)
else:
assert not (args.t5_zero3), "t5_zero3 are not supported in non-distributed environments."
assert not (args.ulysses_size > 1), "sequence parallel are not supported in non-distributed environments."

if args.ulysses_size > 1:
assert args.ulysses_size == world_size, "The number of ulysses_size should be equal to the world size."
init_distributed_group()

cfg = WAN_CONFIGS[args.task]
if args.ulysses_size > 1:
assert (
cfg.num_heads % args.ulysses_size == 0
), f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."

logging.info(f"Training job args: {args}")
logging.info(f"Training model config: {cfg}")

if dist.is_initialized():
base_seed = [args.base_seed] if rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]

logging.info(f"Input prompt: {args.prompt}")
img = None
if args.image is not None:
img = Image.open(args.image).convert("RGB")
logging.info(f"Input image: {args.image}")

if "t2v" in args.task:
raise NotImplementedError
elif "ti2v" in args.task:
logging.info("Creating WanTI2V pipeline.")
wan_ti2v = wan.WanTI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
rank=rank,
t5_zero3=args.t5_zero3,
use_sp=(args.ulysses_size > 1),
t5_cpu=args.t5_cpu,
convert_model_dtype=True,
)

logging.info("Prepare trainer ...")
size_buckets = tuple(SIZE_CONFIGS[size] for size in SUPPORTED_SIZES[args.task])
train_loader = create_video_dataset(
data_root=args.data_root,
dataset_file=args.dataset_file,
caption_column=args.caption_column,
video_column=args.video_column,
frame_num=args.frame_num,
size_buckets=size_buckets,
batch_size=args.batch_size,
num_shards=world_size,
shard_id=rank,
text_drop_prob=args.text_drop_prob,
)
training_config = dict(
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
validation_interval=args.validation_interval,
save_interval=args.save_interval,
output_dir=os.path.join(args.output_dir, "ckpt"),
frame_num=args.frame_num,
num_train_timesteps=cfg.num_train_timesteps,
vae_stride=cfg.vae_stride,
patch_size=cfg.patch_size,
max_grad_norm=args.max_grad_norm,
)
generation_config = dict(
input_prompt=args.prompt,
img=img,
size=SIZE_CONFIGS[args.size],
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model,
output_dir=os.path.join(args.output_dir, "visual"),
sample_fps=cfg.sample_fps,
)
trainer = LoRATrainer(wan_ti2v, train_loader, training_config, generation_config)

logging.info("Start training ...")
trainer.train(args.num_epochs)
else:
raise NotImplementedError


if __name__ == "__main__":
args = _parse_args()
train(args)
6 changes: 5 additions & 1 deletion examples/wan2_2/wan/modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,11 @@ def construct(

x = x.to(self.dtype)
for block in self.blocks:
x = block(x, **kwargs)
if self.training:
# recompute to save memory
x = ms.recompute(block, x, **kwargs)
else:
x = block(x, **kwargs)

# head
x = self.head(x, e)
Expand Down
5 changes: 4 additions & 1 deletion examples/wan2_2/wan/modules/vae2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def construct(self, x: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor]:
return x_recon, mu, log_var

def encode(
self, x: ms.Tensor, scale: List[Union[ms.Tensor, float]]
self, x: ms.Tensor, scale: List[Union[ms.Tensor, float]], return_log_var: bool = False
) -> Union[Tuple[ms.Tensor, ms.Tensor], ms.Tensor]:
self.clear_cache()
# cache
Expand All @@ -508,6 +508,9 @@ def encode(
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()

if return_log_var:
return mu, log_var
return mu

def decode(self, z: ms.Tensor, scale: List[Union[ms.Tensor, float]]) -> ms.Tensor:
Expand Down
9 changes: 7 additions & 2 deletions examples/wan2_2/wan/modules/vae2_2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -684,7 +684,9 @@ def construct(self, x: ms.Tensor, scale: List[Any] = [0, 1]) -> Tuple[ms.Tensor,
x_recon = self.decode(mu, scale)
return x_recon, mu

def encode(self, x: ms.Tensor, scale: List[Any]) -> ms.Tensor:
def encode(
self, x: ms.Tensor, scale: List[Any], return_log_var: bool = False
) -> Union[ms.Tensor, Tuple[ms.Tensor, ms.Tensor]]:
self.clear_cache()
x = patchify(x, patch_size=2)
t = x.shape[2]
Expand All @@ -706,6 +708,9 @@ def encode(self, x: ms.Tensor, scale: List[Any]) -> ms.Tensor:
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()

if return_log_var:
return mu, log_var
return mu

def decode(self, z: ms.Tensor, scale: List[Any]) -> ms.Tensor:
Expand Down
Loading