Skip to content

Add ForwardDiff rules #434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand All @@ -48,6 +49,7 @@ LinearSolveBandedMatricesExt = "BandedMatrices"
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveEnzymeExt = "Enzyme"
LinearSolveForwardDiff = "ForwardDiff"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand All @@ -66,6 +68,7 @@ DocStringExtensions = "0.9"
EnumX = "1"
EnzymeCore = "0.6"
FastLapackInterface = "2"
ForwardDiff = "0.10"
GPUArraysCore = "0.1"
HYPRE = "1.4.0"
InteractiveUtils = "1.6"
Expand Down
88 changes: 88 additions & 0 deletions ext/LinearSolveForwardDiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
module LinearSolveForwardDiff

using LinearSolve
using InteractiveUtils
isdefined(Base, :get_extension) ?
(import ForwardDiff; using ForwardDiff: Dual) :
(import ..ForwardDiff; using ..ForwardDiff: Dual)
Comment on lines +5 to +7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only 1.9+ is supported now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand. What do you mean?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically you dont need to do this anymore, just the first import line works


function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
@assert !(eltype(first(dAs)) isa Dual)
@assert !(eltype(first(dbs)) isa Dual)
@assert !(eltype(A) isa Dual)
@assert !(eltype(b) isa Dual)
reltol = cache.reltol isa Dual ? ForwardDiff.value(cache.reltol) : cache.reltol
abstol = cache.abstol isa Dual ? ForwardDiff.value(cache.abstol) : cache.abstol
u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u
cacheval = cache.cacheval isa Tuple ? cache.cacheval[1] : cache.cacheval
cacheval = eltype(cacheval.factors) <: Dual ? begin
LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cacheval.factors), cacheval.ipiv, cacheval.info)
end : cacheval
cacheval = cache.cacheval isa Tuple ? (cacheval, cache.cacheval[2]) : cacheval

cache2 = remake(cache; A, b, u, reltol, abstol, cacheval)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being forced to remake cache in order to solve the non-dual version. Is there some other way we can replace Dual Array with a regular array?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want to hook into init. In theory in init what you can do is un-dual the user inputs that are dual, but tag the cache in such a way that in solve! you end up doing two (or number of chunk size + 1) solves and reconstruct the resulting dual numbers in the output.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or rather, it's just one solve! call but in a batched form.

res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy
dresus = reduce(hcat, map(dAs, dbs) do dA, db
cache2.b = db - dA * res.u
dres = LinearSolve.solve!(cache2, alg, kwargs...)
deepcopy(dres.u)
end)
Comment on lines +24 to +29
Copy link
Contributor Author

@sharanry sharanry Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needing to deepcopy the results of the solves as they are being overwritten by subsequent solves when reusing the cache.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you hook into init and do a single batched solve then this is handled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any documentation on how to do batched solves? I am unable to find how to do this anywhere. The possi bly closest thing I could find was https://discourse.julialang.org/t/batched-lu-solves-or-factorizations-with-sparse-matrices/106019/2 -- however, couldn't find the right function call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just A\B matrix instead of A\b vector

Copy link
Contributor Author

@sharanry sharanry Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not entirely sure what you mean in the context of LinearSolve.jl.

n = 4
A = rand(n, n)
B = rand(n, n)

A \ B  # works

mapreduce(hcat, eachcol(B)) do b
    A \ b
end # works

mapreduce(hcat, eachcol(B)) do b
    prob = LinearProblem(A, b)
    sol = solve(prob)
    sol.u
end # works

begin
    prob = LinearProblem(A, B)
    sol = solve(prob)  # errors
    sol.u
end

Error:

ERROR: MethodError: no method matching ldiv!(::Vector{Float64}, ::LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, ::Matrix{Float64})

Closest candidates are:
  ldiv!(::Any, ::Sparspak.SpkSparseSolver.SparseSolver{IT, FT}, ::Any) where {IT, FT}
   @ Sparspak ~/.julia/packages/Sparspak/oqBYl/src/SparseCSCInterface/SparseCSCInterface.jl:263
  ldiv!(::Any, ::LinearSolve.InvPreconditioner, ::Any)
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/preconditioners.jl:30
  ldiv!(::Any, ::LinearSolve.ComposePreconditioner, ::Any)
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/preconditioners.jl:17
  ...

