Closed
Description
Describe the bug
Goal
Build an image-generation service with StableDiffusionXLPipeline
that:
- Keeps large number of LoRA adapters in CPU RAM
- For each request:
• load ≤ 5 specific LoRAs into cuda usingpipeline.set_lora_device
• run inference
• offload loras back to cpu.
Issue
The pipeline.set_lora_device
method crashes with KeyError (see logs below)
What I've tried
- Even if I load/offload only one LoRA at a time it still crashes with KeyError
- Installing diffusers from main
Reproduction
import random
import traceback
from typing import Dict, List
from pydantic import BaseModel
from diffusers import StableDiffusionXLPipeline, AutoencoderTiny
import torch
from diffusers.utils import logging
logging.disable_progress_bar()
pipeline = None
loaded_loras = []
lora_paths: Dict[str, str] = {
"cereal": "ostris/super-cereal-sdxl-lora",
"ikea": "ostris/ikea-instructions-lora-sdxl",
"feng": "lordjia/by-feng-zikai",
}
class Lora(BaseModel):
name: str
strength: float
def set_lora_device(pipeline, lora_names: List[str], device="cuda"):
print(f"[LoRA] Setting LoRA device to {device} for adapters: {lora_names}")
pipeline.set_lora_device(adapter_names=lora_names, device=device)
def load_loras(pipeline: StableDiffusionXLPipeline, loras: List[Lora]):
for lora in loras:
if lora.name in loaded_loras:
print(f"[LoRA] {lora.name} already loaded.")
continue
print(f"[LoRA] Loading new lora: {lora.name} (strength={lora.strength})")
try:
print(f"[LoRA] Applying {lora.name} (strength={lora.strength})")
pipeline.load_lora_weights(
pretrained_model_name_or_path_or_dict=lora_paths[lora.name],
adapter_name=lora.name,
)
loaded_loras.append(lora.name)
except Exception:
print(f"[LoRA] Failed to load {lora.name} (strength={lora.strength})")
print(traceback.format_exc(limit=5))
continue
set_lora_device(pipeline, lora_names=[lora.name for lora in loras], device="cuda")
def generate_images(pipeline, loras: List[Lora], prompt: str) -> None:
pipeline.enable_lora()
load_loras(pipeline, loras)
pipeline.set_adapters(
adapter_names=[lora.name for lora in loras],
adapter_weights=[lora.strength for lora in loras],
)
images = pipeline(
prompt=prompt,
negative_prompt="",
width=256,
height=256,
num_inference_steps=10,
guidance_scale=7,
generator=torch.Generator(device="cuda").manual_seed(42),
).images
# Offload LoRA to CPU after generation
set_lora_device(pipeline, [lora.name for lora in loras], device="cpu")
return images
def test_lora_group(pipeline, lora_group: List[Lora]):
try:
generate_images(
pipeline=pipeline,
loras=[Lora(name=lora_name, strength=0.8) for lora_name in lora_group],
prompt="a simple test image" + str(random.randint(0, 1000))
)
except Exception as e:
print(traceback.format_exc(limit=5))
print(f"Error testing LoRA group {lora_group}: {e}")
if __name__ == "__main__":
print("Loading pipeline for LoRA tests...")
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
vae=AutoencoderTiny.from_pretrained(
'madebyollin/taesdxl',
use_safetensors=True,
torch_dtype=torch.float16,
)
).to("cuda")
pipeline.set_progress_bar_config(disable=True)
lora_groups = [[name for name in lora_paths.keys()]] * 3
for i, lora_group in enumerate(lora_groups, 1):
print(f"Testing group {i}/{len(lora_groups)}: {lora_group}")
test_lora_group(pipeline, lora_group)
Logs
Testing group 1/3: ['cereal', 'ikea', 'feng']
[LoRA] Loading new lora: cereal (strength=0.8)
[LoRA] Applying cereal (strength=0.8)
[LoRA] Loading new lora: ikea (strength=0.8)
[LoRA] Applying ikea (strength=0.8)
[LoRA] Loading new lora: feng (strength=0.8)
[LoRA] Applying feng (strength=0.8)
[LoRA] Setting LoRA device to cuda for adapters: []
[LoRA] Setting LoRA device to cpu for adapters: ['cereal', 'ikea', 'feng']
Traceback (most recent call last):
File "/aibabe-diffusers/aibabe_diffusers/main.py", line 75, in test_lora_group
generate_images(
File "/aibabe-diffusers/aibabe_diffusers/main.py", line 69, in generate_images
set_lora_device(pipeline, [lora.name for lora in loras], device="cpu")
File "/aibabe-diffusers/aibabe_diffusers/main.py", line 23, in set_lora_device
pipeline.set_lora_device(adapter_names=lora_names, device=device)
File "/aibabe-diffusers/worker-env/lib/python3.11/site-packages/diffusers/loaders/lora_base.py", line 952, in set_lora_device
module.lora_A[adapter_name].to(device)
~~~~~~~~~~~~~^^^^^^^^^^^^^^
File "/aibabe-diffusers/worker-env/lib/python3.11/site-packages/torch/nn/modules/container.py", line 492, in __getitem__
return self._modules[key]
~~~~~~~~~~~~~^^^^^
KeyError: 'cereal'
Error testing LoRA group ['cereal', 'ikea', 'feng']: 'cereal'
Testing group 2/3: ['cereal', 'ikea', 'feng']
[LoRA] cereal already loaded.
[LoRA] ikea already loaded.
[LoRA] feng already loaded.
[LoRA] Setting LoRA device to cuda for adapters: ['cereal', 'ikea', 'feng']
Traceback (most recent call last):
File "/aibabe-diffusers/aibabe_diffusers/main.py", line 75, in test_lora_group
generate_images(
File "/aibabe-diffusers/aibabe_diffusers/main.py", line 52, in generate_images
load_loras(pipeline, loras)
File "/aibabe-diffusers/aibabe_diffusers/main.py", line 48, in load_loras
set_lora_device(pipeline, lora_names=switch_device, device="cuda")
File "/aibabe-diffusers/aibabe_diffusers/main.py", line 23, in set_lora_device
pipeline.set_lora_device(adapter_names=lora_names, device=device)
File "/aibabe-diffusers/worker-env/lib/python3.11/site-packages/diffusers/loaders/lora_base.py", line 952, in set_lora_device
module.lora_A[adapter_name].to(device)
~~~~~~~~~~~~~^^^^^^^^^^^^^^
KeyError: 'cereal'
System Info
- 🤗 Diffusers version: 0.35.0.dev0
- Platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.11.11
- PyTorch version (GPU?): 2.7.1+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.33.1
- Transformers version: 4.53.0
- Accelerate version: 1.8.1
- PEFT version: 0.15.2
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 5090, 32607 MiB
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no