Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 233 additions & 24 deletions MultiLoraLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,31 @@
import os
import re

KEY_BLOCKS_ALL = "all_blocks"
KEY_BLOCKS_ALL_ABBR = "allb"
KEY_BLOCKS_SINGLE = "single_blocks"
KEY_BLOCKS_SINGLE_ABBR = "msb"
KEY_BLOCKS_DOUBLE = "double_blocks"
KEY_BLOCKS_DOUBLE_ABBR = "mdb"

class MultiLoraLoader:
def __init__(self):
self.selected_loras = SelectedLoras()

@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP", ),
"text": ("STRING", {
"multiline": True,
"default": ""}),
}}
},
"optional": {"clip": ("CLIP", ),}}

RETURN_TYPES = ("MODEL", "CLIP")
FUNCTION = "load_loras"
CATEGORY = "loaders"

def load_loras(self, model, clip, text):
def load_loras(self, model, text, clip = None):
result = (model, clip)

lora_items = self.selected_loras.updated_lora_items_with_text(text)
Expand All @@ -33,6 +40,78 @@ def load_loras(self, model, clip, text):
result = item.apply_lora(result[0], result[1])

return result

class MultiLoraParser:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"text": ("STRING", {
"multiline": True,
"default": ""
}),
}
}

RETURN_TYPES = ("LORA_INFO",)
RETURN_NAMES = ("Lora Info List",)
FUNCTION = "parse_loras"
CATEGORY = "utils"

def __init__(self):
self.selected_loras = SelectedLoras()

@staticmethod
def get_block_list(blocks):
return list()

def parse_loras(self, text):
lora_items = self.selected_loras.updated_lora_items_with_text(text)

parsed_info = []
for item in lora_items:
lora_path = item.get_lora_path()
parsed_info.append({
"path": lora_path,
"strength": item.strength_model,
"name": item.lora_name,
"blocks": list() if KEY_BLOCKS_ALL in item.blocks else MultiLoraParser.get_block_list(item.blocks),
"low_mem_load": False # Can be expanded later if needed
})

return (parsed_info,)

class LoraInfoToHunyuanVidLora:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora_info": ("LORA_INFO",),
}
}

RETURN_TYPES = ("HYVIDLORA",)
FUNCTION = "passthrough"
CATEGORY = "utils"

def passthrough(self, lora_info):
return (lora_info,)

class LoraInfoToWanVidLora:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora_info": ("LORA_INFO",),
}
}

RETURN_TYPES = ("WANVIDLORA",)
FUNCTION = "passthrough"
CATEGORY = "utils"

def passthrough(self, lora_info):
return (lora_info,)

# maintains a list of lora objects made from a prompt, preserving loaded loras across changes
class SelectedLoras:
Expand Down Expand Up @@ -88,35 +167,112 @@ def __init__(self, lora_text, loras_by_short_names, default_weight, weight_separ
self.comment_trim_re = re.compile("\s*#.*\Z")

def execute(self):
return [LoraItem(elements[0], elements[1], elements[2])
for line in self.lora_text.splitlines()
for elements in [self.parse_lora_description(self.description_from_line(line))] if elements[0] is not None]

out_loras = []

for line in self.lora_text.splitlines():
description = self.description_from_line(line)
name, model_weight, clip_weight, block_type = self.parse_lora_description(description)
if name is not None:
out_loras.append(LoraItem(name, model_weight, clip_weight, block_type))

return out_loras

def parse_lora_description(self, description):
if description is None:
return (None,)
return (None, None, None, None)

lora_name = None
strength_model = self.default_weight
strength_clip = None
blocks = []

remaining, sep, strength = description.rpartition(self.weight_separator)
if sep == self.weight_separator:
lora_name = remaining
strength_model = float(strength)

remaining, sep, strength = remaining.rpartition(self.weight_separator)
if sep == self.weight_separator:
strength_clip = strength_model
strength_model = float(strength)
lora_name = remaining
else:
lora_name = description
parts = description.split(self.weight_separator)

def is_block_param(param):
return param.split("[")[0] in [KEY_BLOCKS_ALL, KEY_BLOCKS_ALL_ABBR, KEY_BLOCKS_SINGLE, KEY_BLOCKS_SINGLE_ABBR, KEY_BLOCKS_DOUBLE, KEY_BLOCKS_DOUBLE_ABBR]

try:
if len(parts) == 1: # Only lora name
lora_name = parts[0]
elif len(parts) == 2: # lora name and model weight or blocks
lora_name, last_param = parts
if "blocks" in last_param:
blocks = [last_param]
else:
strength_model = float(last_param)
elif len(parts) == 3: # lora name, model weight, and either clip weight or blocks
lora_name, strength_model, last_param = parts
strength_model = float(strength_model)
if is_block_param(last_param):
blocks = [last_param]
else:
strength_clip = float(last_param)
elif len(parts) == 4: # name, model weight, clip weight, blocks (single or double) OR name, model weight, blocks (single and double)
lora_name, strength_model, second_to_last_param, last_param = parts
strength_model = float(strength_model)
if is_block_param(second_to_last_param):
blocks = [second_to_last_param]
else:
strength_clip = float(second_to_last_param)
blocks.append(last_param)
elif len(parts) == 5: # lora name, model weight, clip weight, single blocks, double blocks (block position is interchangeable)
lora_name, strength_model, strength_clip, blocksA, blocksB = parts
strength_model = float(strength_model)
strength_clip = float(strength_clip)
blocks = [blocksA, blocksB]
except ValueError as e:
raise ValueError(f"Invalid description format: {description}") from e

if strength_clip is None:
strength_clip = strength_model

return (self.loras_by_short_names.get(lora_name, lora_name), strength_model, strength_clip)

def parse_ranges(input_str, max_range=100):
result = set()
for part in input_str.split(','):
part = part.strip() # Remove extra spaces
if '-' in part:
try:
start, end = map(int, part.split('-'))
# Add range as strings
result.update(map(str, range(min(start, end), max(start, end) + 1)))
except ValueError as e:
raise ValueError(f"Invalid numeric range: {e}")
elif "even" in part:
# Add even numbers as strings
result.update(map(str, range(0, max_range + 1, 2)))
elif "odd" in part:
# Add odd numbers as strings
result.update(map(str, range(1, max_range + 1, 2)))
elif part.isdigit(): # Single numeric value
result.add(part)
elif part: # Handle non-numeric strings
result.add(part)
return result

valid_blocks = [KEY_BLOCKS_SINGLE, KEY_BLOCKS_SINGLE_ABBR, KEY_BLOCKS_DOUBLE, KEY_BLOCKS_DOUBLE_ABBR]
if not blocks:
blocks = {KEY_BLOCKS_ALL: []}
else:
# Split compound block strings like `double_blocks[1-10]`
normalized_blocks = {}
for block in blocks:
# Remove whitespace and expand abbreviations
block = block.strip().replace(KEY_BLOCKS_SINGLE_ABBR, KEY_BLOCKS_SINGLE).replace(KEY_BLOCKS_DOUBLE_ABBR, KEY_BLOCKS_DOUBLE)
if "[" in block and "]" in block:
block_type, indices = block.split("[", 1)
block_type = block_type.strip()
indices = indices.rstrip("]").strip()
if block_type in valid_blocks:
normalized_blocks[block_type] = parse_ranges(indices)
elif block in valid_blocks:
normalized_blocks[block] = set()
else:
raise ValueError(f"Invalid block type or format: {block}")
blocks = normalized_blocks

return (self.loras_by_short_names.get(lora_name, lora_name), strength_model, strength_clip, blocks)


def description_from_line(self, line):
result = self.comment_trim_re.sub("", line.strip())
Expand All @@ -125,14 +281,15 @@ def description_from_line(self, line):


class LoraItem:
def __init__(self, lora_name, strength_model, strength_clip):
def __init__(self, lora_name, strength_model, strength_clip, blocks_type):
self.lora_name = lora_name
self.strength_model = strength_model
self.strength_clip = strength_clip
self.blocks = blocks_type
self._loaded_lora = None

def __eq__(self, other):
return self.lora_name == other.lora_name and self.strength_model == other.strength_model and self.strength_clip == other.strength_clip
return self.lora_name == other.lora_name and self.strength_model == other.strength_model and self.strength_clip == other.strength_clip and self.blocks == other.blocks

def get_lora_path(self):
return folder_paths.get_full_path("loras", self.lora_name)
Expand All @@ -146,9 +303,55 @@ def move_resources_from(self, lora_items_by_name):
def apply_lora(self, model, clip):
if self.is_noop:
return (model, clip)

filtered_lora = self.get_filtered_lora()

model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, self.lora_object, self.strength_model, self.strength_clip)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, filtered_lora, self.strength_model, self.strength_clip)
return (model_lora, clip_lora)

