diff --git a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmComfyCommon/SwarmImages.py b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmComfyCommon/SwarmImages.py index b7db701b..037808ed 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmComfyCommon/SwarmImages.py +++ b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmComfyCommon/SwarmImages.py @@ -209,7 +209,7 @@ def INPUT_TYPES(s): "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "mask": ("MASK",), - "correction_method": (["None", "Uniform", "Linear"], ) + "correction_method": (["None", "Uniform", "Linear", "Linear2"], ) } } @@ -246,6 +246,8 @@ def composite(self, destination, source, x, y, mask, correction_method): source_section = color_correct_uniform(source_section, dest_section, inverse_mask) elif correction_method == "Linear": source_section = color_correct_linear(source_section, dest_section, inverse_mask) + elif correction_method == "Linear2": + source_section = color_correct_linear2(source_section, dest_section, inverse_mask) source_portion = mask * source_section destination_portion = inverse_mask * dest_section @@ -278,29 +280,56 @@ def color_correct_linear(source_section: torch.Tensor, dest_section: torch.Tenso if thresholded_sum > 50: source_hsv = rgb2hsv(source_section) dest_hsv = rgb2hsv(dest_section) - source_hsv_masked = source_hsv * thresholded - dest_hsv_masked = dest_hsv * thresholded - # Simple linear regression on dest as a function of source - source_mean = source_hsv_masked.sum(dim=[0, 2, 3]) / thresholded_sum - dest_mean = dest_hsv_masked.sum(dim=[0, 2, 3]) / thresholded_sum - source_mean = source_mean.unsqueeze(0).unsqueeze(2).unsqueeze(2) - dest_mean = dest_mean.unsqueeze(0).unsqueeze(2).unsqueeze(2) - source_deviation = (source_hsv - source_mean) * thresholded - dest_deviation = (dest_hsv - dest_mean) * thresholded - numerator = torch.sum(source_deviation * dest_deviation, (0, 2, 3)) - denominator = torch.sum(source_deviation * source_deviation, (0, 2, 3)) - # When all src the same color, we fall back to assuming m = 1 (uniform offset) - m = torch.where(denominator != 0, numerator / denominator, torch.tensor(1.0)) - m = m.unsqueeze(0).unsqueeze(2).unsqueeze(2) # 3 - b = dest_mean - source_mean * m - m[0][0][0][0] = 1.0 - b[0][0][0][0] = 0.0 - source_hsv = m * source_hsv + b - source_hsv = source_hsv.clamp(0, 1) + source_h = source_hsv[:, 0:1, :, :] + source_s = linear_fit(source_hsv[:, 1:2, :, :], dest_hsv[:, 1:2, :, :], thresholded) + source_v = linear_fit(source_hsv[:, 2:3, :, :], dest_hsv[:, 2:3, :, :], thresholded) + source_hsv = torch.cat([source_h, source_s, source_v], dim=1) source_section = hsv2rgb(source_hsv) return source_section +# like color_correct_linear, but fits s*v and v instead of s and v to avoid instability from dark pixels +def color_correct_linear2(source_section: torch.Tensor, dest_section: torch.Tensor, inverse_mask: torch.Tensor) -> torch.Tensor: + thresholded = (inverse_mask.clamp(0, 1) - 0.9999).clamp(0, 1) * 10000 + thresholded_sum = thresholded.sum() + if thresholded_sum > 50: + source_hsv = rgb2hsv(source_section) + dest_hsv = rgb2hsv(dest_section) + source_h = source_hsv[:, 0:1, :, :] + source_sv_mul = source_hsv[:, 1:2, :, :] * source_hsv[:, 2:3, :, :] + dest_sv_mul = dest_hsv[:, 1:2, :, :] * dest_hsv[:, 2:3, :, :] + source_sv_mul = linear_fit(source_sv_mul, dest_sv_mul, thresholded) + source_v = linear_fit(source_hsv[:, 2:3, :, :], dest_hsv[:, 2:3, :, :], thresholded) + source_s = torch.zeros_like(source_sv_mul) + source_s[source_v != 0] = source_sv_mul[source_v != 0] / source_v[source_v != 0] + source_s = source_s.clamp(0, 1) + source_hsv = torch.cat([source_h, source_s, source_v], dim=1) + source_section = hsv2rgb(source_hsv) + return source_section + + +def linear_fit(source_component: torch.Tensor, dest_component: torch.Tensor, thresholded: torch.Tensor) -> torch.Tensor: + thresholded_sum = thresholded.sum() + source_masked = source_component * thresholded + dest_masked = dest_component * thresholded + # Simple linear regression on dest as a function of source + source_mean = source_masked.sum(dim=[0, 2, 3]) / thresholded_sum + dest_mean = dest_masked.sum(dim=[0, 2, 3]) / thresholded_sum + source_mean = source_mean.unsqueeze(0).unsqueeze(2).unsqueeze(2) + dest_mean = dest_mean.unsqueeze(0).unsqueeze(2).unsqueeze(2) + source_deviation = (source_component - source_mean) * thresholded + dest_deviation = (dest_component - dest_mean) * thresholded + numerator = torch.sum(source_deviation * dest_deviation, (0, 2, 3)) + denominator = torch.sum(source_deviation * source_deviation, (0, 2, 3)) + # When all src the same color, we fall back to assuming m = 1 (uniform offset) + m = torch.where(denominator != 0, numerator / denominator, torch.tensor(1.0)) + m = m.unsqueeze(0).unsqueeze(2).unsqueeze(2) + b = dest_mean - source_mean * m + source_component = m * source_component + b + source_component = source_component.clamp(0, 1) + return source_component + + # from https://github.com/limacv/RGB_HSV_HSL def rgb2hsv(rgb: torch.Tensor) -> torch.Tensor: cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True) diff --git a/src/Text2Image/T2IParamTypes.cs b/src/Text2Image/T2IParamTypes.cs index f59a0734..ad93b535 100644 --- a/src/Text2Image/T2IParamTypes.cs +++ b/src/Text2Image/T2IParamTypes.cs @@ -812,7 +812,7 @@ static List listVaes(Session s) "8", Min: 4, Max: 4096, Step: 4, Toggleable: true, IsAdvanced: true, Group: GroupAdvancedSampling, OrderPriority: -4.4 )); ColorCorrectionBehavior = Register(new("Color Correction Behavior", "Experimental: How to correct color when compositing a masked image.\n'None' = Do not attempt color correction.\n'Uniform' = Compute a fixed offset HSV correction for all pixels.\n'Linear' = Compute a linear correction that depends on each pixel's S and V.\nThis is useful for example when doing inpainting with Flux models, as the Flux VAE does not retain consistent colors - 'Linear' may help correct for this misbehavior.", - "None", IgnoreIf: "None", IsAdvanced: true, GetValues: (_) => ["None", "Uniform", "Linear"], Group: GroupAdvancedSampling, OrderPriority: -3 + "None", IgnoreIf: "None", IsAdvanced: true, GetValues: (_) => ["None", "Uniform", "Linear", "Linear2"], Group: GroupAdvancedSampling, OrderPriority: -3 )); RemoveBackground = Register(new("Remove Background", "If enabled, removes the background from the generated image.\nThis internally uses RemBG.", "false", IgnoreIf: "false", IsAdvanced: true, Group: GroupAdvancedSampling, OrderPriority: -2