Skip to content

Commit 47eeb04

Browse files
Merge pull request #629 from jClugstor/forwarddiff_overloads
ForwardDiff Overload Fixes
2 parents 4cdaa15 + 62127b8 commit 47eeb04

File tree

2 files changed

+183
-49
lines changed

2 files changed

+183
-49
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 69 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LinearSolveForwardDiffExt
22

33
using LinearSolve
4+
using LinearSolve: SciMLLinearSolveAlgorithm
45
using LinearAlgebra
56
using ForwardDiff
67
using ForwardDiff: Dual, Partials
@@ -36,8 +37,14 @@ const DualAbstractLinearProblem = Union{
3637
LinearSolve.@concrete mutable struct DualLinearCache
3738
linear_cache
3839
dual_type
40+
3941
partials_A
4042
partials_b
43+
partials_u
44+
45+
dual_A
46+
dual_b
47+
dual_u
4148
end
4249

4350
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
@@ -55,16 +62,15 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5562

5663
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
5764

58-
partial_cache = cache.linear_cache
59-
partial_cache.u = dual_u0
60-
65+
cache.linear_cache.u = dual_u0
66+
# We can reuse the linear cache, because the same factorization will work for the partials.
6167
for i in eachindex(rhs_list)
62-
partial_cache.b = rhs_list[i]
63-
rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u)
68+
cache.linear_cache.b = rhs_list[i]
69+
rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)
6470
end
6571

66-
# Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to
67-
partial_cache.b = primal_b
72+
# Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
73+
cache.linear_cache.b = primal_b
6874

6975
partial_sols = rhs_list
7076

@@ -96,35 +102,25 @@ function xp_linsolve_rhs(
96102
b_list
97103
end
98104

99-
#=
100-
function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
101-
return solve(prob, nothing, args...; kwargs...)
102-
end
103-
104-
function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
105-
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
106-
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
107-
end
108-
109-
function SciMLBase.solve(prob::DualAbstractLinearProblem,
110-
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
111-
solve!(init(prob, alg, args...; kwargs...))
112-
end
113-
=#
114-
115105
function linearsolve_dual_solution(
116106
u::Number, partials, dual_type)
117107
return dual_type(u, partials)
118108
end
119109

120-
function linearsolve_dual_solution(
121-
u::AbstractArray, partials, dual_type)
110+
function linearsolve_dual_solution(u::Number, partials,
111+
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
112+
# Handle single-level duals
113+
return dual_type(u, partials)
114+
end
115+
116+
function linearsolve_dual_solution(u::AbstractArray, partials,
117+
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
118+
# Handle single-level duals for arrays
122119
partials_list = RecursiveArrayTools.VectorOfArray(partials)
123120
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
124-
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
121+
zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1])))
125122
end
126123

127-
#=
128124
function SciMLBase.init(
129125
prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm,
130126
args...;
@@ -138,7 +134,6 @@ function SciMLBase.init(
138134
assumptions = OperatorAssumptions(issquare(prob.A)),
139135
sensealg = LinearSolveAdjoint(),
140136
kwargs...)
141-
142137
(; A, b, u0, p) = prob
143138
new_A = nodual_value(A)
144139
new_b = nodual_value(b)
@@ -147,7 +142,6 @@ function SciMLBase.init(
147142
∂_A = partial_vals(A)
148143
∂_b = partial_vals(b)
149144

150-
#primal_prob = LinearProblem(new_A, new_b, u0 = new_u0)
151145
primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0)
152146

153147
if get_dual_type(prob.A) !== nothing
@@ -156,48 +150,71 @@ function SciMLBase.init(
156150
dual_type = get_dual_type(prob.b)
157151
end
158152

153+
alg isa LinearSolve.DefaultLinearSolver ? real_alg = LinearSolve.defaultalg(primal_prob.A, primal_prob.b) : real_alg = alg
154+
159155
non_partial_cache = init(
160-
primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
156+
primal_prob, real_alg, assumptions, args...;
157+
alias = alias, abstol = abstol, reltol = reltol,
161158
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
162159
sensealg = sensealg, u0 = new_u0, kwargs...)
163-
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b)
160+
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zeros(dual_type, length(b)))
164161
end
165162

166163
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
164+
solve!(cache, cache.alg, args...; kwargs...)
165+
end
166+
167+
function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
167168
sol,
168169
partials = linearsolve_forwarddiff_solve(
169170
cache::DualLinearCache, cache.alg, args...; kwargs...)
170-
171171
dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)
172+
173+
if cache.dual_u isa AbstractArray
174+
cache.dual_u[:] = dual_sol
175+
else
176+
cache.dual_u = dual_sol
177+
end
178+
172179
return SciMLBase.build_linear_solution(
173180
cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats
174181
)
175182
end
176-
=#
177183

