Skip to content

set_lora_device reports KeyError even though lora is loaded #11833

Closed
@hrazjan

Description

@hrazjan

Describe the bug

Goal

Build an image-generation service with StableDiffusionXLPipeline that:

  1. Keeps large number of LoRA adapters in CPU RAM
  2. For each request:
    • load ≤ 5 specific LoRAs into cuda using pipeline.set_lora_device
    • run inference
    • offload loras back to cpu.

Issue

The pipeline.set_lora_device method crashes with KeyError (see logs below)

What I've tried

  • Even if I load/offload only one LoRA at a time it still crashes with KeyError
  • Installing diffusers from main

Reproduction

import random
import traceback
from typing import Dict, List
from pydantic import BaseModel
from diffusers import StableDiffusionXLPipeline, AutoencoderTiny
import torch
from diffusers.utils import logging
logging.disable_progress_bar()

pipeline = None
loaded_loras = []
lora_paths: Dict[str, str] = {
    "cereal": "ostris/super-cereal-sdxl-lora",
    "ikea": "ostris/ikea-instructions-lora-sdxl",
    "feng": "lordjia/by-feng-zikai",
}
class Lora(BaseModel):
    name: str
    strength: float

def set_lora_device(pipeline, lora_names: List[str], device="cuda"):
    print(f"[LoRA] Setting LoRA device to {device} for adapters: {lora_names}")
    pipeline.set_lora_device(adapter_names=lora_names, device=device)

def load_loras(pipeline: StableDiffusionXLPipeline, loras: List[Lora]):
    for lora in loras:
        if lora.name in loaded_loras:
            print(f"[LoRA] {lora.name} already loaded.")
            continue

        print(f"[LoRA] Loading new lora: {lora.name} (strength={lora.strength})")
        try:
            print(f"[LoRA] Applying {lora.name} (strength={lora.strength})")
            pipeline.load_lora_weights(
                pretrained_model_name_or_path_or_dict=lora_paths[lora.name],
                adapter_name=lora.name,
            )
            loaded_loras.append(lora.name)
        except Exception:
            print(f"[LoRA] Failed to load {lora.name} (strength={lora.strength})")
            print(traceback.format_exc(limit=5))
            continue

    set_lora_device(pipeline, lora_names=[lora.name for lora in loras], device="cuda")

def generate_images(pipeline, loras: List[Lora], prompt: str) -> None:
    pipeline.enable_lora()
    load_loras(pipeline, loras)
    pipeline.set_adapters(
        adapter_names=[lora.name for lora in loras],
        adapter_weights=[lora.strength for lora in loras],
    )

    images = pipeline(
        prompt=prompt,
        negative_prompt="",
        width=256,
        height=256,
        num_inference_steps=10,
        guidance_scale=7,
        generator=torch.Generator(device="cuda").manual_seed(42),
    ).images

    # Offload LoRA to CPU after generation
    set_lora_device(pipeline, [lora.name for lora in loras], device="cpu")
    
    return images

def test_lora_group(pipeline, lora_group: List[Lora]):    
    try:
        generate_images(
            pipeline=pipeline, 
            loras=[Lora(name=lora_name, strength=0.8) for lora_name in lora_group], 
            prompt="a simple test image" + str(random.randint(0, 1000))
        )
    except Exception as e:
        print(traceback.format_exc(limit=5))
        print(f"Error testing LoRA group {lora_group}: {e}")

if __name__ == "__main__":
    print("Loading pipeline for LoRA tests...")
    pipeline = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        vae=AutoencoderTiny.from_pretrained(
            'madebyollin/taesdxl',
            use_safetensors=True,
            torch_dtype=torch.float16,
        )
    ).to("cuda")
    pipeline.set_progress_bar_config(disable=True)
    
    lora_groups = [[name for name in lora_paths.keys()]] * 3
    for i, lora_group in enumerate(lora_groups, 1):
        print(f"Testing group {i}/{len(lora_groups)}: {lora_group}")
        test_lora_group(pipeline, lora_group)

Logs

Testing group 1/3: ['cereal', 'ikea', 'feng']
[LoRA] Loading new lora: cereal (strength=0.8)
[LoRA] Applying cereal (strength=0.8)
[LoRA] Loading new lora: ikea (strength=0.8)
[LoRA] Applying ikea (strength=0.8)
[LoRA] Loading new lora: feng (strength=0.8)
[LoRA] Applying feng (strength=0.8)
[LoRA] Setting LoRA device to cuda for adapters: []
[LoRA] Setting LoRA device to cpu for adapters: ['cereal', 'ikea', 'feng']
Traceback (most recent call last):
  File "/aibabe-diffusers/aibabe_diffusers/main.py", line 75, in test_lora_group
    generate_images(
  File "/aibabe-diffusers/aibabe_diffusers/main.py", line 69, in generate_images
    set_lora_device(pipeline, [lora.name for lora in loras], device="cpu")
  File "/aibabe-diffusers/aibabe_diffusers/main.py", line 23, in set_lora_device
    pipeline.set_lora_device(adapter_names=lora_names, device=device)
  File "/aibabe-diffusers/worker-env/lib/python3.11/site-packages/diffusers/loaders/lora_base.py", line 952, in set_lora_device
    module.lora_A[adapter_name].to(device)
    ~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "/aibabe-diffusers/worker-env/lib/python3.11/site-packages/torch/nn/modules/container.py", line 492, in __getitem__
    return self._modules[key]
           ~~~~~~~~~~~~~^^^^^
KeyError: 'cereal'
Error testing LoRA group ['cereal', 'ikea', 'feng']: 'cereal'
Testing group 2/3: ['cereal', 'ikea', 'feng']
[LoRA] cereal already loaded.
[LoRA] ikea already loaded.
[LoRA] feng already loaded.
[LoRA] Setting LoRA device to cuda for adapters: ['cereal', 'ikea', 'feng']
Traceback (most recent call last):
  File "/aibabe-diffusers/aibabe_diffusers/main.py", line 75, in test_lora_group
    generate_images(
  File "/aibabe-diffusers/aibabe_diffusers/main.py", line 52, in generate_images
    load_loras(pipeline, loras)
  File "/aibabe-diffusers/aibabe_diffusers/main.py", line 48, in load_loras
    set_lora_device(pipeline, lora_names=switch_device, device="cuda")
  File "/aibabe-diffusers/aibabe_diffusers/main.py", line 23, in set_lora_device
    pipeline.set_lora_device(adapter_names=lora_names, device=device)
  File "/aibabe-diffusers/worker-env/lib/python3.11/site-packages/diffusers/loaders/lora_base.py", line 952, in set_lora_device
    module.lora_A[adapter_name].to(device)
    ~~~~~~~~~~~~~^^^^^^^^^^^^^^
KeyError: 'cereal'

System Info

  • 🤗 Diffusers version: 0.35.0.dev0
  • Platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.7.1+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.33.1
  • Transformers version: 4.53.0
  • Accelerate version: 1.8.1
  • PEFT version: 0.15.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 5090, 32607 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@sayakpaul @BenjaminBossan

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions