This repository has been archived by the owner on Jul 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodels.py
111 lines (88 loc) · 3.46 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from .base import BaseNode, GLOBAL_CATEGORY
# noinspection PyUnresolvedReferences,PyPackageRequirements
import comfy.utils
# noinspection PyUnresolvedReferences,PyPackageRequirements
import folder_paths
MODULE_CATEGORY = f"{GLOBAL_CATEGORY}/models"
class HelperNodes_CheckpointSelector(BaseNode):
"""
Simple selector node that allows the selection of Checkpoint/Model.
This should then be passed into either a conditioner or into a LoRA loader.
Does not include LoRA selection, which is done in the standard Load LoRA nodes.
"""
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"chkpt_name": (folder_paths.get_filename_list("checkpoints"),)
}
}
CATEGORY = MODULE_CATEGORY
RETURN_TYPES = (folder_paths.get_filename_list("checkpoints"),)
RETURN_NAMES = ("chkpt_name",)
def process(self, chkpt_name) -> tuple:
return (chkpt_name,)
class HelperNodes_VAESelector(BaseNode):
"""
Simple selector node that allows the selection of VAEs.
This should then be passed to a VAE decoder node as it returns a VAE.
"""
@staticmethod
def vae_list():
# Borrowed verbatim from comfyui's implementations.
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False
for v in approx_vaes:
if v.startswith("taesd_decoder."):
sd1_taesd_dec = True
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
return vaes
@staticmethod
def load_taesd(name):
# Borrowed verbatim from comfyui's implementations
sd = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
return sd
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"vae_name": (cls.vae_list(),)
}
}
CATEGORY = f"{MODULE_CATEGORY}"
RETURN_TYPES = ("VAE",)
RETURN_NAMES = ("VAE",)
def process(self, vae_name) -> tuple:
if vae_name in ["taesd", "taesdxl"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,)