Stacktrace:
 [1] _ldiv!(x::Vector{Float64}, A::LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, b::Matrix{Float64})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/factorization.jl:11
 [2] macro expansion
   @ ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:135 [inlined]
 [3] solve!(cache::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, alg::LUFactorization{LinearAlgebra.RowMaximum}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:127
 [4] solve!(cache::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, alg::LUFactorization{LinearAlgebra.RowMaximum})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:127
 [5] solve!(::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:218
 [6] solve!(::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:217
 [7] solve(::LinearProblem{Nothing, true, Matrix{Float64}, Matrix{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::LUFactorization{LinearAlgebra.RowMaximum}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:214
 [8] solve(::LinearProblem{Nothing, true, Matrix{Float64}, Matrix{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::LUFactorization{LinearAlgebra.RowMaximum})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:211
 [9] top-level scope
   @ REPL[24]:3

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avik-pal I thought you handled something with this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avik-pal A ping on this. Is there another way to do this if we do not yet have batch dispatch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this case, but a case where A and b are both batched. Here you will have to see how Base handles it, there are special LAPACK routines for these

d = Dual{T}.(res.u, Tuple.(eachrow(dresus)))
LinearSolve.SciMLBase.build_linear_solution(alg, d, nothing, cache; retcode=res.retcode, iters=res.iters, stats=res.stats)
end


for ALG in subtypes(LinearSolve, LinearSolve.AbstractFactorization)
@eval begin
function LinearSolve.solve!(
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}, B},
alg::$ALG,
kwargs...
) where {T, V, P, B}
# @info "using solve! df/dA"
dAs = begin
t = collect.(ForwardDiff.partials.(cache.A))
[getindex.(t, i) for i in 1:P]
end
dbs = [zero(cache.b) for _=1:P]
A = ForwardDiff.value.(cache.A)
b = cache.b
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
end
function LinearSolve.solve!(
cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}},
alg::$ALG;
kwargs...
) where {T, V, P, A_}
# @info "using solve! df/db"
dAs = [zero(cache.A) for _=1:P]
dbs = begin
t = collect.(ForwardDiff.partials.(cache.b))
[getindex.(t, i) for i in 1:P]
end
A = cache.A
b = ForwardDiff.value.(cache.b)
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
end
function LinearSolve.solve!(
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}},
alg::$ALG;
kwargs...
) where {T, V, P}
# @info "using solve! df/dAb"
dAs = begin
t = collect.(ForwardDiff.partials.(cache.A))
[getindex.(t, i) for i in 1:P]
end
dbs = begin
t = collect.(ForwardDiff.partials.(cache.b))
[getindex.(t, i) for i in 1:P]
end
A = ForwardDiff.value.(cache.A)
b = ForwardDiff.value.(cache.b)
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
end
end
end

end # module
9 changes: 9 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
assumptions::OperatorAssumptions{issq}
end

function SciMLBase.remake(cache::LinearCache;
A::TA=cache.A, b::TB=cache.b, u::TU=cache.u, p::TP=cache.p, alg::Talg=cache.alg,
cacheval::Tc=cache.cacheval, isfresh::Bool=cache.isfresh, Pl::Tl=cache.Pl, Pr::Tr=cache.Pr,
abstol::Ttol=cache.abstol, reltol::Ttol=cache.reltol, maxiters::Int=cache.maxiters,
verbose::Bool=cache.verbose, assumptions::OperatorAssumptions{issq}=cache.assumptions) where {TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq}
LinearCache{TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq}(A,b,u,p,alg,cacheval,isfresh,Pl,Pr,abstol,reltol,
maxiters,verbose,assumptions)
end

Comment on lines +85 to +93
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check if there is a way to avoid redefining this by providing a better constructor for LinearCache.

function Base.setproperty!(cache::LinearCache, name::Symbol, x)
if name === :A
setfield!(cache, :isfresh, true)
Expand Down
74 changes: 74 additions & 0 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using Test
using ForwardDiff
using LinearSolve
using FiniteDiff
using Enzyme
using Random
Random.seed!(1234)

n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
for alg in (
LUFactorization(),
RFLUFactorization(),
# KrylovJL_GMRES(), dispatch fails
)
alg_str = string(alg)
@show alg_str
function fb(b)
prob = LinearProblem(A, b)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fb(b1)

fid_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
@show fid_jac

fod_jac = ForwardDiff.gradient(fb, b1) |> vec
@show fod_jac

@test fod_jac ≈ fid_jac rtol=1e-6

function fA(A)
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fA(A)

fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
@show fid_jac

fod_jac = ForwardDiff.gradient(fA, A) |> vec
@show fod_jac

@test fod_jac ≈ fid_jac rtol=1e-6


function fAb(Ab)
A = Ab[:, 1:n]
b1 = Ab[:, n+1]
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fAb(hcat(A, b1))

fid_jac = FiniteDiff.finite_difference_jacobian(fAb, hcat(A, b1)) |> vec
@show fid_jac

fod_jac = ForwardDiff.gradient(fAb, hcat(A, b1)) |> vec
@show fod_jac

@test fod_jac ≈ fid_jac rtol=1e-6

end