Skip to content

Commit 44d08b2

Browse files
authored
implement generic ROF model using Chambolle04 primal-dual method (#233)
1 parent 1bbf666 commit 44d08b2

11 files changed

+394
-4
lines changed

.gitignore

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
.vscode
55
docs/build
66
docs/site
7-
docs/Manifest.toml
87
docs/src/democards
9-
/Manifest.toml
8+
Manifest.toml
109
/.benchmarkci
1110
/benchmark/*.json

Project.toml

+7-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
99
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1010
FFTViews = "4f61f5a4-77b1-5117-aa51-3ab5ef4ef0cd"
1111
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
12+
ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e"
1213
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
@@ -24,6 +25,7 @@ ComputationalResources = "0.3"
2425
DataStructures = "0.17.7, 0.18"
2526
FFTViews = "0.3"
2627
FFTW = "0.3, 1"
28+
ImageBase = "0.1.5"
2729
ImageCore = "0.9"
2830
OffsetArrays = "1.9"
2931
Reexport = "1.1"
@@ -33,10 +35,14 @@ julia = "1"
3335

3436
[extras]
3537
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
38+
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
39+
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
3640
ImageMetadata = "bc367c6b-8a6b-528e-b4bd-a4b897500b49"
41+
ImageQualityIndexes = "2996bd0c-7a13-11e9-2da2-2f5ce47296a9"
3742
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
3843
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3944
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
45+
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
4046

4147
[targets]
42-
test = ["AxisArrays", "ImageMetadata", "Logging", "Random", "Test"]
48+
test = ["AxisArrays", "ImageIO", "ImageMagick", "ImageMetadata", "ImageQualityIndexes", "Logging", "Random", "Test", "TestImages"]

benchmark/benchmarks.jl

+12
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ImageFiltering, ImageCore
22
using PkgBenchmark
33
using BenchmarkTools
44
using Statistics: quantile, mean, median!
5+
using ImageFiltering.Models
56

67
function makeimages(sz)
78
imgF32 = rand(Float32, sz)
@@ -56,3 +57,14 @@ let grp = SUITE["imfilter"]
5657
end
5758
end
5859
end
60+
61+
62+
SUITE["ROF"] = BenchmarkGroup()
63+
let grp = SUITE["ROF"]
64+
for sz in ((100, 100), (256, 256), (2048, 2048), (256, 256, 30))
65+
for (aname, img) in makeimages(sz)
66+
szstr = sz2str(sz)
67+
grp["PrimalDual"*"_"*aname*"_"*szstr] = @benchmarkable solve_ROF_PD($img, 0.1, 10)
68+
end
69+
end
70+
end

docs/src/function_reference.md

+6
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ Algorithm.IIR
7272
Algorithm.Mixed
7373
```
7474

75+
# Solvers for predefined models
76+
77+
```@autodocs
78+
Modules = [ImageFiltering.Models]
79+
```
80+
7581
# Internal machinery
7682

7783
```@docs

src/ImageFiltering.jl

+2
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ include("mapwindow.jl")
9292
using .MapWindow
9393
include("extrema.jl")
9494

95+
include("models.jl")
96+
9597
function __init__()
9698
# See ComputationalResources README for explanation
9799
push!(LOAD_PATH, dirname(@__FILE__))

src/models.jl

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
module Models
2+
3+
using ImageBase
4+
using ImageBase.ImageCore.MappedArrays: of_eltype
5+
using ImageBase.FiniteDiff
6+
7+
# Introduced in ColorVectorSpace v0.9.3
8+
# https://github.com/JuliaGraphics/ColorVectorSpace.jl/pull/172
9+
using ImageBase.ImageCore.ColorVectorSpace.Future: abs2
10+
11+
"""
12+
This submodule provides predefined image-related models and its solvers that can be reused
13+
by many image processing tasks.
14+
15+
- solve the Rudin Osher Fatemi (ROF) model using the primal-dual method: [`solve_ROF_PD`](@ref) and [`solve_ROF_PD!`](@ref)
16+
"""
17+
Models
18+
19+
export solve_ROF_PD, solve_ROF_PD!
20+
21+
22+
##### implementation details
23+
24+
"""
25+
solve_ROF_PD([T], img::AbstractArray, λ; kwargs...)
26+
27+
Return a smoothed version of `img`, using Rudin-Osher-Fatemi (ROF) filtering, more commonly
28+
known as Total Variation (TV) denoising or TV regularization. This algorithm is based on the
29+
primal-dual method.
30+
31+
This function applies to generic N-dimensional colorant array and is also CUDA-compatible.
32+
See also [`solve_ROF_PD!`](@ref) for the in-place version.
33+
34+
# Arguments
35+
36+
- `T`: the output element type. By default it is `float32(eltype(img))`.
37+
- `img`: the input image, usually a noisy image.
38+
- `λ`: the regularization coefficient. Larger `λ` results in more smoothing.
39+
40+
# Parameters
41+
42+
- `num_iters::Int`: The number of iterations before stopping.
43+
44+
# Examples
45+
46+
```julia
47+
using ImageFiltering
48+
using ImageFiltering.Models: solve_ROF_PD
49+
using ImageQualityIndexes
50+
using TestImages
51+
52+
img_ori = float.(testimage("cameraman"))
53+
img_noisy = img_ori .+ 0.1 .* randn(size(img_ori))
54+
assess_psnr(img_noisy, img_ori) # ~20 dB
55+
56+
img_smoothed = solve_ROF_PD(img_noisy, 0.015, 50)
57+
assess_psnr(img_smoothed, img_ori) # ~27 dB
58+
59+
# larger λ produces over-smoothed result
60+
img_smoothed = solve_ROF_PD(img_noisy, 5, 50)
61+
assess_psnr(img_smoothed, img_ori) # ~21 dB
62+
```
63+
64+
# Extended help
65+
66+
Mathematically, this function solves the following ROF model using the primal-dual method:
67+
68+
```math
69+
\\min_u \\lVert u - g \\rVert^2 + \\lambda\\lvert\\nabla u\\rvert
70+
```
71+
72+
# References
73+
74+
- [1] Chambolle, A. (2004). "An algorithm for total variation minimization and applications". _Journal of Mathematical Imaging and Vision_. 20: 89–97
75+
- [2] https://en.wikipedia.org/wiki/Total_variation_denoising
76+
"""
77+
solve_ROF_PD(img::AbstractArray{T}, args...) where T = solve_ROF_PD(float32(T), img, args...)
78+
function solve_ROF_PD(::Type{T}, img::AbstractArray, args...) where T
79+
u = similar(img, T)
80+
buffer = preallocate_solve_ROF_PD(T, img)
81+
solve_ROF_PD!(u, buffer, img, args...)
82+
end
83+
84+
# non-exported helper
85+
preallocate_solve_ROF_PD(img::AbstractArray{T}) where T = preallocate_solve_ROF_PD(float32(T), img)
86+
function preallocate_solve_ROF_PD(::Type{T}, img) where T
87+
div_p = similar(img, T)
88+
p = ntuple(i->similar(img, T), ndims(img))
89+
∇u = ntuple(i->similar(img, T), ndims(img))
90+
∇u_mag = similar(img, eltype(T))
91+
return div_p, p, ∇u, ∇u_mag
92+
end
93+
94+
"""
95+
solve_ROF_PD!(out, buffer, img, λ, num_iters)
96+
97+
The in-place version of [`solve_ROF_PD`](@ref).
98+
99+
It is not uncommon to use ROF solver in a higher-level loop, in which case it makes sense to
100+
preallocate the output and intermediate arrays to make it faster.
101+
102+
!!! note "Buffer"
103+
The content and meaning of `buffer` might change without any notice if the internal
104+
implementation is changed. Use `preallocate_solve_ROF_PD` helper function to avoid
105+
potential changes.
106+
107+
# Examples
108+
109+
```julia
110+
using ImageFiltering.Models: preallocate_solve_ROF_PD
111+
112+
out = similar(img)
113+
buffer = preallocate_solve_ROF_PD(img)
114+
solve_ROF_PD!(out, buffer, img, 0.2, 30)
115+
```
116+
117+
"""
118+
function solve_ROF_PD!(
119+
out::AbstractArray{T},
120+
buffer::Tuple,
121+
img::AbstractArray,
122+
λ::Real,
123+
num_iters::Integer) where T
124+
# seperate a stub method to reduce latency
125+
FT = float32(T)
126+
if FT == T
127+
solve_ROF_PD!(out, buffer, img, Float32(λ), Int(num_iters))
128+
else
129+
solve_ROF_PD!(out, buffer, FT.(img), Float32(λ), Int(num_iters))
130+
end
131+
end
132+
function solve_ROF_PD!(
133+
out::AbstractArray,
134+
(div_p, p, ∇u, ∇u_mag)::Tuple,
135+
img::AbstractArray,
136+
λ::Float32,
137+
num_iters::Int)
138+
# Total Variation regularized image denoising using the primal dual algorithm
139+
# Implement according to reference [1]
140+
τ = 1//4 # see 2nd remark after proof of Theorem 3.1.
141+
142+
# use the same symbol in the paper
143+
u, g = out, img
144+
145+
fgradient!(p, g)
146+
# This iterates Eq. (9) of [1]
147+
# TODO(johnnychen94): set better stop criterion
148+
for _ in 1:num_iters
149+
fdiv!(div_p, p)
150+
# multiply term inside ∇ by -λ. Thm. 3.1 relates this to `u` via Eq. 7.
151+
@. u = g - λ*div_p
152+
fgradient!(∇u, u)
153+
_l2norm_vec!(∇u_mag, ∇u) # |∇(g - λdiv p)|
154+
# Eq. (9): update p
155+
for i in 1:length(p)
156+
@. p[i] = (p[i] -/λ)*∇u[i])/(1 +/λ) * ∇u_mag)
157+
end
158+
end
159+
return u
160+
end
161+
162+
function _l2norm_vec!(out, Vs::Tuple)
163+
all(v->axes(out) == axes(v), Vs) || throw(ArgumentError("All axes of input data should be the same."))
164+
@. out = abs2(Vs[1])
165+
for v in Vs[2:end]
166+
@. out += abs2(v)
167+
end
168+
@. out = sqrt(out)
169+
return out
170+
end
171+
172+
173+
end # module

test/cuda/Project.toml

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e"
4+
ImageFiltering = "6a3955dd-da59-5b1f-98d4-e7296123deb5"
5+
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
6+
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
7+
ImageQualityIndexes = "2996bd0c-7a13-11e9-2da2-2f5ce47296a9"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10+
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"

test/cuda/models.jl

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using ImageFiltering.Models
2+
3+
@testset "solve_ROF_PD" begin
4+
# This testset is modified from its CPU version
5+
6+
@testset "Numerical" begin
7+
# 2D Gray
8+
img = restrict(testimage("cameraman"))
9+
img_noisy = img .+ 0.05randn(MersenneTwister(0), size(img))
10+
img_smoothed = solve_ROF_PD(img_noisy, 0.05, 20)
11+
@test ndims(img_smoothed) == 2
12+
@test eltype(img_smoothed) <: Gray
13+
@test assess_psnr(img_smoothed, img) > 31.67
14+
@test assess_ssim(img_smoothed, img) > 0.90
15+
16+
img_noisy_cu = CuArray(float32.(img_noisy))
17+
img_smoothed_cu = solve_ROF_PD(img_noisy_cu, 0.05, 20)
18+
@test img_smoothed_cu isa CuArray
19+
@test eltype(eltype(img_smoothed_cu)) == Float32
20+
@test Array(img_smoothed_cu) img_smoothed
21+
22+
# 2D RGB
23+
img = restrict(testimage("lighthouse"))
24+
img_noisy = img .+ colorview(RGB, ntuple(i->0.05.*randn(MersenneTwister(i), size(img)), 3)...)
25+
img_smoothed = solve_ROF_PD(img_noisy, 0.03, 20)
26+
@test ndims(img_smoothed) == 2
27+
@test eltype(img_smoothed) <: RGB
28+
@test assess_psnr(img_smoothed, img) > 32.15
29+
@test assess_ssim(img_smoothed, img) > 0.90
30+
31+
img_noisy_cu = CuArray(float32.(img_noisy))
32+
img_smoothed_cu = solve_ROF_PD(img_noisy_cu, 0.03, 20)
33+
@test img_smoothed_cu isa CuArray
34+
@test eltype(eltype(img_smoothed_cu)) == Float32
35+
@test Array(img_smoothed_cu) img_smoothed
36+
37+
# 3D Gray
38+
img = Gray.(restrict(testimage("mri"), (1, 2)))
39+
img_noisy = img .+ 0.05randn(MersenneTwister(0), size(img))
40+
img_smoothed = solve_ROF_PD(img_noisy, 0.02, 20)
41+
@test ndims(img_smoothed) == 3
42+
@test eltype(img_smoothed) <: Gray
43+
@test assess_psnr(img_smoothed, img) > 31.78
44+
@test assess_ssim(img_smoothed, img) > 0.85
45+
46+
img_noisy_cu = CuArray(float32.(img_noisy))
47+
img_smoothed_cu = solve_ROF_PD(img_noisy_cu, 0.02, 20)
48+
@test img_smoothed_cu isa CuArray
49+
@test eltype(eltype(img_smoothed_cu)) == Float32
50+
@test Array(img_smoothed_cu) img_smoothed
51+
52+
# 3D RGB
53+
img = RGB.(restrict(testimage("mri"), (1, 2)))
54+
img_noisy = img .+ colorview(RGB, ntuple(i->0.05.*randn(MersenneTwister(i), size(img)), 3)...)
55+
img_smoothed = solve_ROF_PD(img_noisy, 0.02, 20)
56+
@test ndims(img_smoothed) == 3
57+
@test eltype(img_smoothed) <: RGB
58+
@test assess_psnr(img_smoothed, img) > 31.17
59+
@test assess_ssim(img_smoothed, img) > 0.79
60+
61+
img_noisy_cu = CuArray(float32.(img_noisy))
62+
img_smoothed_cu = solve_ROF_PD(img_noisy_cu, 0.02, 20)
63+
@test img_smoothed_cu isa CuArray
64+
@test eltype(eltype(img_smoothed_cu)) == Float32
65+
@test Array(img_smoothed_cu) img_smoothed
66+
end
67+
end

test/cuda/runtests.jl

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# This file is maintained in a way to support CUDA-only test via
2+
# `julia --project=test/cuda -e 'include("runtests.jl")'`
3+
using ImageFiltering
4+
using CUDA
5+
using TestImages
6+
using ImageBase
7+
using ImageQualityIndexes
8+
using Test
9+
using Random
10+
11+
CUDA.allowscalar(false)
12+
13+
@testset "ImageFiltering" begin
14+
if CUDA.functional()
15+
include("models.jl")
16+
end
17+
end

0 commit comments

Comments
 (0)