Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split node WD14Tagger|pysssss #80

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 96 additions & 34 deletions wd14tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,40 @@ def get_installed_models():
return models


async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclude_tags="", replace_underscore=True, trailing_comma=False, client_id=None, node=None):
if model_name.endswith(".onnx"):
model_name = model_name[0:-5]
installed = list(get_installed_models())
if not any(model_name + ".onnx" in s for s in installed):
await download_model(model_name, client_id, node)
class WD14Model:

name = os.path.join(models_dir, model_name + ".onnx")
model = InferenceSession(name, providers=defaults["ortProviders"])
def __init__(self, model_name, replace_underscore=False):
if model_name.endswith(".onnx"):
model_name = model_name[0:-5]
installed = list(get_installed_models())
if not any(model_name + ".onnx" in s for s in installed):
wait_for_async(lambda: download_model(model_name, None, None))

input = model.get_inputs()[0]
name = os.path.join(models_dir, model_name + ".onnx")
self.model = InferenceSession(name, providers=defaults["ortProviders"])

tags = []
general_index = None
character_index = None
with open(os.path.join(models_dir, model_name + ".csv")) as f:
reader = csv.reader(f)
next(reader)
for row in reader:
if general_index is None and row[2] == "0":
general_index = reader.line_num - 2
elif character_index is None and row[2] == "4":
character_index = reader.line_num - 2
if replace_underscore:
tags.append(row[1].replace("_", " "))
else:
tags.append(row[1])
self.tags = tags
self.general_index = general_index
self.character_index = character_index


def tag(image, model: WD14Model, threshold=0.35, character_threshold=0.85, exclude_tags="", trailing_comma=False, client_id=None, node=None):
input = model.model.get_inputs()[0]
height = input.shape[1]

# Reduce to max size and pad with white
Expand All @@ -71,31 +94,14 @@ async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclu
image = image[:, :, ::-1] # RGB -> BGR
image = np.expand_dims(image, 0)

# Read all tags from csv and locate start of each category
tags = []
general_index = None
character_index = None
with open(os.path.join(models_dir, model_name + ".csv")) as f:
reader = csv.reader(f)
next(reader)
for row in reader:
if general_index is None and row[2] == "0":
general_index = reader.line_num - 2
elif character_index is None and row[2] == "4":
character_index = reader.line_num - 2
if replace_underscore:
tags.append(row[1].replace("_", " "))
else:
tags.append(row[1])

label_name = model.get_outputs()[0].name
probs = model.run([label_name], {input.name: image})[0]

result = list(zip(tags, probs[0]))
label_name = model.model.get_outputs()[0].name
probs = model.model.run([label_name], {input.name: image})[0]

result = list(zip(model.tags, probs[0]))

# rating = max(result[:general_index], key=lambda x: x[1])
general = [item for item in result[general_index:character_index] if item[1] > threshold]
character = [item for item in result[character_index:] if item[1] > character_threshold]
general = [item for item in result[model.general_index:model.character_index] if item[1] > threshold]
character = [item for item in result[model.character_index:] if item[1] > character_threshold]

all = character + general
remove = [s.strip() for s in exclude_tags.lower().split(",")]
Expand All @@ -115,6 +121,7 @@ async def download_model(model, client_id, node):
hf_endpoint = hf_endpoint.rstrip("/")

url = config["models"][model]
# https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/resolve/main/model.onnx
url = url.replace("{HF_ENDPOINT}", hf_endpoint)
url = f"{url}/resolve/main/"
async with aiohttp.ClientSession(loop=asyncio.get_event_loop()) as session:
Expand Down Expand Up @@ -162,7 +169,7 @@ async def get_tags(request):

models = get_installed_models()
default = defaults["model"] + ".onnx"
model = default if default in models else models[0]
model = WD14Model(default if default in models else models[0])

return web.json_response(await tag(image, model, client_id=request.rel_url.query.get("clientId", ""), node=request.rel_url.query.get("node", "")))

Expand Down Expand Up @@ -195,15 +202,70 @@ def tag(self, image, model, threshold, character_threshold, exclude_tags="", rep

pbar = comfy.utils.ProgressBar(tensor.shape[0])
tags = []
wd14_model = WD14Model(model, replace_underscore)
for i in range(tensor.shape[0]):
image = Image.fromarray(tensor[i])
tags.append(wait_for_async(lambda: tag(image, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma)))
tags.append(tag(image, wd14_model, threshold, character_threshold, exclude_tags, trailing_comma))
pbar.update(1)
return {"ui": {"tags": tags}, "result": (tags,)}


class WD14ModelLoader:

@classmethod
def INPUT_TYPES(s):
extra = [name for name, _ in (os.path.splitext(m) for m in get_installed_models()) if name not in known_models]
models = known_models + extra
return {"required": {
"model_name": (models, { "default": defaults["model"] }),
"replace_underscore": ("BOOLEAN", {"default": defaults["replace_underscore"]}),
}}

RETURN_TYPES = ("WD14_MODEL", )
RETURN_NAMES = ("model", )
FUNCTION = "load"

CATEGORY = "image"

def load(self, model_name, replace_underscore):
return (WD14Model(model_name, replace_underscore), )


class WD14TaggerOnly:

@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE", ),
"model": ("WD14_MODEL", ),
"threshold": ("FLOAT", {"default": defaults["threshold"], "min": 0.0, "max": 1, "step": 0.05}),
"character_threshold": ("FLOAT", {"default": defaults["character_threshold"], "min": 0.0, "max": 1, "step": 0.05}),
"trailing_comma": ("BOOLEAN", {"default": defaults["trailing_comma"]}),
"exclude_tags": ("STRING", {"default": defaults["exclude_tags"]}),
}}

RETURN_TYPES = ("STRING",)
OUTPUT_IS_LIST = (True,)
FUNCTION = "tag"

CATEGORY = "image"

def tag(self, image, model, threshold, character_threshold, exclude_tags="", trailing_comma=False):
tensor = image*255
tensor = np.array(tensor, dtype=np.uint8)

pbar = comfy.utils.ProgressBar(tensor.shape[0])
tags = []
for i in range(tensor.shape[0]):
image = Image.fromarray(tensor[i])
tags.append(tag(image, model, threshold, character_threshold, exclude_tags, trailing_comma))
pbar.update(1)
return (tags, )

NODE_CLASS_MAPPINGS = {
"WD14Tagger|pysssss": WD14Tagger,
"WD14ModelLoader|pysssss": WD14ModelLoader,
"WD14TaggerOnly|pysssss": WD14TaggerOnly,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WD14Tagger|pysssss": "WD14 Tagger 🐍",
Expand Down