diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 5b12c3aca84d..f4b5223866d5 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1789,12 +1789,58 @@ def get_alpha_scales(down_weight, key): return converted_state_dict -def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"): - if not all(k.startswith(non_diffusers_prefix) for k in state_dict): - raise ValueError("Invalid LoRA state dict for HiDream.") - converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()} - converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} - return converted_state_dict +def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict): + non_diffusers_prefix = "diffusion_model" + is_kohya = all(k.startswith(f"{non_diffusers_prefix}.") for k in state_dict) + + def _convert_kohya(state_dict): + converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()} + converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} + return converted_state_dict + + if is_kohya: + return _convert_kohya(state_dict) + + else: + assert any(k.startswith(("clip_g.", "clip_l.", "t5.", "llama.", "transformer.")) for k in state_dict) + converted_state_dict = {} + component = "transformer" + compoent_sd = {k: v for k, v in state_dict.items() if k.startswith(f"{component}.")} + + def _convert_omi(key, state_dict, component): + down_key = f"{key}.lora_down.weight" + down_weight = state_dict.pop(down_key) + lora_rank = down_weight.shape[0] + + up_weight_key = f"{key}.lora_up.weight" + up_weight = state_dict.pop(up_weight_key) + + alpha_key = f"{key}.alpha" + alpha = state_dict.pop(alpha_key) + + # scale weight by alpha and dim + scale = alpha / lora_rank + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + diffusers_down_key = f"{key}.lora_A.weight" + converted_state_dict[f"{component}.{diffusers_down_key}"] = down_weight + converted_state_dict[f"{component}.{diffusers_down_key.replace('.lora_A.', '.lora_B.')}"] = up_weight + + all_unique_keys = { + k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") + for k in compoent_sd + } + for k in all_unique_keys: + _convert_omi(k, compoent_sd, component=component) + + return converted_state_dict def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"): diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 6092eeff0a80..d0617437cbd8 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5489,7 +5489,9 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - is_non_diffusers_format = any("diffusion_model" in k for k in state_dict) + kohya_format = any("diffusion_model" in k for k in state_dict) + is_omi_format = any(k.startswith(("clip_g.", "clip_l.", "t5.", "llama.", "transformer.")) for k in state_dict) + is_non_diffusers_format = kohya_format or is_omi_format if is_non_diffusers_format: state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)