diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 9e24e5ebfd7..ce745cae41b 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -30,7 +30,7 @@ pack, unpack, ) -from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX +from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.model_manager.config import ModelFormat @@ -209,7 +209,7 @@ def _run_diffusion( LoRAPatcher.apply_lora_patches( model=transformer, patches=self._lora_iterator(context), - prefix=FLUX_KOHYA_TRANFORMER_PREFIX, + prefix=FLUX_LORA_TRANSFORMER_PREFIX, cached_weights=cached_weights, ) ) @@ -220,7 +220,7 @@ def _run_diffusion( LoRAPatcher.apply_lora_sidecar_patches( model=transformer, patches=self._lora_iterator(context), - prefix=FLUX_KOHYA_TRANFORMER_PREFIX, + prefix=FLUX_LORA_TRANSFORMER_PREFIX, dtype=inference_dtype, ) ) diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index a306a8aa95e..209988ce807 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -10,7 +10,7 @@ from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.modules.conditioner import HFEncoder -from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_CLIP_PREFIX +from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.model_manager.config import ModelFormat @@ -101,7 +101,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor: LoRAPatcher.apply_lora_patches( model=clip_text_encoder, patches=self._clip_lora_iterator(context), - prefix=FLUX_KOHYA_CLIP_PREFIX, + prefix=FLUX_LORA_CLIP_PREFIX, cached_weights=cached_weights, ) ) diff --git a/invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py b/invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py index 4f0335af3ac..67e393f23ce 100644 --- a/invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py +++ b/invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py @@ -2,7 +2,7 @@ import torch -from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX +from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer from invokeai.backend.lora.layers.lora_layer import LoRALayer @@ -190,7 +190,7 @@ def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None # Assert that all keys were processed. assert len(grouped_state_dict) == 0 - layers_with_prefix = {f"{FLUX_KOHYA_TRANFORMER_PREFIX}{k}": v for k, v in layers.items()} + layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()} return LoRAModelRaw(layers=layers_with_prefix) diff --git a/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py index 94df3d5567d..bc2515cfccd 100644 --- a/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py +++ b/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py @@ -3,6 +3,7 @@ import torch +from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict from invokeai.backend.lora.lora_model_raw import LoRAModelRaw @@ -23,11 +24,6 @@ FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*" -# Prefixes used to distinguish between transformer and CLIP text encoder keys in the InvokeAI LoRA format. -FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-" -FLUX_KOHYA_CLIP_PREFIX = "lora_clip-" - - def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool: """Checks if the provided state dict is likely in the Kohya FLUX LoRA format. @@ -67,9 +63,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) - # Create LoRA layers. layers: dict[str, AnyLoRALayer] = {} for layer_key, layer_state_dict in transformer_grouped_sd.items(): - layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) + layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) for layer_key, layer_state_dict in clip_grouped_sd.items(): - layers[FLUX_KOHYA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) + layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) # Create and return the LoRAModelRaw. return LoRAModelRaw(layers=layers) diff --git a/invokeai/backend/lora/conversions/flux_lora_constants.py b/invokeai/backend/lora/conversions/flux_lora_constants.py new file mode 100644 index 00000000000..4f854d14421 --- /dev/null +++ b/invokeai/backend/lora/conversions/flux_lora_constants.py @@ -0,0 +1,3 @@ +# Prefixes used to distinguish between transformer and CLIP text encoder keys in the FLUX InvokeAI LoRA format. +FLUX_LORA_TRANSFORMER_PREFIX = "lora_transformer-" +FLUX_LORA_CLIP_PREFIX = "lora_clip-" diff --git a/tests/backend/lora/conversions/test_flux_diffusers_lora_conversion_utils.py b/tests/backend/lora/conversions/test_flux_diffusers_lora_conversion_utils.py index 7f23fb22388..9b02a04f766 100644 --- a/tests/backend/lora/conversions/test_flux_diffusers_lora_conversion_utils.py +++ b/tests/backend/lora/conversions/test_flux_diffusers_lora_conversion_utils.py @@ -5,7 +5,7 @@ is_state_dict_likely_in_flux_diffusers_format, lora_model_from_flux_diffusers_state_dict, ) -from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX +from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import ( state_dict_keys as flux_diffusers_state_dict_keys, ) @@ -51,7 +51,7 @@ def test_lora_model_from_flux_diffusers_state_dict(): concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"] expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)} assert len(model.layers) == len(expected_lora_layers) - assert all(k.startswith(FLUX_KOHYA_TRANFORMER_PREFIX) for k in model.layers.keys()) + assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys()) def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error(): diff --git a/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py b/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py index 7878137130d..1c947884cfd 100644 --- a/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py +++ b/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py @@ -5,12 +5,11 @@ from invokeai.backend.flux.model import Flux from invokeai.backend.flux.util import params from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import ( - FLUX_KOHYA_CLIP_PREFIX, - FLUX_KOHYA_TRANFORMER_PREFIX, _convert_flux_transformer_kohya_state_dict_to_invoke_format, is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict, ) +from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import ( state_dict_keys as flux_diffusers_state_dict_keys, ) @@ -95,8 +94,8 @@ def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]): expected_layer_keys: set[str] = set() for k in sd_keys: # Replace prefixes. - k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX) - k = k.replace("lora_te1_", FLUX_KOHYA_CLIP_PREFIX) + k = k.replace("lora_unet_", FLUX_LORA_TRANSFORMER_PREFIX) + k = k.replace("lora_te1_", FLUX_LORA_CLIP_PREFIX) # Remove suffixes. k = k.replace(".lora_up.weight", "") k = k.replace(".lora_down.weight", "")