Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 81 additions & 76 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,8 @@ def slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen):

def slice_cnet(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img):
if img is None:
img = model.cond_hint_original
hint = tiling.get_slice(img, h*8, h_len*8, w*8, w_len*8)
if isinstance(model, comfy.controlnet.ControlLora):
model.cond_hint = hint.float().to(model.device)
else:
model.cond_hint = hint.to(model.control_model.dtype).to(model.device)
img = model.backup
model.cond_hint_original = tiling.get_slice(img, h*8, h_len*8, w*8, w_len*8)

def slices_T2I(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img):
model.control_input = None
Expand All @@ -104,6 +100,8 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
end_at_step = min(end_at_step, steps)
device = comfy.model_management.get_torch_device()
samples = latent_image["samples"]
samples = comfy.sample.fix_empty_latent_channels(model, samples)

noise_mask = latent_image["noise_mask"] if "noise_mask" in latent_image else None
force_full_denoise = return_with_leftover_noise == "enable"
if add_noise == "disable":
Expand Down Expand Up @@ -132,8 +130,7 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
modelPatches, inference_memory = comfy.sampler_helpers.get_additional_models(conds, model.model_dtype())

comfy.model_management.load_models_gpu([model] + modelPatches, model.memory_required(noise.shape) + inference_memory)
real_model = model.model


sampler = comfy.samplers.KSampler(model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)

if tiling_strategy != 'padded':
Expand Down Expand Up @@ -223,73 +220,81 @@ def callback(step, x0, x, total_steps):

if tiling_strategy == "random strict":
samples_next = samples.clone()
for img_pass in tiles:
for i in range(len(img_pass)):
for tile_h, tile_h_len, tile_w, tile_w_len, tile_steps, tile_mask in img_pass[i]:
tiled_mask = None
if noise_mask is not None:
tiled_mask = tiling.get_slice(noise_mask, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
if tile_mask is not None:
if tiled_mask is not None:
tiled_mask *= tile_mask.to(device)

for m in cnets:
m.backup = m.cond_hint_original

try:
for img_pass in tiles:
for i in range(len(img_pass)):
for tile_h, tile_h_len, tile_w, tile_w_len, tile_steps, tile_mask in img_pass[i]:
tiled_mask = None
if noise_mask is not None:
tiled_mask = tiling.get_slice(noise_mask, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
if tile_mask is not None:
if tiled_mask is not None:
tiled_mask *= tile_mask.to(device)
else:
tiled_mask = tile_mask.to(device)

if tiling_strategy == 'padded' or tiling_strategy == 'random strict':
tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask = tiling.mask_at_boundary( tile_h, tile_h_len, tile_w, tile_w_len,
tile_height, tile_width, samples.shape[-2], samples.shape[-1],
tiled_mask, device)


if tiled_mask is not None and tiled_mask.sum().cpu() == 0.0:
continue

tiled_latent = tiling.get_slice(samples, tile_h, tile_h_len, tile_w, tile_w_len).to(device)

if tiling_strategy == 'padded':
tiled_noise = tiling.get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
else:
tiled_mask = tile_mask.to(device)

if tiling_strategy == 'padded' or tiling_strategy == 'random strict':
tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask = tiling.mask_at_boundary( tile_h, tile_h_len, tile_w, tile_w_len,
tile_height, tile_width, samples.shape[-2], samples.shape[-1],
tiled_mask, device)


if tiled_mask is not None and tiled_mask.sum().cpu() == 0.0:
continue

tiled_latent = tiling.get_slice(samples, tile_h, tile_h_len, tile_w, tile_w_len).to(device)

if tiling_strategy == 'padded':
tiled_noise = tiling.get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
else:
if tiled_mask is None or noise_mask is None:
tiled_noise = torch.zeros_like(tiled_latent)
if tiled_mask is None or noise_mask is None:
tiled_noise = torch.zeros_like(tiled_latent)
else:
tiled_noise = tiling.get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device) * (1 - tiled_mask)

#TODO: all other condition based stuff like area sets and GLIGEN should also happen here

#cnets
for m, img in zip(cnets, cnet_imgs):
slice_cnet(tile_h, tile_h_len, tile_w, tile_w_len, m, img)

#T2I
for m, img in zip(T2Is, T2I_imgs):
slices_T2I(tile_h, tile_h_len, tile_w, tile_w_len, m, img)

pos = copy_cond(positive)
neg = copy_cond(negative)

#cond areas
pos = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area) for c, area in zip(pos, spatial_conds_pos)]
pos = [c for c, ignore in pos if not ignore]
neg = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area) for c, area in zip(neg, spatial_conds_neg)]
neg = [c for c, ignore in neg if not ignore]

