Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for histogram matching #129

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions fast_neural_style/HistoLoss.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
require 'torch'
require 'nn'

local threads = require 'threads'

local function makeCdfInv(img, bins)
-- things we'll need later.
local imgmin = img:min()
local imgmax = img:max()-img:min()
local cdfinv = torch.zeros(bins)
local cdfinvcount = torch.ones(bins)
local cdfsum = 0

local imgview = img:view(-1)

-- calculate histogram
local hist = torch.histc(img, bins)

-- calculate probability density function...
local pmf = hist:div(img:nElement())

-- ... and use that to generate a cumulative density function.
local cdf = pmf:apply(function(x)
cdfsum = cdfsum + x
return cdfsum
end
)

-- we then scale and floor the CDF for generating the inverse CDF
cdf:mul(bins-1):floor()

-- and then generate the inverse cdf.
imgview:apply(function(x)
local y = math.floor(((x-imgmin)/(imgmax+1e-11))*(bins-1)+1)
y = cdf[y]+1
cdfinv[y] = cdfinv[y] + x
cdfinvcount[y] = cdfinvcount[y] + 1
end
)
cdfinv:cdiv(cdfinvcount)

-- to improve results, replace all unfilled inverse CDF bins with linear interpolated values.
cdfinv[bins] = cdfinv:max()
if math.ceil(cdfinv:max()) ~= 0 then
for i = 2, cdfinv:size()[1] do
local count = 1
local temp1 = temp1 or cdfinv[i-1]
local temp2 = 0
if cdfinv[i] == 0 then
while cdfinv[i-1+count] == 0 do
count = count + 1
temp2 = cdfinv[i-1+count]
end
if count < 2 then
end
cdfinv[i] = temp1*(1/count)+temp2*(1-(1/count))
else
temp1 = cdfinv[i]
end
end
end

return cdfinv
end

local function histoMatch(img, cdfinv, bins)
-- things we'll need later.
local imgmin = img:min()
local imgmax = img:max()-img:min()
local cdfsum = 0

local imgview = img:view(-1)

-- calculate histogram
local hist = torch.histc(img, bins)

-- calculate probability density function...
local pmf = hist:div(img:nElement())

-- ... and use that to generate a cumulative density function.
local cdf = pmf:apply(function(x)
cdfsum = cdfsum + x
return cdfsum
end
)
-- finally, we use the generated CDF to match the histograms.
local function invert(img)
img = math.floor(((img-imgmin)/(imgmax+1e-11))*(bins-1)+1)
img = math.floor(cdf[img]*(bins-1)+1)
return cdfinv[img]
end
imgview:apply(invert)

return(img)
end

local HistoLoss, parent = torch.class('nn.HistoLoss', 'nn.Module')

function HistoLoss:__init(strength, bins, n_threads)
parent.__init(self)
self.strength = strength
self.target = nil
self.loss = 0
self.bins = bins
self.mode = 'none'
self.H = nil
self.crit = nn.MSECriterion()
self.crit.sizeAverage = true
self.threads = n_threads or 6
end

function HistoLoss:updateOutput(input)
-- since creating an opencl/CUDA kernel for this is non-trivial,
-- instead i've chosen to thread the fuck out of it.
local pool = threads.Threads(self.threads)
local bins_thread = self.bins
if self.mode == 'capture' then
self.target = torch.Tensor(input:size()[1], input:size()[2], self.bins)
for i = 1, input:size()[1] do
for j = 1, input:size()[2] do
local input_thread = input[i][j]:clone()
local target_thread = self.target[i][j]:clone()
pool:addjob(
function()
target_thread = makeCdfInv(input_thread, bins_thread)
return target_thread
end,

function(target_thread)
self.target[i][j] = target_thread
end
)
end
end
pool:synchronize()
pool:terminate()
elseif self.mode == 'loss' then
self.H = input:clone()
for i = 1, input:size()[1] do
for j = 1, input:size()[2] do
local target_thread = self.target[1][j]:clone()
local input_thread = input[i][j]:clone()
local H_thread = self.H[i][j]:clone()
pool:addjob(
function()
H_thread = histoMatch(input_thread, target_thread, bins_thread)
return H_thread
end,
function(H_thread)
self.H[i][j] = H_thread
end
)
end
end
pool:synchronize()
pool:terminate()
self.loss = self.crit:forward(input, self.H)
self.loss = self.loss * self.strength
end
self.output = input
return self.output
end