def get_filtered_lora(self):
# Early return if all blocks are present
if KEY_BLOCKS_ALL in self.blocks:
return self.lora_object

# Check if single and double blocks are in the 'self.blocks'
use_single_blocks = KEY_BLOCKS_SINGLE in self.blocks
use_single_block_indices = use_single_blocks and len(self.blocks[KEY_BLOCKS_SINGLE]) > 0
use_double_blocks = KEY_BLOCKS_DOUBLE in self.blocks
use_double_block_indices = use_double_blocks and len(self.blocks[KEY_BLOCKS_DOUBLE]) > 0

# Initialize a dictionary to store the filtered Lora values
filtered_lora = {}

# Helper function to check for valid indices
def has_matching_index(layer, index):
if use_single_block_indices and layer == KEY_BLOCKS_SINGLE:
return index in self.blocks[KEY_BLOCKS_SINGLE]
if use_double_block_indices and layer == KEY_BLOCKS_DOUBLE:
return index in self.blocks[KEY_BLOCKS_DOUBLE]
return True # If there are no indices, all are considered matching

# Iterate through the items in the Lora object
for key, value in self.lora_object.items():
components = key.split(".")

try:
# Strip and lowercase the layer for easier matching
layer = components[1].strip().lower()
index = components[2].strip()

# Check if layer matches and if index is valid
if use_single_blocks and layer == KEY_BLOCKS_SINGLE:
if has_matching_index(layer, index):
filtered_lora[key] = value

if use_double_blocks and layer == KEY_BLOCKS_DOUBLE:
if has_matching_index(layer, index):
filtered_lora[key] = value
except:
pass

return filtered_lora

@property
def lora_object(self):
Expand Down Expand Up @@ -194,9 +397,15 @@ def process_text(self, text):
NODE_CLASS_MAPPINGS = {
"MultiLoraLoader-70bf3d77": MultiLoraLoader,
"LoraTextExtractor-b1f83aa2": LoraTextExtractor,
"MultiLoraParser-8c12fa4b": MultiLoraParser,
"LoraInfoToHunyuanVidLora-bb8cfde0": LoraInfoToHunyuanVidLora,
"LoraInfoToWanVidLora-bd47e92c": LoraInfoToWanVidLora,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"MultiLoraLoader-70bf3d77": "MultiLora Loader",
"LoraTextExtractor-b1f83aa2": "Lora Text Extractor",
"MultiLoraParser-8c12fa4b": "Multi Lora Parser",
"LoraInfoToHunyuanVidLora-bb8cfde0": "Lora Info To HunyuanVid Lora",
"LoraInfoToWanVidLora-bd47e92c": "Lora Info To WanVid Lora",
}