Skip to content

Commit a690e96

Browse files
authored
add kandinsky5 (#1388)
1 parent 4107a9d commit a690e96

File tree

13 files changed

+1955
-2
lines changed

13 files changed

+1955
-2
lines changed

docs/diffusers/api/loaders/lora.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
6363

6464
::: mindone.diffusers.loaders.lora_pipeline.QwenImageLoraLoaderMixin
6565

66+
::: mindone.diffusers.loaders.lora_pipeline.KandinskyLoraLoaderMixin
67+
6668
::: mindone.diffusers.loaders.lora_base.LoraBaseMixin

mindone/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
"HunyuanVideoTransformer3DModel",
9696
"I2VGenXLUNet",
9797
"Kandinsky3UNet",
98+
"Kandinsky5Transformer3DModel",
9899
"LatteTransformer3DModel",
99100
"LTXVideoTransformer3DModel",
100101
"Lumina2Transformer2DModel",
@@ -236,6 +237,7 @@
236237
"KandinskyV22PriorPipeline",
237238
"Kandinsky3Img2ImgPipeline",
238239
"Kandinsky3Pipeline",
240+
"Kandinsky5T2VPipeline",
239241
"KolorsPAGPipeline",
240242
"KolorsPipeline",
241243
"KolorsImg2ImgPipeline",
@@ -478,6 +480,7 @@
478480
HunyuanVideoTransformer3DModel,
479481
I2VGenXLUNet,
480482
Kandinsky3UNet,
483+
Kandinsky5Transformer3DModel,
481484
LatteTransformer3DModel,
482485
LTXVideoTransformer3DModel,
483486
Lumina2Transformer2DModel,
@@ -613,6 +616,7 @@
613616
IFSuperResolutionPipeline,
614617
Kandinsky3Img2ImgPipeline,
615618
Kandinsky3Pipeline,
619+
Kandinsky5T2VPipeline,
616620
KandinskyCombinedPipeline,
617621
KandinskyImg2ImgCombinedPipeline,
618622
KandinskyImg2ImgPipeline,

mindone/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder):
7777
"SanaLoraLoaderMixin",
7878
"Lumina2LoraLoaderMixin",
7979
"WanLoraLoaderMixin",
80+
"KandinskyLoraLoaderMixin",
8081
"HiDreamImageLoraLoaderMixin",
8182
"SkyReelsV2LoraLoaderMixin",
8283
],
@@ -97,6 +98,7 @@ def text_encoder_attn_modules(text_encoder):
9798
FluxLoraLoaderMixin,
9899
HiDreamImageLoraLoaderMixin,
99100
HunyuanVideoLoraLoaderMixin,
101+
KandinskyLoraLoaderMixin,
100102
LoraLoaderMixin,
101103
LTXVideoLoraLoaderMixin,
102104
Lumina2LoraLoaderMixin,

mindone/diffusers/loaders/lora_pipeline.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4667,6 +4667,268 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
46674667
super().unfuse_lora(components=components, **kwargs)
46684668

46694669