function HistoLoss:updateGradInput(input, gradOutput)
if self.mode == 'capture' or self.mode == 'none' then
self.gradInput = gradOutput
elseif self.mode == 'loss' then
self.gradInput = self.crit:backward(input, self.H)
self.gradInput:mul(self.strength)
self.gradInput:add(gradOutput)
end
return self.gradInput
end

function HistoLoss:setMode(mode)
if mode ~= 'capture' and mode ~= 'loss' and mode ~= 'none' then
error(string.format('Invalid mode "%s"', mode))
end
self.mode = mode
end
36 changes: 33 additions & 3 deletions fast_neural_style/PerceptualCriterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ require 'nn'
require 'fast_neural_style.ContentLoss'
require 'fast_neural_style.StyleLoss'
require 'fast_neural_style.DeepDreamLoss'
require 'fast_neural_style.HistoLoss'

local layer_utils = require 'fast_neural_style.layer_utils'

Expand All @@ -27,12 +28,14 @@ Input: args is a table with the following keys:
function crit:__init(args)
args.content_layers = args.content_layers or {}
args.style_layers = args.style_layers or {}
args.histo_layers = args.histo_layers or {}
args.deepdream_layers = args.deepdream_layers or {}

self.net = args.cnn
self.net:evaluate()
self.content_loss_layers = {}
self.style_loss_layers = {}
self.histo_loss_layers = {}
self.deepdream_loss_layers = {}

-- Set up content loss layers
Expand All @@ -50,6 +53,14 @@ function crit:__init(args)
layer_utils.insert_after(self.net, layer_string, style_loss_layer)
table.insert(self.style_loss_layers, style_loss_layer)
end

-- Set up histo loss layers
for i, layer_string in ipairs(args.histo_layers) do
local weight = args.histo_weights[i]
local histo_loss_layers = nn.HistoLoss(weight, args.histo_bins, args.histo_threads)
layer_utils.insert_after(self.net, layer_string, histo_loss_layers)
table.insert(self.histo_loss_layers, histo_loss_layers)
end

-- Set up DeepDream layers
for i, layer_string in ipairs(args.deepdream_layers) do
Expand All @@ -75,17 +86,22 @@ function crit:setStyleTarget(target)
for i, style_loss_layer in ipairs(self.style_loss_layers) do
style_loss_layer:setMode('capture')
end
for i, histo_loss_layer in ipairs(self.histo_loss_layers) do
histo_loss_layer:setMode('capture')
end
self.net:forward(target)
end


--[[
target: Tensor of shape (N, 3, H, W) giving pixels for content target images
--]]
function crit:setContentTarget(target)
for i, style_loss_layer in ipairs(self.style_loss_layers) do
style_loss_layer:setMode('none')
end
for i, histo_loss_layer in ipairs(self.histo_loss_layers) do
histo_loss_layer:setMode('none')
end
for i, content_loss_layer in ipairs(self.content_loss_layers) do
content_loss_layer:setMode('capture')
end
Expand All @@ -106,6 +122,11 @@ function crit:setContentWeight(weight)
end
end

function crit:setHistoWeight(weight)
for i, histo_loss_layer in ipairs(self.histo_loss_layers) do
histo_loss_layer.strength = weight
end
end

