Skip to content

Commit 7df42b9

Browse files
Fix dora.
1 parent 5d8bbb7 commit 7df42b9

File tree

2 files changed

+20
-21
lines changed

2 files changed

+20
-21
lines changed

comfy/lora.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,26 @@ def model_lora_keys_unet(model, key_map={}):
327327
return key_map
328328

329329

330+
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
331+
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
332+
lora_diff *= alpha
333+
weight_calc = weight + lora_diff.type(weight.dtype)
334+
weight_norm = (
335+
weight_calc.transpose(0, 1)
336+
.reshape(weight_calc.shape[1], -1)
337+
.norm(dim=1, keepdim=True)
338+
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
339+
.transpose(0, 1)
340+
)
341+
342+
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
343+
if strength != 1.0:
344+
weight_calc -= weight
345+
weight += strength * (weight_calc)
346+
else:
347+
weight[:] = weight_calc
348+
return weight
349+
330350
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
331351
for p in patches:
332352
strength = p[0]

comfy/model_patcher.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,6 @@
3131
from comfy.types import UnetWrapperFunction
3232

3333

34-
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
35-
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
36-
lora_diff *= alpha
37-
weight_calc = weight + lora_diff.type(weight.dtype)
38-
weight_norm = (
39-
weight_calc.transpose(0, 1)
40-
.reshape(weight_calc.shape[1], -1)
41-
.norm(dim=1, keepdim=True)
42-
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
43-
.transpose(0, 1)
44-
)
45-
46-
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
47-
if strength != 1.0:
48-
weight_calc -= weight
49-
weight += strength * (weight_calc)
50-
else:
51-
weight[:] = weight_calc
52-
return weight
53-
54-
5534
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
5635
to = model_options["transformer_options"].copy()
5736

0 commit comments

Comments
 (0)