178184
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
179-
# Also "forwards" setproperty so that
180185
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
181186
# If the property is A or b, also update it in the LinearCache
182187
if sym === :A || sym === :b || sym === :u
183188
setproperty!(dc.linear_cache, sym, nodual_value(val))
189+
elseif hasfield(DualLinearCache, sym)
190+
setfield!(dc, sym, val)
184191
elseif hasfield(LinearSolve.LinearCache, sym)
185192
setproperty!(dc.linear_cache, sym, val)
186193
end
187194

195+
188196
# Update the partials if setting A or b
189197
if sym === :A
198+
setfield!(dc, :dual_A, val)
190199
setfield!(dc, :partials_A, partial_vals(val))
191-
elseif sym === :b
200+
elseif sym === :b
201+
setfield!(dc, :dual_b, val)
192202
setfield!(dc, :partials_b, partial_vals(val))
193-
else
194-
setfield!(dc, sym, val)
203+
elseif sym === :u
204+
setfield!(dc, :dual_u, val)
205+
setfield!(dc, :partials_u, partial_vals(val))
195206
end
196207
end
197208

198209
# "Forwards" getproperty to LinearCache if necessary
199210
function Base.getproperty(dc::DualLinearCache, sym::Symbol)
200-
if hasfield(LinearSolve.LinearCache, sym)
211+
if sym === :A
212+
dc.dual_A
213+
elseif sym === :b
214+
dc.dual_b
215+
elseif sym === :u
216+
dc.dual_u
217+
elseif hasfield(LinearSolve.LinearCache, sym)
201218
return getproperty(dc.linear_cache, sym)
202219
else
203220
return getfield(dc, sym)
@@ -206,31 +223,36 @@ end
206223

207224

208225

