|
| 1 | +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. |
| 2 | +import argparse |
| 3 | +import logging |
| 4 | +import os |
| 5 | +import random |
| 6 | +import sys |
| 7 | + |
| 8 | +from PIL import Image |
| 9 | + |
| 10 | +import mindspore as ms |
| 11 | +import mindspore.mint.distributed as dist |
| 12 | + |
| 13 | +__dir__ = os.path.dirname(os.path.abspath(__file__)) |
| 14 | +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) |
| 15 | +sys.path.insert(0, mindone_lib_path) |
| 16 | + |
| 17 | +import wan |
| 18 | +from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS |
| 19 | +from wan.distributed.util import init_distributed_group |
| 20 | +from wan.trainer import LoRATrainer, create_video_dataset |
| 21 | +from wan.utils.utils import str2bool |
| 22 | + |
| 23 | +EXAMPLE_PROMPT = { |
| 24 | + "t2v-A14B": { |
| 25 | + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", |
| 26 | + }, |
| 27 | + "i2v-A14B": { |
| 28 | + "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. " |
| 29 | + "The fluffy-furred feline gazes directly at the camera with a relaxed expression. " |
| 30 | + "Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. " |
| 31 | + "The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. " |
| 32 | + "A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", |
| 33 | + "image": "examples/i2v_input.JPG", |
| 34 | + }, |
| 35 | + "ti2v-5B": { |
| 36 | + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", |
| 37 | + }, |
| 38 | +} |
| 39 | + |
| 40 | + |
| 41 | +def _validate_args(args): |
| 42 | + # Basic check |
| 43 | + assert args.ckpt_dir is not None, "Please specify the checkpoint directory." |
| 44 | + assert args.task in WAN_CONFIGS, f"Unsupported task: {args.task}" |
| 45 | + assert args.task in EXAMPLE_PROMPT, f"Unsupported task: {args.task}" |
| 46 | + |
| 47 | + if args.prompt is None: |
| 48 | + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] |
| 49 | + if args.image is None and "image" in EXAMPLE_PROMPT[args.task]: |
| 50 | + args.image = EXAMPLE_PROMPT[args.task]["image"] |
| 51 | + |
| 52 | + if args.task == "i2v-A14B": |
| 53 | + assert args.image is not None, "Please specify the image path for i2v." |
| 54 | + |
| 55 | + cfg = WAN_CONFIGS[args.task] |
| 56 | + |
| 57 | + if args.sample_steps is None: |
| 58 | + args.sample_steps = cfg.sample_steps |
| 59 | + |
| 60 | + if args.sample_shift is None: |
| 61 | + args.sample_shift = cfg.sample_shift |
| 62 | + |
| 63 | + if args.sample_guide_scale is None: |
| 64 | + args.sample_guide_scale = cfg.sample_guide_scale |
| 65 | + |
| 66 | + if args.frame_num is None: |
| 67 | + args.frame_num = cfg.frame_num |
| 68 | + |
| 69 | + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize) |
| 70 | + # Size check |
| 71 | + assert ( |
| 72 | + args.size in SUPPORTED_SIZES[args.task] |
| 73 | + ), f"Unsupported size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" |
| 74 | + |
| 75 | + |
| 76 | +def _parse_args(): |
| 77 | + parser = argparse.ArgumentParser(description="Generate a image or video from a text prompt or image using Wan") |
| 78 | + parser.add_argument( |
| 79 | + "--task", type=str, default="t2v-A14B", choices=list(WAN_CONFIGS.keys()), help="The task to run." |
| 80 | + ) |
| 81 | + parser.add_argument( |
| 82 | + "--size", |
| 83 | + type=str, |
| 84 | + default="1280*720", |
| 85 | + choices=list(SIZE_CONFIGS.keys()), |
| 86 | + help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image.", |
| 87 | + ) |
| 88 | + parser.add_argument( |
| 89 | + "--frame_num", type=int, default=None, help="How many frames of video are generated. The number should be 4n+1" |
| 90 | + ) |
| 91 | + parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory.") |
| 92 | + parser.add_argument( |
| 93 | + "--offload_model", |
| 94 | + type=str2bool, |
| 95 | + default=None, |
| 96 | + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage.", |
| 97 | + ) |
| 98 | + parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.") |
| 99 | + parser.add_argument("--t5_zero3", action="store_true", default=False, help="Whether to use ZeRO3 for T5.") |
| 100 | + parser.add_argument("--t5_cpu", action="store_true", default=False, help="Whether to place T5 model on CPU.") |
| 101 | + parser.add_argument("--save_file", type=str, default=None, help="The file to save the generated video to.") |
| 102 | + parser.add_argument("--prompt", type=str, default=None, help="The prompt to generate the video from.") |
| 103 | + parser.add_argument("--base_seed", type=int, default=-1, help="The seed to use for generating the video.") |
| 104 | + parser.add_argument("--image", type=str, default=None, help="The image to generate the video from.") |
| 105 | + parser.add_argument( |
| 106 | + "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++"], help="The solver used to sample." |
| 107 | + ) |
| 108 | + parser.add_argument("--sample_steps", type=int, default=None, help="The sampling steps.") |
| 109 | + parser.add_argument( |
| 110 | + "--sample_shift", type=float, default=None, help="Sampling shift factor for flow matching schedulers." |
| 111 | + ) |
| 112 | + parser.add_argument("--sample_guide_scale", type=float, default=None, help="Classifier free guidance scale.") |
| 113 | + parser.add_argument("--text_dropout_rate", type=float, default=0.1, help="The dropout rate for text encoder.") |
| 114 | + parser.add_argument("--validation_interval", type=int, default=100, help="The interval for validation.") |
| 115 | + parser.add_argument("--save_interval", type=int, default=100, help="The interval for saving checkpoints.") |
| 116 | + parser.add_argument("--output_dir", type=str, default="./output", help="The output directory to save checkpoints.") |
| 117 | + parser.add_argument("--learning_rate", type=float, default=1e-4, help="The learning rate for training.") |
| 118 | + parser.add_argument("--weight_decay", type=float, default=0.01, help="The weight decay for training.") |
| 119 | + parser.add_argument("--data_root", type=str, default=None, help="The root directory of training data.") |
| 120 | + parser.add_argument("--dataset_file", type=str, default=None, help="The dataset file for training data.") |
| 121 | + parser.add_argument("--caption_column", type=str, default="caption", help="The caption column in the dataset file.") |
| 122 | + parser.add_argument("--video_column", type=str, default="video", help="The video column in the dataset file.") |
| 123 | + parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") |
| 124 | + parser.add_argument( |
| 125 | + "--text_drop_prob", type=float, default=0.1, help="The probability of dropping text input during training." |
| 126 | + ) |
| 127 | + parser.add_argument("--num_epochs", type=int, default=100, help="The number of epochs for training.") |
| 128 | + parser.add_argument("--max_grad_norm", type=float, default=1.0, help="The maximum gradient norm for clipping.") |
| 129 | + |
| 130 | + args = parser.parse_args() |
| 131 | + |
| 132 | + _validate_args(args) |
| 133 | + |
| 134 | + return args |
| 135 | + |
| 136 | + |
| 137 | +def _init_logging(rank): |
| 138 | + # logging |
| 139 | + if rank == 0: |
| 140 | + # set format |
| 141 | + logging.basicConfig( |
| 142 | + level=logging.INFO, |
| 143 | + format="[%(asctime)s] %(levelname)s: %(message)s", |
| 144 | + handlers=[logging.StreamHandler(stream=sys.stdout)], |
| 145 | + ) |
| 146 | + else: |
| 147 | + logging.basicConfig(level=logging.ERROR) |
| 148 | + |
| 149 | + |
| 150 | +def train(args): |
| 151 | + rank = int(os.getenv("RANK_ID", 0)) |
| 152 | + world_size = int(os.getenv("RANK_SIZE", 1)) |
| 153 | + _init_logging(rank) |
| 154 | + |
| 155 | + if args.offload_model is None: |
| 156 | + args.offload_model = False |
| 157 | + logging.info(f"offload_model is not specified, set to {args.offload_model}.") |
| 158 | + |
| 159 | + if args.offload_model: |
| 160 | + raise ValueError("offload_model is not supported in training currently.") |
| 161 | + |
| 162 | + if world_size > 1: |
| 163 | + dist.init_process_group(backend="hccl", init_method="env://", rank=rank, world_size=world_size) |
| 164 | + ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL) |
| 165 | + else: |
| 166 | + assert not (args.t5_zero3), "t5_zero3 are not supported in non-distributed environments." |
| 167 | + assert not (args.ulysses_size > 1), "sequence parallel are not supported in non-distributed environments." |
| 168 | + |
| 169 | + if args.ulysses_size > 1: |
| 170 | + assert args.ulysses_size == world_size, "The number of ulysses_size should be equal to the world size." |
| 171 | + init_distributed_group() |
| 172 | + |
| 173 | + cfg = WAN_CONFIGS[args.task] |
| 174 | + if args.ulysses_size > 1: |
| 175 | + assert ( |
| 176 | + cfg.num_heads % args.ulysses_size == 0 |
| 177 | + ), f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`." |
| 178 | + |
| 179 | + logging.info(f"Training job args: {args}") |
| 180 | + logging.info(f"Training model config: {cfg}") |
| 181 | + |
| 182 | + if dist.is_initialized(): |
| 183 | + base_seed = [args.base_seed] if rank == 0 else [None] |
| 184 | + dist.broadcast_object_list(base_seed, src=0) |
| 185 | + args.base_seed = base_seed[0] |
| 186 | + |
| 187 | + logging.info(f"Input prompt: {args.prompt}") |
| 188 | + img = None |
| 189 | + if args.image is not None: |
| 190 | + img = Image.open(args.image).convert("RGB") |
| 191 | + logging.info(f"Input image: {args.image}") |
| 192 | + |
| 193 | + if "t2v" in args.task: |
| 194 | + raise NotImplementedError |
| 195 | + elif "ti2v" in args.task: |
| 196 | + logging.info("Creating WanTI2V pipeline.") |
| 197 | + wan_ti2v = wan.WanTI2V( |
| 198 | + config=cfg, |
| 199 | + checkpoint_dir=args.ckpt_dir, |
| 200 | + rank=rank, |
| 201 | + t5_zero3=args.t5_zero3, |
| 202 | + use_sp=(args.ulysses_size > 1), |
| 203 | + t5_cpu=args.t5_cpu, |
| 204 | + convert_model_dtype=True, |
| 205 | + ) |
| 206 | + |
| 207 | + logging.info("Prepare trainer ...") |
| 208 | + size_buckets = tuple(SIZE_CONFIGS[size] for size in SUPPORTED_SIZES[args.task]) |
| 209 | + train_loader = create_video_dataset( |
| 210 | + data_root=args.data_root, |
| 211 | + dataset_file=args.dataset_file, |
| 212 | + caption_column=args.caption_column, |
| 213 | + video_column=args.video_column, |
| 214 | + frame_num=args.frame_num, |
| 215 | + size_buckets=size_buckets, |
| 216 | + batch_size=args.batch_size, |
| 217 | + num_shards=world_size, |
| 218 | + shard_id=rank, |
| 219 | + text_drop_prob=args.text_drop_prob, |
| 220 | + ) |
| 221 | + training_config = dict( |
| 222 | + learning_rate=args.learning_rate, |
| 223 | + weight_decay=args.weight_decay, |
| 224 | + validation_interval=args.validation_interval, |
| 225 | + save_interval=args.save_interval, |
| 226 | + output_dir=os.path.join(args.output_dir, "ckpt"), |
| 227 | + frame_num=args.frame_num, |
| 228 | + num_train_timesteps=cfg.num_train_timesteps, |
| 229 | + vae_stride=cfg.vae_stride, |
| 230 | + patch_size=cfg.patch_size, |
| 231 | + max_grad_norm=args.max_grad_norm, |
| 232 | + ) |
| 233 | + generation_config = dict( |
| 234 | + input_prompt=args.prompt, |
| 235 | + img=img, |
| 236 | + size=SIZE_CONFIGS[args.size], |
| 237 | + max_area=MAX_AREA_CONFIGS[args.size], |
| 238 | + frame_num=args.frame_num, |
| 239 | + shift=args.sample_shift, |
| 240 | + sample_solver=args.sample_solver, |
| 241 | + sampling_steps=args.sample_steps, |
| 242 | + guide_scale=args.sample_guide_scale, |
| 243 | + seed=args.base_seed, |
| 244 | + offload_model=args.offload_model, |
| 245 | + output_dir=os.path.join(args.output_dir, "visual"), |
| 246 | + sample_fps=cfg.sample_fps, |
| 247 | + ) |
| 248 | + trainer = LoRATrainer(wan_ti2v, train_loader, training_config, generation_config) |
| 249 | + |
| 250 | + logging.info("Start training ...") |
| 251 | + trainer.train(args.num_epochs) |
| 252 | + else: |
| 253 | + raise NotImplementedError |
| 254 | + |
| 255 | + |
| 256 | +if __name__ == "__main__": |
| 257 | + args = _parse_args() |
| 258 | + train(args) |
0 commit comments