Skip to content

Improve recomposite color correct by fitting (s*v, v) instead of (s, v) #863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def INPUT_TYPES(s):
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"mask": ("MASK",),
"correction_method": (["None", "Uniform", "Linear"], )
"correction_method": (["None", "Uniform", "Linear", "Linear2"], )
}
}

Expand Down Expand Up @@ -246,6 +246,8 @@ def composite(self, destination, source, x, y, mask, correction_method):
source_section = color_correct_uniform(source_section, dest_section, inverse_mask)
elif correction_method == "Linear":
source_section = color_correct_linear(source_section, dest_section, inverse_mask)
elif correction_method == "Linear2":
source_section = color_correct_linear2(source_section, dest_section, inverse_mask)

source_portion = mask * source_section
destination_portion = inverse_mask * dest_section
Expand Down Expand Up @@ -278,29 +280,56 @@ def color_correct_linear(source_section: torch.Tensor, dest_section: torch.Tenso
if thresholded_sum > 50:
source_hsv = rgb2hsv(source_section)
dest_hsv = rgb2hsv(dest_section)
source_hsv_masked = source_hsv * thresholded
dest_hsv_masked = dest_hsv * thresholded
# Simple linear regression on dest as a function of source
source_mean = source_hsv_masked.sum(dim=[0, 2, 3]) / thresholded_sum
dest_mean = dest_hsv_masked.sum(dim=[0, 2, 3]) / thresholded_sum
source_mean = source_mean.unsqueeze(0).unsqueeze(2).unsqueeze(2)
dest_mean = dest_mean.unsqueeze(0).unsqueeze(2).unsqueeze(2)
source_deviation = (source_hsv - source_mean) * thresholded
dest_deviation = (dest_hsv - dest_mean) * thresholded
numerator = torch.sum(source_deviation * dest_deviation, (0, 2, 3))
denominator = torch.sum(source_deviation * source_deviation, (0, 2, 3))
# When all src the same color, we fall back to assuming m = 1 (uniform offset)
m = torch.where(denominator != 0, numerator / denominator, torch.tensor(1.0))
m = m.unsqueeze(0).unsqueeze(2).unsqueeze(2) # 3
b = dest_mean - source_mean * m
m[0][0][0][0] = 1.0
b[0][0][0][0] = 0.0
source_hsv = m * source_hsv + b
source_hsv = source_hsv.clamp(0, 1)
source_h = source_hsv[:, 0:1, :, :]
source_s = linear_fit(source_hsv[:, 1:2, :, :], dest_hsv[:, 1:2, :, :], thresholded)
source_v = linear_fit(source_hsv[:, 2:3, :, :], dest_hsv[:, 2:3, :, :], thresholded)
source_hsv = torch.cat([source_h, source_s, source_v], dim=1)
source_section = hsv2rgb(source_hsv)
return source_section


# like color_correct_linear, but fits s*v and v instead of s and v to avoid instability from dark pixels
def color_correct_linear2(source_section: torch.Tensor, dest_section: torch.Tensor, inverse_mask: torch.Tensor) -> torch.Tensor:
thresholded = (inverse_mask.clamp(0, 1) - 0.9999).clamp(0, 1) * 10000
thresholded_sum = thresholded.sum()
if thresholded_sum > 50:
source_hsv = rgb2hsv(source_section)
dest_hsv = rgb2hsv(dest_section)
source_h = source_hsv[:, 0:1, :, :]
source_sv_mul = source_hsv[:, 1:2, :, :] * source_hsv[:, 2:3, :, :]
dest_sv_mul = dest_hsv[:, 1:2, :, :] * dest_hsv[:, 2:3, :, :]
source_sv_mul = linear_fit(source_sv_mul, dest_sv_mul, thresholded)
source_v = linear_fit(source_hsv[:, 2:3, :, :], dest_hsv[:, 2:3, :, :], thresholded)
source_s = torch.zeros_like(source_sv_mul)
source_s[source_v != 0] = source_sv_mul[source_v != 0] / source_v[source_v != 0]
source_s = source_s.clamp(0, 1)
source_hsv = torch.cat([source_h, source_s, source_v], dim=1)
source_section = hsv2rgb(source_hsv)
return source_section


def linear_fit(source_component: torch.Tensor, dest_component: torch.Tensor, thresholded: torch.Tensor) -> torch.Tensor:
thresholded_sum = thresholded.sum()
source_masked = source_component * thresholded
dest_masked = dest_component * thresholded
# Simple linear regression on dest as a function of source
source_mean = source_masked.sum(dim=[0, 2, 3]) / thresholded_sum
dest_mean = dest_masked.sum(dim=[0, 2, 3]) / thresholded_sum
source_mean = source_mean.unsqueeze(0).unsqueeze(2).unsqueeze(2)
dest_mean = dest_mean.unsqueeze(0).unsqueeze(2).unsqueeze(2)
source_deviation = (source_component - source_mean) * thresholded
dest_deviation = (dest_component - dest_mean) * thresholded
numerator = torch.sum(source_deviation * dest_deviation, (0, 2, 3))
denominator = torch.sum(source_deviation * source_deviation, (0, 2, 3))
# When all src the same color, we fall back to assuming m = 1 (uniform offset)
m = torch.where(denominator != 0, numerator / denominator, torch.tensor(1.0))
m = m.unsqueeze(0).unsqueeze(2).unsqueeze(2)
b = dest_mean - source_mean * m
source_component = m * source_component + b
source_component = source_component.clamp(0, 1)
return source_component


# from https://github.com/limacv/RGB_HSV_HSL
def rgb2hsv(rgb: torch.Tensor) -> torch.Tensor:
cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True)
Expand Down
2 changes: 1 addition & 1 deletion src/Text2Image/T2IParamTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ static List<string> listVaes(Session s)
"8", Min: 4, Max: 4096, Step: 4, Toggleable: true, IsAdvanced: true, Group: GroupAdvancedSampling, OrderPriority: -4.4
));
ColorCorrectionBehavior = Register<string>(new("Color Correction Behavior", "Experimental: How to correct color when compositing a masked image.\n'None' = Do not attempt color correction.\n'Uniform' = Compute a fixed offset HSV correction for all pixels.\n'Linear' = Compute a linear correction that depends on each pixel's S and V.\nThis is useful for example when doing inpainting with Flux models, as the Flux VAE does not retain consistent colors - 'Linear' may help correct for this misbehavior.",
"None", IgnoreIf: "None", IsAdvanced: true, GetValues: (_) => ["None", "Uniform", "Linear"], Group: GroupAdvancedSampling, OrderPriority: -3
"None", IgnoreIf: "None", IsAdvanced: true, GetValues: (_) => ["None", "Uniform", "Linear", "Linear2"], Group: GroupAdvancedSampling, OrderPriority: -3
));
RemoveBackground = Register<bool>(new("Remove Background", "If enabled, removes the background from the generated image.\nThis internally uses RemBG.",
"false", IgnoreIf: "false", IsAdvanced: true, Group: GroupAdvancedSampling, OrderPriority: -2
Expand Down