209-
# Helper functions for Dual numbers
210-
get_dual_type(x::Dual) = typeof(x)
226+
# Enhanced helper functions for Dual numbers to handle recursion
227+
get_dual_type(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = typeof(x)
228+
get_dual_type(x::Dual{T, V, P}) where {T, V <: Dual, P} = typeof(x)
211229
get_dual_type(x::AbstractArray{<:Dual}) = eltype(x)
212230
get_dual_type(x) = nothing
213231

214-
partial_vals(x::Dual) = ForwardDiff.partials(x)
232+
# Add recursive handling for nested dual partials
233+
partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.partials(x)
234+
partial_vals(x::Dual{T, V, P}) where {T, V <: Dual, P} = ForwardDiff.partials(x)
215235
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
216236
partial_vals(x) = nothing
217237

238+
# Add recursive handling for nested dual values
218239
nodual_value(x) = x
219-
nodual_value(x::Dual) = ForwardDiff.value(x)
220-
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
240+
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
241+
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact
242+
nodual_value(x::AbstractArray{<:Dual}) = map(nodual_value, x)
221243

222244

223-
function partials_to_list(partial_matrix::Vector)
245+
function partials_to_list(partial_matrix::AbstractVector{T}) where {T}
224246
p = eachindex(first(partial_matrix))
225247
[[partial[i] for partial in partial_matrix] for i in p]
226248
end
227249

228250
function partials_to_list(partial_matrix)
229251
p = length(first(partial_matrix))
230252
m, n = size(partial_matrix)
231-
res_list = fill(zeros(m, n), p)
253+
res_list = fill(zeros(typeof(partial_matrix[1, 1][1]), m, n), p)
232254
for k in 1:p
233-
res = zeros(m, n)
255+
res = zeros(typeof(partial_matrix[1, 1][1]), m, n)
234256
for i in 1:m
235257
for j in 1:n
236258
res[i, j] = partial_matrix[i, j][k]
@@ -243,3 +265,4 @@ end
243265

244266

245267
end
268+

test/forwarddiff_overloads.jl

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using LinearSolve
22
using ForwardDiff
33
using Test
4+
using SparseArrays
45

56
function h(p)
67
(A = [p[1] p[2]+1 p[2]^3;
@@ -23,12 +24,11 @@ krylov_u0_sol = solve(krylov_prob, KrylovJL_GMRES())
2324

2425
@test (krylov_u0_sol, backslash_x_p, rtol = 1e-9)
2526

26-
2727
A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
2828
backslash_x_p = A \ [6.0, 10.0, 25.0]
2929
prob = LinearProblem(A, [6.0, 10.0, 25.0])
3030

31-
@test (solve(prob).u, backslash_x_p, rtol = 1e-9)
31+
@test (solve(prob).u, backslash_x_p, rtol = 1e-9)
3232
@test (solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9)
3333

3434
_, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
@@ -48,6 +48,9 @@ new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.
4848
cache.A = new_A
4949
cache.b = new_b
5050

51+
@test cache.A == new_A
52+
@test cache.b == new_b
53+
5154
x_p = solve!(cache)
5255
backslash_x_p = new_A \ new_b
5356

@@ -61,6 +64,7 @@ cache = init(prob)
6164

6265
new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
6366
cache.A = new_A
67+
@test cache.A == new_A
6468

6569
x_p = solve!(cache)
6670
backslash_x_p = new_A \ b
@@ -75,8 +79,115 @@ cache = init(prob)
7579

7680
_, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
7781
cache.b = new_b
82+
@test cache.b == new_b
7883

7984
x_p = solve!(cache)
8085
backslash_x_p = A \ new_b
8186

82-
@test (x_p, backslash_x_p, rtol = 1e-9)
87+
@test (x_p, backslash_x_p, rtol = 1e-9)
88+
89+
# Nested Duals
90+
function h(p)
91+
(A = [p[1] p[2]+1 p[2]^3;
92+
3*p[1] p[1]+5 p[2] * p[1]-4;
93+
p[2]^2 9*p[1] p[2]],
94+
b = [p[1] + 1, p[2] * 2, p[1]^2])
95+
end
96+
97+
A, b = h([ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 1.0, 0.0),
98+
ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 0.0, 1.0)])
99+
100+
prob = LinearProblem(A, b)
101+
overload_x_p = solve(prob)
102+
103+
original_x_p = A \ b
104+
105+
@test (overload_x_p, original_x_p, rtol = 1e-9)
106+
107+
prob = LinearProblem(A, b)
108+
cache = init(prob)
109+
110+
new_A, new_b = h([ForwardDiff.Dual(ForwardDiff.Dual(10.0, 1.0, 0.0), 1.0, 0.0),
111+
ForwardDiff.Dual(ForwardDiff.Dual(10.0, 1.0, 0.0), 0.0, 1.0)])
112+
113+
cache.A = new_A
114+
cache.b = new_b
115+
116+
@test cache.A == new_A
117+
@test cache.b == new_b
118+
119+
function linprob_f(p)
120+
A, b = h(p)
121+
prob = LinearProblem(A, b)
122+
solve(prob)
123+
end
124+
125+
function slash_f(p)
126+
A, b = h(p)
127+
A \ b
128+
end
129+
130+
@test (
131+
ForwardDiff.jacobian(slash_f, [5.0, 5.0]), ForwardDiff.jacobian(linprob_f, [5.0, 5.0]))
132+
133+
@test (ForwardDiff.jacobian(p -> ForwardDiff.jacobian(slash_f, [5.0, p[1]]), [5.0]),
134+
ForwardDiff.jacobian(p -> ForwardDiff.jacobian(linprob_f, [5.0, p[1]]), [5.0]))
135+
136+
function g(p)
137+
(A = [p[1] p[1]+1 p[1]^3;
138+
3*p[1] p[1]+5 p[1] * p[1]-4;
139+
p[1]^2 9*p[1] p[1]],
140+
b = [p[1] + 1, p[1] * 2, p[1]^2])
141+
end
142+
143+
function slash_f_hes(p)
144+
A, b = g(p)
145+
x = A \ b
146+
sum(x)
147+
end
148+
149+
function linprob_f_hes(p)
150+
A, b = g(p)
151+
prob = LinearProblem(A, b)
152+
x = solve(prob)
153+
sum(x)
154+
end
155+
156+
@test (ForwardDiff.hessian(slash_f_hes, [5.0]),
157+
ForwardDiff.hessian(linprob_f_hes, [5.0]))
158+
159+
# Test aliasing
160+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
161+
162+
prob = LinearProblem(A, b)
163+
cache = init(prob)
164+
165+
new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
166+
cache.A = new_A
167+
cache.b = new_b
168+
169+
linu = [ForwardDiff.Dual(0.0, 0.0, 0.0), ForwardDiff.Dual(0.0, 0.0, 0.0),
170+
ForwardDiff.Dual(0.0, 0.0, 0.0)]
171+
cache.u = linu
172+
x_p = solve!(cache)
173+
backslash_x_p = new_A \ new_b
174+
175+
@test linu == cache.u
176+
177+
# Test Float Only solvers
178+
179+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
180+
181+
prob = LinearProblem(sparse(A), sparse(b))
182+
overload_x_p = solve(prob, KLUFactorization())
183+
backslash_x_p = A \ b
184+
185+
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
186+
187+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
188+
189+
prob = LinearProblem(sparse(A), sparse(b))
190+
overload_x_p = solve(prob, UMFPACKFactorization())
191+
backslash_x_p = A \ b
192+
193+
@test (overload_x_p, backslash_x_p, rtol = 1e-9)

0 commit comments

Comments
 (0)