4670+
class KandinskyLoraLoaderMixin(LoraBaseMixin):
4671+
r"""
4672+
Load LoRA layers into [`Kandinsky5Transformer3DModel`],
4673+
"""
4674+
4675+
_lora_loadable_modules = ["transformer"]
4676+
transformer_name = TRANSFORMER_NAME
4677+
4678+
@classmethod
4679+
@validate_hf_hub_args
4680+
def lora_state_dict(
4681+
cls,
4682+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, ms.Tensor]],
4683+
**kwargs,
4684+
):
4685+
r"""
4686+
Return state dict for lora weights and the network alphas.
4687+
4688+
Parameters:
4689+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4690+
Can be either:
4691+
- A string, the *model id* of a pretrained model hosted on the Hub.
4692+
- A path to a *directory* containing the model weights.
4693+
- A [mindspore state dict].
4694+
4695+
cache_dir (`Union[str, os.PathLike]`, *optional*):
4696+
Path to a directory where a downloaded pretrained model configuration is cached.
4697+
force_download (`bool`, *optional*, defaults to `False`):
4698+
Whether or not to force the (re-)download of the model weights.
4699+
proxies (`Dict[str, str]`, *optional*):
4700+
A dictionary of proxy servers to use by protocol or endpoint.
4701+
local_files_only (`bool`, *optional*, defaults to `False`):
4702+
Whether to only load local model weights and configuration files.
4703+
token (`str` or *bool*, *optional*):
4704+
The token to use as HTTP bearer authorization for remote files.
4705+
revision (`str`, *optional*, defaults to `"main"`):
4706+
The specific model version to use.
4707+
subfolder (`str`, *optional*, defaults to `""`):
4708+
The subfolder location of a model file within a larger model repository.
4709+
weight_name (`str`, *optional*, defaults to None):
4710+
Name of the serialized state dict file.
4711+
use_safetensors (`bool`, *optional*):
4712+
Whether to use safetensors for loading.
4713+
return_lora_metadata (`bool`, *optional*, defaults to False):
4714+
When enabled, additionally return the LoRA adapter metadata.
4715+
"""
4716+
# Load the main state dict first which has the LoRA layers
4717+
cache_dir = kwargs.pop("cache_dir", None)
4718+
force_download = kwargs.pop("force_download", False)
4719+
proxies = kwargs.pop("proxies", None)
4720+
local_files_only = kwargs.pop("local_files_only", None)
4721+
token = kwargs.pop("token", None)
4722+
revision = kwargs.pop("revision", None)
4723+
subfolder = kwargs.pop("subfolder", None)
4724+
weight_name = kwargs.pop("weight_name", None)
4725+
use_safetensors = kwargs.pop("use_safetensors", None)
4726+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
4727+
4728+
allow_pickle = False
4729+
if use_safetensors is None:
4730+
use_safetensors = True
4731+
allow_pickle = True
4732+
4733+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
4734+
4735+
state_dict, metadata = _fetch_state_dict(
4736+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
4737+
weight_name=weight_name,
4738+
use_safetensors=use_safetensors,
4739+
local_files_only=local_files_only,
4740+
cache_dir=cache_dir,
4741+
force_download=force_download,
4742+
proxies=proxies,
4743+
token=token,
4744+
revision=revision,
4745+
subfolder=subfolder,
4746+
user_agent=user_agent,
4747+
allow_pickle=allow_pickle,
4748+
)
4749+
4750+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
4751+
if is_dora_scale_present:
4752+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." # noqa
4753+
logger.warning(warn_msg)
4754+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4755+
4756+
out = (state_dict, metadata) if return_lora_metadata else state_dict
4757+
return out
4758+
4759+
def load_lora_weights(
4760+
self,
4761+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, ms.Tensor]],
4762+
adapter_name: Optional[str] = None,
4763+
hotswap: bool = False,
4764+
**kwargs,
4765+
):
4766+
"""
4767+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
4768+
4769+
Parameters:
4770+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4771+
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
4772+
adapter_name (`str`, *optional*):
4773+
Adapter name to be used for referencing the loaded adapter model.
4774+
hotswap (`bool`, *optional*):
4775+
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
4776+
kwargs (`dict`, *optional*):
4777+
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
4778+
"""
4779+
# if a dict is passed, copy it instead of modifying it inplace
4780+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
4781+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4782+
4783+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4784+
kwargs["return_lora_metadata"] = True
4785+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4786+
4787+
is_correct_format = all("lora" in key for key in state_dict.keys())
4788+
if not is_correct_format:
4789+
raise ValueError("Invalid LoRA checkpoint.")
4790+
4791+
# Load LoRA into transformer
4792+
self.load_lora_into_transformer(
4793+
state_dict,
4794+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4795+
adapter_name=adapter_name,
4796+
metadata=metadata,
4797+
_pipeline=self,
4798+
hotswap=hotswap,
4799+
)
4800+
4801+
@classmethod
4802+
def load_lora_into_transformer(
4803+
cls,
4804+
state_dict,
4805+
transformer,
4806+
adapter_name=None,
4807+
_pipeline=None,
4808+
hotswap: bool = False,
4809+
metadata=None,
4810+
):
4811+
"""
4812+
Load the LoRA layers specified in `state_dict` into `transformer`.
4813+
4814+
Parameters:
4815+
state_dict (`dict`):
4816+
A standard state dict containing the lora layer parameters.
4817+
transformer (`Kandinsky5Transformer3DModel`):
4818+
The transformer model to load the LoRA layers into.
4819+
adapter_name (`str`, *optional*):
4820+
Adapter name to be used for referencing the loaded adapter model.
4821+
hotswap (`bool`, *optional*):
4822+
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
4823+
metadata (`dict`):
4824+
Optional LoRA adapter metadata.
4825+
"""
4826+
# Load the layers corresponding to transformer.
4827+
logger.info(f"Loading {cls.transformer_name}.")
4828+
transformer.load_lora_adapter(
4829+
state_dict,
4830+
network_alphas=None,
4831+
adapter_name=adapter_name,
4832+
metadata=metadata,
4833+
_pipeline=_pipeline,
4834+
hotswap=hotswap,
4835+
)
4836+
4837+
@classmethod
4838+
def save_lora_weights(
4839+
cls,
4840+
save_directory: Union[str, os.PathLike],
4841+
transformer_lora_layers: Dict[str, Union[ms.nn.Cell, ms.Tensor]] = None,
4842+
is_main_process: bool = True,
4843+
weight_name: str = None,
4844+
save_function: Callable = None,
4845+
safe_serialization: bool = True,
4846+
transformer_lora_adapter_metadata=None,
4847+
):
4848+
r"""
4849+
Save the LoRA parameters corresponding to the transformer and text encoders.
4850+
4851+
Arguments:
4852+
save_directory (`str` or `os.PathLike`):
4853+
Directory to save LoRA parameters to.
4854+
transformer_lora_layers (`Dict[str, ms.nn.Cell]` or `Dict[str, ms.Tensor]`):
4855+
State dict of the LoRA layers corresponding to the `transformer`.
4856+
is_main_process (`bool`, *optional*, defaults to `True`):
4857+
Whether the process calling this is the main process.
4858+
save_function (`Callable`):
4859+
The function to use to save the state dictionary.
4860+
safe_serialization (`bool`, *optional*, defaults to `True`):
4861+
Whether to save the model using `safetensors` or the traditional PyTorch way.
4862+
transformer_lora_adapter_metadata:
4863+
LoRA adapter metadata associated with the transformer.
4864+
"""
4865+
lora_layers = {}
4866+
lora_metadata = {}
4867+
4868+
if transformer_lora_layers:
4869+
lora_layers[cls.transformer_name] = transformer_lora_layers
4870+
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
4871+
4872+
if not lora_layers:
4873+
raise ValueError("You must pass at least one of `transformer_lora_layers`")
4874+
4875+
cls._save_lora_weights(
4876+
save_directory=save_directory,
4877+
lora_layers=lora_layers,
4878+
lora_metadata=lora_metadata,
4879+
is_main_process=is_main_process,
4880+
weight_name=weight_name,
4881+
save_function=save_function,
4882+
safe_serialization=safe_serialization,
4883+
)
4884+
4885+
def fuse_lora(
4886+
self,
4887+
components: List[str] = ["transformer"],
4888+
lora_scale: float = 1.0,
4889+
safe_fusing: bool = False,
4890+
adapter_names: Optional[List[str]] = None,
4891+
**kwargs,
4892+
):
4893+
r"""
4894+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4895+
4896+
Args:
4897+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4898+
lora_scale (`float`, defaults to 1.0):
4899+
Controls how much to influence the outputs with the LoRA parameters.
4900+
safe_fusing (`bool`, defaults to `False`):
4901+
Whether to check fused weights for NaN values before fusing.
4902+
adapter_names (`List[str]`, *optional*):
4903+
Adapter names to be used for fusing.
4904+
4905+
Example:
4906+
```py
4907+
from mindone.diffusers import Kandinsky5T2VPipeline
4908+
4909+
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
4910+
pipeline.load_lora_weights("path/to/lora.safetensors")
4911+
pipeline.fuse_lora(lora_scale=0.7)
4912+
```
4913+
"""
4914+
super().fuse_lora(
4915+
components=components,
4916+
lora_scale=lora_scale,
4917+
safe_fusing=safe_fusing,
4918+
adapter_names=adapter_names,
4919+
**kwargs,
4920+
)
4921+
4922+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
4923+
r"""
4924+
Reverses the effect of [`pipe.fuse_lora()`].
4925+
4926+
Args:
4927+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4928+
"""
4929+
super().unfuse_lora(components=components, **kwargs)
4930+
4931+
46704932
class WanLoraLoaderMixin(LoraBaseMixin):
46714933
r"""
46724934
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].

mindone/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
"transformers.transformer_hidream_image": ["HiDreamImageTransformer2DModel"],
8181
"transformers.transformer_hunyuan_video": ["HunyuanVideoTransformer3DModel"],
8282
"transformers.transformer_hunyuan_video_framepack": ["HunyuanVideoFramepackTransformer3DModel"],
83+
"transformers.transformer_kandinsky": ["Kandinsky5Transformer3DModel"],
8384
"transformers.transformer_ltx": ["LTXVideoTransformer3DModel"],
8485
"transformers.transformer_lumina2": ["Lumina2Transformer2DModel"],
8586
"transformers.transformer_mochi": ["MochiTransformer3DModel"],
@@ -161,6 +162,7 @@
161162
HunyuanDiT2DModel,
162163
HunyuanVideoFramepackTransformer3DModel,
163164
HunyuanVideoTransformer3DModel,
165+
Kandinsky5Transformer3DModel,
164166
LatteTransformer3DModel,
165167
LTXVideoTransformer3DModel,
166168
Lumina2Transformer2DModel,

0 commit comments

Comments
 (0)