--[[
Inputs:
Expand All @@ -119,7 +140,7 @@ function crit:updateOutput(input, target)
self:setContentTarget(target.content_target)
end
if target.style_target then
self.setStyleTarget(target.style_target)
self:setStyleTarget(target.style_target)
end

-- Make sure to set all content and style loss layers to loss mode before
Expand All @@ -130,6 +151,9 @@ function crit:updateOutput(input, target)
for i, style_loss_layer in ipairs(self.style_loss_layers) do
style_loss_layer:setMode('loss')
end
for i, histo_loss_layer in ipairs(self.histo_loss_layers) do
histo_loss_layer:setMode('loss')
end

local output = self.net:forward(input)

Expand All @@ -141,6 +165,8 @@ function crit:updateOutput(input, target)
self.content_losses = {}
self.total_style_loss = 0
self.style_losses = {}
self.total_histo_loss = 0
self.histo_losses = {}
for i, content_loss_layer in ipairs(self.content_loss_layers) do
self.total_content_loss = self.total_content_loss + content_loss_layer.loss
table.insert(self.content_losses, content_loss_layer.loss)
Expand All @@ -149,8 +175,12 @@ function crit:updateOutput(input, target)
self.total_style_loss = self.total_style_loss + style_loss_layer.loss
table.insert(self.style_losses, style_loss_layer.loss)
end
for i, histo_loss_layer in ipairs(self.histo_loss_layers) do
self.total_histo_loss = self.total_histo_loss + histo_loss_layer.loss
table.insert(self.histo_losses, histo_loss_layer.loss)
end

self.output = self.total_style_loss + self.total_content_loss
self.output = self.total_style_loss + self.total_content_loss + self.total_histo_loss
return self.output
end

Expand Down
10 changes: 10 additions & 0 deletions slow_neural_style.lua
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ cmd:option('-content_layers', '16')
cmd:option('-style_weights', '5.0')
cmd:option('-style_layers', '4,9,16,23')
cmd:option('-style_image_size', 512)
cmd:option('-histo_weights', '5.0')
cmd:option('-histo_layers', '2,21')
cmd:option('-histo_bins', 256)
cmd:option('-histo_threads', 4)

-- Options for DeepDream
cmd:option('-deepdream_layers', '')
Expand Down Expand Up @@ -80,6 +84,8 @@ local function main()
print(loss_net)
local style_layers, style_weights =
utils.parse_layers(opt.style_layers, opt.style_weights)
local histo_layers, histo_weights =
utils.parse_layers(opt.histo_layers, opt.histo_weights)
local content_layers, content_weights =
utils.parse_layers(opt.content_layers, opt.content_weights)
local deepdream_layers, deepdream_weights =
Expand All @@ -88,6 +94,10 @@ local function main()
cnn = loss_net,
style_layers = style_layers,
style_weights = style_weights,
histo_layers = histo_layers,
histo_weights = histo_weights,
histo_bins = opt.histo_bins,
histo_threads = opt.histo_threads,
content_layers = content_layers,
content_weights = content_weights,
deepdream_layers = deepdream_layers,
Expand Down
17 changes: 17 additions & 0 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ cmd:option('-style_image_size', 256)
cmd:option('-style_weights', '5.0')
cmd:option('-style_layers', '4,9,16,23')
cmd:option('-style_target_type', 'gram', 'gram|mean')
cmd:option('-histo_weights', '5.0')
cmd:option('-histo_layers', '4,23')
cmd:option('-histo_bins', 256)
cmd:option('-histo_threads', 4)

-- Upsampling options
cmd:option('-upsample_factor', 4)
Expand Down Expand Up @@ -75,6 +79,8 @@ cmd:option('-backend', 'cuda', 'cuda|opencl')
utils.parse_layers(opt.content_layers, opt.content_weights)
opt.style_layers, opt.style_weights =
utils.parse_layers(opt.style_layers, opt.style_weights)
opt.histo_layers, opt.histo_weights =
utils.parse_layers(opt.histo_layers, opt.histo_weights)

-- Figure out preprocessing
if not preprocess[opt.preprocessing] then
Expand Down Expand Up @@ -119,6 +125,10 @@ cmd:option('-backend', 'cuda', 'cuda|opencl')
cnn = loss_net,
style_layers = opt.style_layers,
style_weights = opt.style_weights,
histo_layers = opt.histo_layers,
histo_weights = opt.histo_weights,
histo_bins = opt.histo_bins,
histo_threads = opt.histo_threads,
content_layers = opt.content_layers,
content_weights = opt.content_weights,
agg_type = opt.style_target_type,
Expand Down Expand Up @@ -227,6 +237,9 @@ cmd:option('-backend', 'cuda', 'cuda|opencl')
for i, k in ipairs(opt.style_layers) do
style_loss_history[string.format('style-%d', k)] = {}
end
for i, k in ipairs(opt.histo_layers) do
style_loss_history[string.format('histo-%d', k)] = {}
end
for i, k in ipairs(opt.content_layers) do
style_loss_history[string.format('content-%d', k)] = {}
end
Expand All @@ -245,6 +258,10 @@ cmd:option('-backend', 'cuda', 'cuda|opencl')
table.insert(style_loss_history[string.format('style-%d', k)],
percep_crit.style_losses[i])
end
for i, k in ipairs(opt.histo_layers) do
table.insert(style_loss_history[string.format('histo-%d', k)],
percep_crit.histo_losses[i])
end
for i, k in ipairs(opt.content_layers) do
table.insert(style_loss_history[string.format('content-%d', k)],
percep_crit.content_losses[i])
Expand Down