@@ -4667,6 +4667,268 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
46674667 super ().unfuse_lora (components = components , ** kwargs )
46684668
46694669
4670+ class KandinskyLoraLoaderMixin (LoraBaseMixin ):
4671+ r"""
4672+ Load LoRA layers into [`Kandinsky5Transformer3DModel`],
4673+ """
4674+
4675+ _lora_loadable_modules = ["transformer" ]
4676+ transformer_name = TRANSFORMER_NAME
4677+
4678+ @classmethod
4679+ @validate_hf_hub_args
4680+ def lora_state_dict (
4681+ cls ,
4682+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , ms .Tensor ]],
4683+ ** kwargs ,
4684+ ):
4685+ r"""
4686+ Return state dict for lora weights and the network alphas.
4687+
4688+ Parameters:
4689+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4690+ Can be either:
4691+ - A string, the *model id* of a pretrained model hosted on the Hub.
4692+ - A path to a *directory* containing the model weights.
4693+ - A [mindspore state dict].
4694+
4695+ cache_dir (`Union[str, os.PathLike]`, *optional*):
4696+ Path to a directory where a downloaded pretrained model configuration is cached.
4697+ force_download (`bool`, *optional*, defaults to `False`):
4698+ Whether or not to force the (re-)download of the model weights.
4699+ proxies (`Dict[str, str]`, *optional*):
4700+ A dictionary of proxy servers to use by protocol or endpoint.
4701+ local_files_only (`bool`, *optional*, defaults to `False`):
4702+ Whether to only load local model weights and configuration files.
4703+ token (`str` or *bool*, *optional*):
4704+ The token to use as HTTP bearer authorization for remote files.
4705+ revision (`str`, *optional*, defaults to `"main"`):
4706+ The specific model version to use.
4707+ subfolder (`str`, *optional*, defaults to `""`):
4708+ The subfolder location of a model file within a larger model repository.
4709+ weight_name (`str`, *optional*, defaults to None):
4710+ Name of the serialized state dict file.
4711+ use_safetensors (`bool`, *optional*):
4712+ Whether to use safetensors for loading.
4713+ return_lora_metadata (`bool`, *optional*, defaults to False):
4714+ When enabled, additionally return the LoRA adapter metadata.
4715+ """
4716+ # Load the main state dict first which has the LoRA layers
4717+ cache_dir = kwargs .pop ("cache_dir" , None )
4718+ force_download = kwargs .pop ("force_download" , False )
4719+ proxies = kwargs .pop ("proxies" , None )
4720+ local_files_only = kwargs .pop ("local_files_only" , None )
4721+ token = kwargs .pop ("token" , None )
4722+ revision = kwargs .pop ("revision" , None )
4723+ subfolder = kwargs .pop ("subfolder" , None )
4724+ weight_name = kwargs .pop ("weight_name" , None )
4725+ use_safetensors = kwargs .pop ("use_safetensors" , None )
4726+ return_lora_metadata = kwargs .pop ("return_lora_metadata" , False )
4727+
4728+ allow_pickle = False
4729+ if use_safetensors is None :
4730+ use_safetensors = True
4731+ allow_pickle = True
4732+
4733+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
4734+
4735+ state_dict , metadata = _fetch_state_dict (
4736+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
4737+ weight_name = weight_name ,
4738+ use_safetensors = use_safetensors ,
4739+ local_files_only = local_files_only ,
4740+ cache_dir = cache_dir ,
4741+ force_download = force_download ,
4742+ proxies = proxies ,
4743+ token = token ,
4744+ revision = revision ,
4745+ subfolder = subfolder ,
4746+ user_agent = user_agent ,
4747+ allow_pickle = allow_pickle ,
4748+ )
4749+
4750+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
4751+ if is_dora_scale_present :
4752+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." # noqa
4753+ logger .warning (warn_msg )
4754+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
4755+
4756+ out = (state_dict , metadata ) if return_lora_metadata else state_dict
4757+ return out
4758+
4759+ def load_lora_weights (
4760+ self ,
4761+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , ms .Tensor ]],
4762+ adapter_name : Optional [str ] = None ,
4763+ hotswap : bool = False ,
4764+ ** kwargs ,
4765+ ):
4766+ """
4767+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
4768+
4769+ Parameters:
4770+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4771+ See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
4772+ adapter_name (`str`, *optional*):
4773+ Adapter name to be used for referencing the loaded adapter model.
4774+ hotswap (`bool`, *optional*):
4775+ Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
4776+ kwargs (`dict`, *optional*):
4777+ See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
4778+ """
4779+ # if a dict is passed, copy it instead of modifying it inplace
4780+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
4781+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
4782+
4783+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4784+ kwargs ["return_lora_metadata" ] = True
4785+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
4786+
4787+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
4788+ if not is_correct_format :
4789+ raise ValueError ("Invalid LoRA checkpoint." )
4790+
4791+ # Load LoRA into transformer
4792+ self .load_lora_into_transformer (
4793+ state_dict ,
4794+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
4795+ adapter_name = adapter_name ,
4796+ metadata = metadata ,
4797+ _pipeline = self ,
4798+ hotswap = hotswap ,
4799+ )
4800+
4801+ @classmethod
4802+ def load_lora_into_transformer (
4803+ cls ,
4804+ state_dict ,
4805+ transformer ,
4806+ adapter_name = None ,
4807+ _pipeline = None ,
4808+ hotswap : bool = False ,
4809+ metadata = None ,
4810+ ):
4811+ """
4812+ Load the LoRA layers specified in `state_dict` into `transformer`.
4813+
4814+ Parameters:
4815+ state_dict (`dict`):
4816+ A standard state dict containing the lora layer parameters.
4817+ transformer (`Kandinsky5Transformer3DModel`):
4818+ The transformer model to load the LoRA layers into.
4819+ adapter_name (`str`, *optional*):
4820+ Adapter name to be used for referencing the loaded adapter model.
4821+ hotswap (`bool`, *optional*):
4822+ See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
4823+ metadata (`dict`):
4824+ Optional LoRA adapter metadata.
4825+ """
4826+ # Load the layers corresponding to transformer.
4827+ logger .info (f"Loading { cls .transformer_name } ." )
4828+ transformer .load_lora_adapter (
4829+ state_dict ,
4830+ network_alphas = None ,
4831+ adapter_name = adapter_name ,
4832+ metadata = metadata ,
4833+ _pipeline = _pipeline ,
4834+ hotswap = hotswap ,
4835+ )
4836+
4837+ @classmethod
4838+ def save_lora_weights (
4839+ cls ,
4840+ save_directory : Union [str , os .PathLike ],
4841+ transformer_lora_layers : Dict [str , Union [ms .nn .Cell , ms .Tensor ]] = None ,
4842+ is_main_process : bool = True ,
4843+ weight_name : str = None ,
4844+ save_function : Callable = None ,
4845+ safe_serialization : bool = True ,
4846+ transformer_lora_adapter_metadata = None ,
4847+ ):
4848+ r"""
4849+ Save the LoRA parameters corresponding to the transformer and text encoders.
4850+
4851+ Arguments:
4852+ save_directory (`str` or `os.PathLike`):
4853+ Directory to save LoRA parameters to.
4854+ transformer_lora_layers (`Dict[str, ms.nn.Cell]` or `Dict[str, ms.Tensor]`):
4855+ State dict of the LoRA layers corresponding to the `transformer`.
4856+ is_main_process (`bool`, *optional*, defaults to `True`):
4857+ Whether the process calling this is the main process.
4858+ save_function (`Callable`):
4859+ The function to use to save the state dictionary.
4860+ safe_serialization (`bool`, *optional*, defaults to `True`):
4861+ Whether to save the model using `safetensors` or the traditional PyTorch way.
4862+ transformer_lora_adapter_metadata:
4863+ LoRA adapter metadata associated with the transformer.
4864+ """
4865+ lora_layers = {}
4866+ lora_metadata = {}
4867+
4868+ if transformer_lora_layers :
4869+ lora_layers [cls .transformer_name ] = transformer_lora_layers
4870+ lora_metadata [cls .transformer_name ] = transformer_lora_adapter_metadata
4871+
4872+ if not lora_layers :
4873+ raise ValueError ("You must pass at least one of `transformer_lora_layers`" )
4874+
4875+ cls ._save_lora_weights (
4876+ save_directory = save_directory ,
4877+ lora_layers = lora_layers ,
4878+ lora_metadata = lora_metadata ,
4879+ is_main_process = is_main_process ,
4880+ weight_name = weight_name ,
4881+ save_function = save_function ,
4882+ safe_serialization = safe_serialization ,
4883+ )
4884+
4885+ def fuse_lora (
4886+ self ,
4887+ components : List [str ] = ["transformer" ],
4888+ lora_scale : float = 1.0 ,
4889+ safe_fusing : bool = False ,
4890+ adapter_names : Optional [List [str ]] = None ,
4891+ ** kwargs ,
4892+ ):
4893+ r"""
4894+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4895+
4896+ Args:
4897+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4898+ lora_scale (`float`, defaults to 1.0):
4899+ Controls how much to influence the outputs with the LoRA parameters.
4900+ safe_fusing (`bool`, defaults to `False`):
4901+ Whether to check fused weights for NaN values before fusing.
4902+ adapter_names (`List[str]`, *optional*):
4903+ Adapter names to be used for fusing.
4904+
4905+ Example:
4906+ ```py
4907+ from mindone.diffusers import Kandinsky5T2VPipeline
4908+
4909+ pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
4910+ pipeline.load_lora_weights("path/to/lora.safetensors")
4911+ pipeline.fuse_lora(lora_scale=0.7)
4912+ ```
4913+ """
4914+ super ().fuse_lora (
4915+ components = components ,
4916+ lora_scale = lora_scale ,
4917+ safe_fusing = safe_fusing ,
4918+ adapter_names = adapter_names ,
4919+ ** kwargs ,
4920+ )
4921+
4922+ def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
4923+ r"""
4924+ Reverses the effect of [`pipe.fuse_lora()`].
4925+
4926+ Args:
4927+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4928+ """
4929+ super ().unfuse_lora (components = components , ** kwargs )
4930+
4931+
46704932class WanLoraLoaderMixin (LoraBaseMixin ):
46714933 r"""
46724934 Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
0 commit comments