@@ -17,48 +17,64 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
17
17
LinearSolve. LinearAlgebra. LU (ForwardDiff. value .(cache. cacheval. factors), cache. cacheval. ipiv, cache. cacheval. info)
18
18
end : cache. cacheval
19
19
cache2 = remake (cache; A, b, u, reltol, abstol, cacheval)
20
- res = LinearSolve. solve! (cache2, alg, kwargs... )
20
+ res = LinearSolve. solve! (cache2, alg, kwargs... ) |> deepcopy
21
21
dresus = reduce (hcat, map (dAs, dbs) do dA, db
22
22
cache2. b = db - dA * res. u
23
23
dres = LinearSolve. solve! (cache2, alg, kwargs... )
24
24
deepcopy (dres. u)
25
25
end )
26
- # display(dresus)
27
26
d = Dual {T} .(res. u, Tuple .(eachrow (dresus)))
28
27
LinearSolve. SciMLBase. build_linear_solution (alg, d, nothing , cache; retcode= res. retcode, iters= res. iters, stats= res. stats)
29
28
end
30
29
30
+
31
31
function LinearSolve. solve! (
32
- cache:: LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}} ,
32
+ cache:: LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:AbstractFloat} } ,
33
33
alg:: LinearSolve.AbstractFactorization ;
34
34
kwargs...
35
35
) where {T, V, P}
36
36
@info " using solve! df/dA"
37
37
dAs = begin
38
- dAs_ = ForwardDiff. partials .(cache. A)
39
- dAs_ = collect .(dAs_)
40
- dAs_ = [getindex .(dAs_, i) for i in 1 : length (first (dAs_))]
38
+ t = collect .(ForwardDiff. partials .(cache. A))
39
+ [getindex .(t, i) for i in 1 : P]
41
40
end
42
41
dbs = [zero (cache. b) for _= 1 : P]
43
42
A = ForwardDiff. value .(cache. A)
44
43
b = cache. b
45
44
_solve! (cache, alg, dAs, dbs, A, b, T; kwargs... )
46
45
end
47
46
function LinearSolve. solve! (
48
- cache:: LinearSolve.LinearCache{A_ ,<:AbstractArray{<:Dual{T,V,P}}} ,
47
+ cache:: LinearSolve.LinearCache{<:AbstractMatrix{<:AbstractFloat} ,<:AbstractArray{<:Dual{T,V,P}}} ,
49
48
alg:: LinearSolve.AbstractFactorization ;
50
49
kwargs...
51
50
) where {T, V, P, A_}
52
51
@info " using solve! df/db"
53
52
dAs = [zero (cache. A) for _= 1 : P]
54
53
dbs = begin
55
- dbs_ = ForwardDiff. partials .(cache. b)
56
- dbs_ = collect .(dbs_)
57
- dbs_ = [getindex .(dbs_, i) for i in 1 : length (first (dbs_))]
54
+ t = collect .(ForwardDiff. partials .(cache. b))
55
+ [getindex .(t, i) for i in 1 : P]
58
56
end
59
57
A = cache. A
60
58
b = ForwardDiff. value .(cache. b)
61
59
_solve! (cache, alg, dAs, dbs, A, b, T; kwargs... )
62
60
end
61
+ function LinearSolve. solve! (
62
+ cache:: LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}} ,
63
+ alg:: LinearSolve.AbstractFactorization ;
64
+ kwargs...
65
+ ) where {T, V, P}
66
+ @info " using solve! df/dAb"
67
+ dAs = begin
68
+ t = collect .(ForwardDiff. partials .(cache. A))
69
+ [getindex .(t, i) for i in 1 : P]
70
+ end
71
+ dbs = begin
72
+ t = collect .(ForwardDiff. partials .(cache. b))
73
+ [getindex .(t, i) for i in 1 : P]
74
+ end
75
+ A = ForwardDiff. value .(cache. A)
76
+ b = ForwardDiff. value .(cache. b)
77
+ _solve! (cache, alg, dAs, dbs, A, b, T; kwargs... )
78
+ end
63
79
64
80
end # module
0 commit comments