Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported
- **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported.
- **`2025/10/14`**: NVIDIA DGX Spark Flux support.
- **`2025/8/14`**: LTX-Video img2vid generation is now supported.
Expand Down
143 changes: 127 additions & 16 deletions src/maxdiffusion/checkpointing/wan_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,37 @@
limitations under the License.
"""

from abc import ABC
from abc import ABC, abstractmethod
import json

import jax
import numpy as np
from typing import Optional, Tuple
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
from ..pipelines.wan.wan_pipeline import WanPipeline
from ..pipelines.wan.wan_pipeline import WanPipeline2_1, WanPipeline2_2
from .. import max_logging, max_utils
import orbax.checkpoint as ocp
from etils import epath


WAN_CHECKPOINT = "WAN_CHECKPOINT"


class WanCheckpointer(ABC):

def __init__(self, config, checkpoint_type):
def __init__(self, config, checkpoint_type: str = WAN_CHECKPOINT):
self.config = config
self.checkpoint_type = checkpoint_type
self.opt_state = None

self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=checkpoint_type,
dataset_type=config.dataset_type,
self.checkpoint_manager: ocp.CheckpointManager = (
create_orbax_checkpoint_manager(
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=checkpoint_type,
dataset_type=config.dataset_type,
)
)

def _create_optimizer(self, model, config, learning_rate):
Expand All @@ -51,6 +54,25 @@ def _create_optimizer(self, model, config, learning_rate):
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
return tx, learning_rate_scheduler

@abstractmethod
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
raise NotImplementedError

@abstractmethod
def load_diffusers_checkpoint(self):
raise NotImplementedError

@abstractmethod
def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2], Optional[dict], Optional[int]]:
raise NotImplementedError

@abstractmethod
def save_checkpoint(self, train_step, pipeline, train_states: dict):
raise NotImplementedError


class WanCheckpointer2_1(WanCheckpointer):

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
step = self.checkpoint_manager.latest_step()
Expand Down Expand Up @@ -85,24 +107,24 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = WanPipeline.from_pretrained(self.config)
pipeline = WanPipeline2_1.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]:
def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_1, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
if "opt_state" in restored_checkpoint["wan_state"].keys():
opt_state = restored_checkpoint["wan_state"]["opt_state"]
pipeline = WanPipeline2_1.from_checkpoint(self.config, restored_checkpoint)
if "opt_state" in restored_checkpoint.wan_state.keys():
opt_state = restored_checkpoint.wan_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step

def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
def save_checkpoint(self, train_step, pipeline: WanPipeline2_1, train_states: dict):
"""Saves the training state and model configurations."""

def config_to_json(model_or_config):
Expand All @@ -120,7 +142,96 @@ def config_to_json(model_or_config):
max_logging.log(f"Checkpoint for step {train_step} saved.")


def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
class WanCheckpointer2_2(WanCheckpointer):

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
step = self.checkpoint_manager.latest_step()
max_logging.log(f"Latest WAN checkpoint step: {step}")
if step is None:
max_logging.log("No WAN checkpoint found.")
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")
metadatas = self.checkpoint_manager.item_metadata(step)

# Handle low_noise_transformer
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata)
low_params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_low_params,
)
)

# Handle high_noise_transformer
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata)
high_params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_high_params,
)
)

max_logging.log("Restoring WAN 2.2 checkpoint")
restored_checkpoint = self.checkpoint_manager.restore(
directory=epath.Path(self.config.checkpoint_dir),
step=step,
args=ocp.args.Composite(
low_noise_transformer_state=low_params_restore,
high_noise_transformer_state=high_params_restore,
wan_config=ocp.args.JsonRestore(),
),
)
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}")
max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}")
max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}")
max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}")
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = WanPipeline2_2.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint)
# Check for optimizer state in either transformer
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step

def save_checkpoint(self, train_step, pipeline: WanPipeline2_2, train_states: dict):
"""Saves the training state and model configurations."""

def config_to_json(model_or_config):
return json.loads(model_or_config.to_json_string())

max_logging.log(f"Saving checkpoint for step {train_step}")
items = {
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
}

items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])

# Save the checkpoint
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
max_logging.log(f"Checkpoint for step {train_step} saved.")

def save_checkpoint_orig(self, train_step, pipeline, train_states: dict):
"""Saves the training state and model configurations."""

def config_to_json(model_or_config):
Expand Down
Loading