#gligen
for cond, gligen in zip(pos, gligen_pos):
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)
for cond, gligen in zip(neg, gligen_neg):
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)

tile_result = sampler.sample(tiled_noise, pos, neg, cfg=cfg, latent_image=tiled_latent, start_step=start_at_step + i * tile_steps, last_step=start_at_step + i*tile_steps + tile_steps, force_full_denoise=force_full_denoise and i+1 == end_at_step - start_at_step, denoise_mask=tiled_mask, callback=callback, disable_pbar=True, seed=noise_seed)
tile_result = tile_result.cpu()
if tiled_mask is not None:
tiled_mask = tiled_mask.cpu()
if tiling_strategy == "random strict":
tiling.set_slice(samples_next, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
else:
tiled_noise = tiling.get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device) * (1 - tiled_mask)

#TODO: all other condition based stuff like area sets and GLIGEN should also happen here

#cnets
for m, img in zip(cnets, cnet_imgs):
slice_cnet(tile_h, tile_h_len, tile_w, tile_w_len, m, img)

#T2I
for m, img in zip(T2Is, T2I_imgs):
slices_T2I(tile_h, tile_h_len, tile_w, tile_w_len, m, img)

pos = [c.copy() for c in positive]#copy_cond(positive_copy)
neg = [c.copy() for c in negative]#copy_cond(negative_copy)

#cond areas
pos = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area) for c, area in zip(pos, spatial_conds_pos)]
pos = [c for c, ignore in pos if not ignore]
neg = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area) for c, area in zip(neg, spatial_conds_neg)]
neg = [c for c, ignore in neg if not ignore]

#gligen
for cond, gligen in zip(pos, gligen_pos):
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)
for cond, gligen in zip(neg, gligen_neg):
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)

tile_result = sampler.sample(tiled_noise, pos, neg, cfg=cfg, latent_image=tiled_latent, start_step=start_at_step + i * tile_steps, last_step=start_at_step + i*tile_steps + tile_steps, force_full_denoise=force_full_denoise and i+1 == end_at_step - start_at_step, denoise_mask=tiled_mask, callback=callback, disable_pbar=True, seed=noise_seed)
tile_result = tile_result.cpu()
if tiled_mask is not None:
tiled_mask = tiled_mask.cpu()
tiling.set_slice(samples, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
if tiling_strategy == "random strict":
tiling.set_slice(samples_next, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
else:
tiling.set_slice(samples, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
if tiling_strategy == "random strict":
samples = samples_next.clone()

samples = samples_next.clone()
finally:
for m in cnets:
m.cond_hint_original = m.backup
del m.backup

comfy.sampler_helpers.cleanup_additional_models(modelPatches)

Expand All @@ -303,8 +308,8 @@ def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"tile_width": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
"tile_height": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
"tile_width": ("INT", {"default": 512, "min": 128, "max": MAX_RESOLUTION, "step": 8}),
"tile_height": ("INT", {"default": 512, "min": 128, "max": MAX_RESOLUTION, "step": 8}),
"tiling_strategy": (["random", "random strict", "padded", 'simple'], ),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
Expand Down Expand Up @@ -332,8 +337,8 @@ def INPUT_TYPES(s):
{"model": ("MODEL",),
"add_noise": (["enable", "disable"], ),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"tile_width": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
"tile_height": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
"tile_width": ("INT", {"default": 512, "min": 128, "max": MAX_RESOLUTION, "step": 64}),
"tile_height": ("INT", {"default": 512, "min": 128, "max": MAX_RESOLUTION, "step": 64}),
"tiling_strategy": (["random", "random strict", "padded", 'simple'], ),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
Expand Down