1
1
module LinearSolveForwardDiffExt
2
2
3
3
using LinearSolve
4
+ using LinearSolve: SciMLLinearSolveAlgorithm
4
5
using LinearAlgebra
5
6
using ForwardDiff
6
7
using ForwardDiff: Dual, Partials
@@ -36,8 +37,14 @@ const DualAbstractLinearProblem = Union{
36
37
LinearSolve. @concrete mutable struct DualLinearCache
37
38
linear_cache
38
39
dual_type
40
+
39
41
partials_A
40
42
partials_b
43
+ partials_u
44
+
45
+ dual_A
46
+ dual_b
47
+ dual_u
41
48
end
42
49
43
50
function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
@@ -55,16 +62,15 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
55
62
56
63
rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
57
64
58
- partial_cache = cache. linear_cache
59
- partial_cache. u = dual_u0
60
-
65
+ cache. linear_cache. u = dual_u0
66
+ # We can reuse the linear cache, because the same factorization will work for the partials.
61
67
for i in eachindex (rhs_list)
62
- partial_cache . b = rhs_list[i]
63
- rhs_list[i] = copy (solve! (partial_cache , alg, args... ; kwargs... ). u)
68
+ cache . linear_cache . b = rhs_list[i]
69
+ rhs_list[i] = copy (solve! (cache . linear_cache , alg, args... ; kwargs... ). u)
64
70
end
65
71
66
- # Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to
67
- partial_cache . b = primal_b
72
+ # Reset to the original `b` and `u` , users will expect that `b` doesn't change if they don't tell it to
73
+ cache . linear_cache . b = primal_b
68
74
69
75
partial_sols = rhs_list
70
76
@@ -96,35 +102,25 @@ function xp_linsolve_rhs(
96
102
b_list
97
103
end
98
104
99
- #=
100
- function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
101
- return solve(prob, nothing, args...; kwargs...)
102
- end
103
-
104
- function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
105
- assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
106
- return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
107
- end
108
-
109
- function SciMLBase.solve(prob::DualAbstractLinearProblem,
110
- alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
111
- solve!(init(prob, alg, args...; kwargs...))
112
- end
113
- =#
114
-
115
105
function linearsolve_dual_solution (
116
106
u:: Number , partials, dual_type)
117
107
return dual_type (u, partials)
118
108
end
119
109
120
- function linearsolve_dual_solution (
121
- u:: AbstractArray , partials, dual_type)
110
+ function linearsolve_dual_solution (u:: Number , partials,
111
+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V, P}
112
+ # Handle single-level duals
113
+ return dual_type (u, partials)
114
+ end
115
+
116
+ function linearsolve_dual_solution (u:: AbstractArray , partials,
117
+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V, P}
118
+ # Handle single-level duals for arrays
122
119
partials_list = RecursiveArrayTools. VectorOfArray (partials)
123
120
return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
124
- zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
121
+ zip (u, partials_list[i, :] for i in 1 : length (partials_list. u [1 ])))
125
122
end
126
123
127
- #=
128
124
function SciMLBase. init (
129
125
prob:: DualAbstractLinearProblem , alg:: LinearSolve.SciMLLinearSolveAlgorithm ,
130
126
args... ;
@@ -138,7 +134,6 @@ function SciMLBase.init(
138
134
assumptions = OperatorAssumptions (issquare (prob. A)),
139
135
sensealg = LinearSolveAdjoint (),
140
136
kwargs... )
141
-
142
137
(; A, b, u0, p) = prob
143
138
new_A = nodual_value (A)
144
139
new_b = nodual_value (b)
@@ -147,7 +142,6 @@ function SciMLBase.init(
147
142
∂_A = partial_vals (A)
148
143
∂_b = partial_vals (b)
149
144
150
- #primal_prob = LinearProblem(new_A, new_b, u0 = new_u0)
151
145
primal_prob = remake (prob; A = new_A, b = new_b, u0 = new_u0)
152
146
153
147
if get_dual_type (prob. A) != = nothing
@@ -156,48 +150,71 @@ function SciMLBase.init(
156
150
dual_type = get_dual_type (prob. b)
157
151
end
158
152
153
+ alg isa LinearSolve. DefaultLinearSolver ? real_alg = LinearSolve. defaultalg (primal_prob. A, primal_prob. b) : real_alg = alg
154
+
159
155
non_partial_cache = init (
160
- primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
156
+ primal_prob, real_alg, assumptions, args... ;
157
+ alias = alias, abstol = abstol, reltol = reltol,
161
158
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
162
159
sensealg = sensealg, u0 = new_u0, kwargs... )
163
- return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b)
160
+ return DualLinearCache (non_partial_cache, dual_type, ∂_A, ∂_b, ! isnothing (∂_b) ? zero .(∂_b) : ∂_b, A, b, zeros (dual_type, length (b)) )
164
161
end
165
162
166
163
function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
164
+ solve! (cache, cache. alg, args... ; kwargs... )
165
+ end
166
+
167
+ function SciMLBase. solve! (cache:: DualLinearCache , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... )
167
168
sol,
168
169
partials = linearsolve_forwarddiff_solve (
169
170
cache:: DualLinearCache , cache. alg, args... ; kwargs... )
170
-
171
171
dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type)
172
+
173
+ if cache. dual_u isa AbstractArray
174
+ cache. dual_u[:] = dual_sol
175
+ else
176
+ cache. dual_u = dual_sol
177
+ end
178
+
172
179
return SciMLBase. build_linear_solution (
173
180
cache. alg, dual_sol, sol. resid, cache; sol. retcode, sol. iters, sol. stats
174
181
)
175
182
end
176
- =#
177
183
178
184
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
179
- # Also "forwards" setproperty so that
180
185
function Base. setproperty! (dc:: DualLinearCache , sym:: Symbol , val)
181
186
# If the property is A or b, also update it in the LinearCache
182
187
if sym === :A || sym === :b || sym === :u
183
188
setproperty! (dc. linear_cache, sym, nodual_value (val))
189
+ elseif hasfield (DualLinearCache, sym)
190
+ setfield! (dc, sym, val)
184
191
elseif hasfield (LinearSolve. LinearCache, sym)
185
192
setproperty! (dc. linear_cache, sym, val)
186
193
end
187
194
195
+
188
196
# Update the partials if setting A or b
189
197
if sym === :A
198
+ setfield! (dc, :dual_A , val)
190
199
setfield! (dc, :partials_A , partial_vals (val))
191
- elseif sym === :b
200
+ elseif sym === :b
201
+ setfield! (dc, :dual_b , val)
192
202
setfield! (dc, :partials_b , partial_vals (val))
193
- else
194
- setfield! (dc, sym, val)
203
+ elseif sym === :u
204
+ setfield! (dc, :dual_u , val)
205
+ setfield! (dc, :partials_u , partial_vals (val))
195
206
end
196
207
end
197
208
198
209
# "Forwards" getproperty to LinearCache if necessary
199
210
function Base. getproperty (dc:: DualLinearCache , sym:: Symbol )
200
- if hasfield (LinearSolve. LinearCache, sym)
211
+ if sym === :A
212
+ dc. dual_A
213
+ elseif sym === :b
214
+ dc. dual_b
215
+ elseif sym === :u
216
+ dc. dual_u
217
+ elseif hasfield (LinearSolve. LinearCache, sym)
201
218
return getproperty (dc. linear_cache, sym)
202
219
else
203
220
return getfield (dc, sym)
@@ -206,31 +223,36 @@ end
206
223
207
224
208
225
209
- # Helper functions for Dual numbers
210
- get_dual_type (x:: Dual ) = typeof (x)
226
+ # Enhanced helper functions for Dual numbers to handle recursion
227
+ get_dual_type (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = typeof (x)
228
+ get_dual_type (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = typeof (x)
211
229
get_dual_type (x:: AbstractArray{<:Dual} ) = eltype (x)
212
230
get_dual_type (x) = nothing
213
231
214
- partial_vals (x:: Dual ) = ForwardDiff. partials (x)
232
+ # Add recursive handling for nested dual partials
233
+ partial_vals (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = ForwardDiff. partials (x)
234
+ partial_vals (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = ForwardDiff. partials (x)
215
235
partial_vals (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. partials, x)
216
236
partial_vals (x) = nothing
217
237
238
+ # Add recursive handling for nested dual values
218
239
nodual_value (x) = x
219
- nodual_value (x:: Dual ) = ForwardDiff. value (x)
220
- nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
240
+ nodual_value (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = ForwardDiff. value (x)
241
+ nodual_value (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = x. value # Keep the inner dual intact
242
+ nodual_value (x:: AbstractArray{<:Dual} ) = map (nodual_value, x)
221
243
222
244
223
- function partials_to_list (partial_matrix:: Vector )
245
+ function partials_to_list (partial_matrix:: AbstractVector{T} ) where {T}
224
246
p = eachindex (first (partial_matrix))
225
247
[[partial[i] for partial in partial_matrix] for i in p]
226
248
end
227
249
228
250
function partials_to_list (partial_matrix)
229
251
p = length (first (partial_matrix))
230
252
m, n = size (partial_matrix)
231
- res_list = fill (zeros (m, n), p)
253
+ res_list = fill (zeros (typeof (partial_matrix[ 1 , 1 ][ 1 ]), m, n), p)
232
254
for k in 1 : p
233
- res = zeros (m, n)
255
+ res = zeros (typeof (partial_matrix[ 1 , 1 ][ 1 ]), m, n)
234
256
for i in 1 : m
235
257
for j in 1 : n
236
258
res[i, j] = partial_matrix[i, j][k]
243
265
244
266
245
267
end
268
+
0 commit comments