Skip to content

Commit

Permalink
Fix regression with FLUX diffusers LoRA models where lora keys were n…
Browse files Browse the repository at this point in the history
…ot given the expected prefix.
  • Loading branch information
RyanJDick authored and hipsterusername committed Oct 1, 2024
1 parent bd3d1dc commit 68dbe45
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch

from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
Expand Down Expand Up @@ -189,7 +190,9 @@ def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None
# Assert that all keys were processed.
assert len(grouped_state_dict) == 0

return LoRAModelRaw(layers=layers)
layers_with_prefix = {f"{FLUX_KOHYA_TRANFORMER_PREFIX}{k}": v for k, v in layers.items()}

return LoRAModelRaw(layers=layers_with_prefix)


def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
is_state_dict_likely_in_flux_diffusers_format,
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
Expand Down Expand Up @@ -50,6 +51,7 @@ def test_lora_model_from_flux_diffusers_state_dict():
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
assert len(model.layers) == len(expected_lora_layers)
assert all(k.startswith(FLUX_KOHYA_TRANFORMER_PREFIX) for k in model.layers.keys())


def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
Expand Down

0 comments on commit 68dbe45

Please sign in to comment.