Skip to content

Commit 6e8cf5d

Browse files
committed
Refactor code and add dispatch where both A and b are dual
1 parent be01fdc commit 6e8cf5d

File tree

2 files changed

+66
-24
lines changed

2 files changed

+66
-24
lines changed

Diff for: ext/LinearSolveForwardDiff.jl

+26-10
Original file line numberDiff line numberDiff line change
@@ -17,48 +17,64 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
1717
LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cache.cacheval.factors), cache.cacheval.ipiv, cache.cacheval.info)
1818
end : cache.cacheval
1919
cache2 = remake(cache; A, b, u, reltol, abstol, cacheval)
20-
res = LinearSolve.solve!(cache2, alg, kwargs...)
20+
res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy
2121
dresus = reduce(hcat, map(dAs, dbs) do dA, db
2222
cache2.b = db - dA * res.u
2323
dres = LinearSolve.solve!(cache2, alg, kwargs...)
2424
deepcopy(dres.u)
2525
end)
26-
# display(dresus)
2726
d = Dual{T}.(res.u, Tuple.(eachrow(dresus)))
2827
LinearSolve.SciMLBase.build_linear_solution(alg, d, nothing, cache; retcode=res.retcode, iters=res.iters, stats=res.stats)
2928
end
3029

30+
3131
function LinearSolve.solve!(
32-
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}},
32+
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:AbstractFloat}},
3333
alg::LinearSolve.AbstractFactorization;
3434
kwargs...
3535
) where {T, V, P}
3636
@info "using solve! df/dA"
3737
dAs = begin
38-
dAs_ = ForwardDiff.partials.(cache.A)
39-
dAs_ = collect.(dAs_)
40-
dAs_ = [getindex.(dAs_, i) for i in 1:length(first(dAs_))]
38+
t = collect.(ForwardDiff.partials.(cache.A))
39+
[getindex.(t, i) for i in 1:P]
4140
end
4241
dbs = [zero(cache.b) for _=1:P]
4342
A = ForwardDiff.value.(cache.A)
4443
b = cache.b
4544
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
4645
end
4746
function LinearSolve.solve!(
48-
cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}},
47+
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:AbstractFloat},<:AbstractArray{<:Dual{T,V,P}}},
4948
alg::LinearSolve.AbstractFactorization;
5049
kwargs...
5150
) where {T, V, P, A_}
5251
@info "using solve! df/db"
5352
dAs = [zero(cache.A) for _=1:P]
5453
dbs = begin
55-
dbs_ = ForwardDiff.partials.(cache.b)
56-
dbs_ = collect.(dbs_)
57-
dbs_ = [getindex.(dbs_, i) for i in 1:length(first(dbs_))]
54+
t = collect.(ForwardDiff.partials.(cache.b))
55+
[getindex.(t, i) for i in 1:P]
5856
end
5957
A = cache.A
6058
b = ForwardDiff.value.(cache.b)
6159
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
6260
end
61+
function LinearSolve.solve!(
62+
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}},
63+
alg::LinearSolve.AbstractFactorization;
64+
kwargs...
65+
) where {T, V, P}
66+
@info "using solve! df/dAb"
67+
dAs = begin
68+
t = collect.(ForwardDiff.partials.(cache.A))
69+
[getindex.(t, i) for i in 1:P]
70+
end
71+
dbs = begin
72+
t = collect.(ForwardDiff.partials.(cache.b))
73+
[getindex.(t, i) for i in 1:P]
74+
end
75+
A = ForwardDiff.value.(cache.A)
76+
b = ForwardDiff.value.(cache.b)
77+
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
78+
end
6379

6480
end # module

Diff for: test/forwarddiff.jl

+40-14
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ n = 4
1010
A = rand(n, n);
1111
dA = zeros(n, n);
1212
b1 = rand(n);
13-
alg = LUFactorization()
14-
# for alg in (
15-
# LUFactorization(),
16-
# # RFLUFactorization(),
17-
# # KrylovJL_GMRES(),
18-
# )
13+
# alg = LUFactorization()
14+
for alg in (
15+
LUFactorization(),
16+
RFLUFactorization(),
17+
KrylovJL_GMRES(),
18+
)
1919
alg_str = string(alg)
2020
@show alg_str
2121
function fb(b)
@@ -51,19 +51,45 @@ alg = LUFactorization()
5151
sum(sol1.u)
5252
end
5353
fA(A)
54-
db = zero(b1)
55-
manual_jac = map(onehot(A)) do dA
56-
y = A \ b1
57-
sum(inv(A) * (db - dA*y))
58-
end |> collect
59-
display(reduce(hcat, manual_jac))
54+
# db = zero(b1)
55+
# manual_jac = map(onehot(A)) do dA
56+
# y = A \ b1
57+
# t = inv(A) * (db - dA*y)
58+
# end |> collect
59+
# display(reduce(hcat, manual_jac))
6060

6161
fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
6262
@show fid_jac
6363

64-
# @test_throws MethodError fod_jac = ForwardDiff.gradient(fA, A) |> vec
6564
fod_jac = ForwardDiff.gradient(fA, A) |> vec
6665
@show fod_jac
6766

6867
@test fod_jac fid_jac rtol=1e-6
69-
# end
68+
69+
70+
# function fAb(Ab)
71+
# A = Ab[:, 1:n]
72+
# b1 = Ab[:, n+1]
73+
# prob = LinearProblem(A, b1)
74+
75+
# sol1 = solve(prob, alg)
76+
77+
# sum(sol1.u)
78+
# end
79+
# fAb(hcat(A, b1))
80+
# # db = zero(b1)
81+
# # manual_jac = map(onehot(A)) do dA
82+
# # y = A \ b1
83+
# # t = inv(A) * (db - dA*y)
84+
# # end |> collect
85+
# # display(reduce(hcat, manual_jac))
86+
87+
# fid_jac = FiniteDiff.finite_difference_jacobian(fAb, hcat(A, b1)) |> vec
88+
# @show fid_jac
89+
90+
# fod_jac = ForwardDiff.gradient(fAb, hcat(A, b1)) |> vec
91+
# @show fod_jac
92+
93+
# @test fod_jac ≈ fid_jac rtol=1e-6
94+
95+
end

0 commit comments

Comments
 (0)