Skip to content

Commit 6d3e784

Browse files
committed
Added ppl
1 parent 5f60e58 commit 6d3e784

16 files changed

+1047
-21
lines changed

LICENSE-LPIPS

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2+
All rights reserved.
3+
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are met:
6+
7+
* Redistributions of source code must retain the above copyright notice, this
8+
list of conditions and the following disclaimer.
9+
10+
* Redistributions in binary form must reproduce the above copyright notice,
11+
this list of conditions and the following disclaimer in the documentation
12+
and/or other materials provided with the distribution.
13+
14+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24+

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ train.py supports Weights & Biases logging. If you want to use it, add --wandb a
2424

2525
![Sample with truncation](sample.png)
2626

27-
At 40,000 iterations. (trained on 1.28M images)
27+
At 110,000 iterations. (trained on 3.52M images)

lpips/__init__.py

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
6+
import numpy as np
7+
from skimage.measure import compare_ssim
8+
import torch
9+
from torch.autograd import Variable
10+
11+
from lpips import dist_model
12+
13+
class PerceptualLoss(torch.nn.Module):
14+
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
15+
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16+
super(PerceptualLoss, self).__init__()
17+
print('Setting up Perceptual loss...')
18+
self.use_gpu = use_gpu
19+
self.spatial = spatial
20+
self.gpu_ids = gpu_ids
21+
self.model = dist_model.DistModel()
22+
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
23+
print('...[%s] initialized'%self.model.name())
24+
print('...Done')
25+
26+
def forward(self, pred, target, normalize=False):
27+
"""
28+
Pred and target are Variables.
29+
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
30+
If normalize is False, assumes the images are already between [-1,+1]
31+
32+
Inputs pred and target are Nx3xHxW
33+
Output pytorch Variable N long
34+
"""
35+
36+
if normalize:
37+
target = 2 * target - 1
38+
pred = 2 * pred - 1
39+
40+
return self.model.forward(target, pred)
41+
42+
def normalize_tensor(in_feat,eps=1e-10):
43+
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
44+
return in_feat/(norm_factor+eps)
45+
46+
def l2(p0, p1, range=255.):
47+
return .5*np.mean((p0 / range - p1 / range)**2)
48+
49+
def psnr(p0, p1, peak=255.):
50+
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
51+
52+
def dssim(p0, p1, range=255.):
53+
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
54+
55+
def rgb2lab(in_img,mean_cent=False):
56+
from skimage import color
57+
img_lab = color.rgb2lab(in_img)
58+
if(mean_cent):
59+
img_lab[:,:,0] = img_lab[:,:,0]-50
60+
return img_lab
61+
62+
def tensor2np(tensor_obj):
63+
# change dimension of a tensor object into a numpy array
64+
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
65+
66+
def np2tensor(np_obj):
67+
# change dimenion of np array into tensor array
68+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
69+
70+
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
71+
# image tensor to lab tensor
72+
from skimage import color
73+
74+
img = tensor2im(image_tensor)
75+
img_lab = color.rgb2lab(img)
76+
if(mc_only):
77+
img_lab[:,:,0] = img_lab[:,:,0]-50
78+
if(to_norm and not mc_only):
79+
img_lab[:,:,0] = img_lab[:,:,0]-50
80+
img_lab = img_lab/100.
81+
82+
return np2tensor(img_lab)
83+
84+
def tensorlab2tensor(lab_tensor,return_inbnd=False):
85+
from skimage import color
86+
import warnings
87+
warnings.filterwarnings("ignore")
88+
89+
lab = tensor2np(lab_tensor)*100.
90+
lab[:,:,0] = lab[:,:,0]+50
91+
92+
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
93+
if(return_inbnd):
94+
# convert back to lab, see if we match
95+
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
96+
mask = 1.*np.isclose(lab_back,lab,atol=2.)
97+
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
98+
return (im2tensor(rgb_back),mask)
99+
else:
100+
return im2tensor(rgb_back)
101+
102+
def rgb2lab(input):
103+
from skimage import color
104+
return color.rgb2lab(input / 255.)
105+
106+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
107+
image_numpy = image_tensor[0].cpu().float().numpy()
108+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
109+
return image_numpy.astype(imtype)
110+
111+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
112+
return torch.Tensor((image / factor - cent)
113+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
114+
115+
def tensor2vec(vector_tensor):
116+
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
117+
118+
def voc_ap(rec, prec, use_07_metric=False):
119+
""" ap = voc_ap(rec, prec, [use_07_metric])
120+
Compute VOC AP given precision and recall.
121+
If use_07_metric is true, uses the
122+
VOC 07 11 point method (default:False).
123+
"""
124+
if use_07_metric:
125+
# 11 point metric
126+
ap = 0.
127+
for t in np.arange(0., 1.1, 0.1):
128+
if np.sum(rec >= t) == 0:
129+
p = 0
130+
else:
131+
p = np.max(prec[rec >= t])
132+
ap = ap + p / 11.
133+
else:
134+
# correct AP calculation
135+
# first append sentinel values at the end
136+
mrec = np.concatenate(([0.], rec, [1.]))
137+
mpre = np.concatenate(([0.], prec, [0.]))
138+
139+
# compute the precision envelope
140+
for i in range(mpre.size - 1, 0, -1):
141+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
142+
143+
# to calculate area under PR curve, look for points
144+
# where X axis (recall) changes value
145+
i = np.where(mrec[1:] != mrec[:-1])[0]
146+
147+
# and sum (\Delta recall) * prec
148+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
149+
return ap
150+
151+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
152+
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
153+
image_numpy = image_tensor[0].cpu().float().numpy()
154+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
155+
return image_numpy.astype(imtype)
156+
157+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
158+
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
159+
return torch.Tensor((image / factor - cent)
160+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))

lpips/base_model.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import torch
3+
from torch.autograd import Variable
4+
from pdb import set_trace as st
5+
from IPython import embed
6+
7+
class BaseModel():
8+
def __init__(self):
9+
pass;
10+
11+
def name(self):
12+
return 'BaseModel'
13+
14+
def initialize(self, use_gpu=True, gpu_ids=[0]):
15+
self.use_gpu = use_gpu
16+
self.gpu_ids = gpu_ids
17+
18+
def forward(self):
19+
pass
20+
21+
def get_image_paths(self):
22+
pass
23+
24+
def optimize_parameters(self):
25+
pass
26+
27+
def get_current_visuals(self):
28+
return self.input
29+
30+
def get_current_errors(self):
31+
return {}
32+
33+
def save(self, label):
34+
pass
35+
36+
# helper saving function that can be used by subclasses
37+
def save_network(self, network, path, network_label, epoch_label):
38+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
39+
save_path = os.path.join(path, save_filename)
40+
torch.save(network.state_dict(), save_path)
41+
42+
# helper loading function that can be used by subclasses
43+
def load_network(self, network, network_label, epoch_label):
44+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
45+
save_path = os.path.join(self.save_dir, save_filename)
46+
print('Loading network from %s'%save_path)
47+
network.load_state_dict(torch.load(save_path))
48+
49+
def update_learning_rate():
50+
pass
51+
52+
def get_image_paths(self):
53+
return self.image_paths
54+
55+
def save_done(self, flag=False):
56+
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
57+
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
58+

0 commit comments

Comments
 (0)