diff --git a/README.md b/README.md index 20ba5f6..fdfabfb 100644 --- a/README.md +++ b/README.md @@ -51,11 +51,11 @@ img1 = Variable( img1, requires_grad=False) img2 = Variable( img2, requires_grad = True) -# Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True) +# Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True,reduction='mean') ssim_value = pytorch_ssim.ssim(img1, img2).data[0] print("Initial ssim:", ssim_value) -# Module: pytorch_ssim.SSIM(window_size = 11, size_average = True) +# Module: pytorch_ssim.SSIM(window_size = 11, size_average = True,reduction='mean') ssim_loss = pytorch_ssim.SSIM() optimizer = optim.Adam([img2], lr=0.01) diff --git a/pytorch_ssim/__init__.py b/pytorch_ssim/__init__.py index 738e803..61af1b8 100644 --- a/pytorch_ssim/__init__.py +++ b/pytorch_ssim/__init__.py @@ -14,7 +14,7 @@ def create_window(window_size, channel): window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) return window -def _ssim(img1, img2, window, window_size, channel, size_average = True): +def _ssim(img1, img2, window, window_size, channel, size_average = True, reduction = 'mean'): mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) @@ -30,19 +30,22 @@ def _ssim(img1, img2, window, window_size, channel, size_average = True): C2 = 0.03**2 ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) - - if size_average: - return ssim_map.mean() - else: - return ssim_map.mean(1).mean(1).mean(1) - + if reduction == 'mean': + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + elif reduction == 'none': + return ssim_map +# reduction can be 'mean' or 'none' class SSIM(torch.nn.Module): - def __init__(self, window_size = 11, size_average = True): + def __init__(self, window_size = 11, size_average = True, reduction= 'mean'): super(SSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.channel = 1 self.window = create_window(window_size, self.channel) + self.reduction = reduction def forward(self, img1, img2): (_, channel, _, _) = img1.size() @@ -60,9 +63,9 @@ def forward(self, img1, img2): self.channel = channel - return _ssim(img1, img2, window, self.window_size, channel, self.size_average) - -def ssim(img1, img2, window_size = 11, size_average = True): + return _ssim(img1, img2, window, self.window_size, channel, self.size_average , self.reduction) +# reduction can be 'mean' or 'none' +def ssim(img1, img2, window_size = 11, size_average = True, reduction='mean'): (_, channel, _, _) = img1.size() window = create_window(window_size, channel) @@ -70,4 +73,4 @@ def ssim(img1, img2, window_size = 11, size_average = True): window = window.cuda(img1.get_device()) window = window.type_as(img1) - return _ssim(img1, img2, window, window_size, channel, size_average) + return _ssim(img1, img2, window, window_size, channel, size_average, reduction)