Skip to content

Commit

Permalink
Merge pull request #2291 from oscardssmith/os/fix-OOP-WOperator
Browse files Browse the repository at this point in the history
fix out of place `WOperator`
  • Loading branch information
ChrisRackauckas committed Jul 26, 2024
2 parents 17bd407 + 9c667c4 commit 7aae96e
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 88 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
ParameterizedFunctions = "65888b18-ceab-5e60-b2b9-181511a3b968"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -114,4 +115,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve"]
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve"]
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ using NonlinearSolve

# Required by temporary fix in not in-place methods with 12+ broadcasts
# `MVector` is used by Nordsieck forms
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA, StaticMatrix

# Integrator Interface
import DiffEqBase: resize!, deleteat!, addat!, full_cache, user_cache, u_cache, du_cache,
Expand Down
130 changes: 44 additions & 86 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,12 @@ function calc_J(integrator, cache, next_step::Bool = false)
J = jacobian(uf, uprev, integrator)
end

integrator.stats.njacs += 1

if alg isa CompositeAlgorithm
integrator.eigen_est = constvalue(opnorm(J, Inf))
end
end

integrator.stats.njacs += 1
J
end

Expand Down Expand Up @@ -144,12 +143,11 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
end
end

integrator.stats.njacs += 1

if alg isa CompositeAlgorithm
integrator.eigen_est = constvalue(opnorm(J, Inf))
end

integrator.stats.njacs += 1
return nothing
end

Expand Down Expand Up @@ -604,21 +602,21 @@ function jacobian2W!(W::Matrix, mass_matrix, dtgamma::Number, J::Matrix,
return nothing
end

function jacobian2W(mass_matrix::MT, dtgamma::Number, J::AbstractMatrix,
W_transform::Bool)::Nothing where {MT}
function jacobian2W(mass_matrix, dtgamma::Number, J::AbstractMatrix,
W_transform::Bool)
# check size and dimension
mass_matrix isa UniformScaling ||
@boundscheck axes(mass_matrix) == axes(J) || _throwJMerror(J, mass_matrix)
@inbounds if W_transform
invdtgamma = inv(dtgamma)
if MT <: UniformScaling
if mass_matrix isa UniformScaling
λ = -mass_matrix.λ
W = J +* invdtgamma) * I
else
W = muladd(-mass_matrix, invdtgamma, J)
end
else
if MT <: UniformScaling
if mass_matrix isa UniformScaling
λ = -mass_matrix.λ
W = dtgamma * J + λ * I
else
Expand Down Expand Up @@ -738,67 +736,33 @@ end
islin, isode = islinearfunction(integrator)
!isdae && update_coefficients!(mass_matrix, uprev, p, t)

if cache.W isa WOperator
W = cache.W
if isnewton(nlsolver)
# we will call `update_coefficients!` for u/p/t in NLNewton
update_coefficients!(W; transform = W_transform, dtgamma)
if cache.W isa StaticWOperator
integrator.stats.nw += 1
J = calc_J(integrator, cache, next_step)
W = StaticWOperator(W_transform ? J - mass_matrix * inv(dtgamma) : dtgamma * J - mass_matrix)
elseif cache.W isa WOperator
integrator.stats.nw += 1
J = if islin
isode ? f.f : f.f1.f
else
update_coefficients!(W, uprev, p, t; transform = W_transform, dtgamma)
calc_J(integrator, cache, next_step)
end
if W.J !== nothing && !(W.J isa AbstractSciMLOperator)
islin, isode = islinearfunction(integrator)
J = islin ? (isode ? f.f : f.f1.f) : calc_J(integrator, cache, next_step)
!isdae &&
jacobian2W!(W._concrete_form, mass_matrix, dtgamma, J, W_transform)
end
elseif cache.W isa AbstractSciMLOperator && !(cache.W isa StaticWOperator)
J = update_coefficients(cache.J, uprev, p, t)
W = WOperator{false}(mass_matrix, dtgamma, J, uprev, cache.W.jacvec; transform = W_transform)
elseif cache.W isa AbstractSciMLOperator
W = update_coefficients(cache.W, uprev, p, t; dtgamma, transform = W_transform)
elseif islin
J = isode ? f.f : f.f1.f # unwrap the Jacobian accordingly
W = WOperator{false}(mass_matrix, dtgamma, J, uprev; transform = W_transform)
elseif DiffEqBase.has_jac(f)
J = f.jac(uprev, p, t)
if J isa StaticArray &&
integrator.alg isa
Union{
Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P, Rodas4P2,
Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
W = W_transform ? J - mass_matrix * inv(dtgamma) :
dtgamma * J - mass_matrix
else
if !isa(J, AbstractSciMLOperator) && (!isnewton(nlsolver) ||
nlsolver.cache.W.J isa AbstractSciMLOperator)
J = MatrixOperator(J)
end
W = WOperator{false}(mass_matrix, dtgamma, J, uprev, cache.W.jacvec;
transform = W_transform)
end
integrator.stats.nw += 1
else
integrator.stats.nw += 1
J = calc_J(integrator, cache, next_step)
J = islin ? isode ? f.f : f.f1.f : calc_J(integrator, cache, next_step)
if isdae
W = J
else
W_full = W_transform ? J - mass_matrix * inv(dtgamma) :
W = W_transform ? J - mass_matrix * inv(dtgamma) :
dtgamma * J - mass_matrix
len = StaticArrayInterface.known_length(typeof(W_full))
W = if W_full isa Number
W_full
elseif len !== nothing &&
integrator.alg isa
Union{Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P,
Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
StaticWOperator(W_full)
else
DiffEqBase.default_factorize(W_full)
if !isa(W, Number)
W = DiffEqBase.default_factorize(W)
end
end
end
(W isa WOperator && unwrap_alg(integrator, true) isa NewtonAlgorithm) &&
(W = update_coefficients!(W, uprev, p, t)) # we will call `update_coefficients!` in NLNewton
is_compos && (integrator.eigen_est = isarray ? constvalue(opnorm(J, Inf)) :
integrator.opts.internalnorm(J, t))
return W
Expand Down Expand Up @@ -876,10 +840,11 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
elseif f.jac_prototype isa AbstractSciMLOperator
W = WOperator{IIP}(f, u, dt)
J = W.J
elseif islin
J = isode ? f.f : f.f1.f # unwrap the Jacobian accordingly
W = WOperator{IIP}(f.mass_matrix, dt, J, u)
elseif IIP && f.jac_prototype !== nothing && concrete_jac(alg) === nothing &&
(alg.linsolve === nothing ||
alg.linsolve !== nothing &&
LinearSolve.needs_concrete_A(alg.linsolve))
(alg.linsolve === nothing || LinearSolve.needs_concrete_A(alg.linsolve))

# If factorization, then just use the jac_prototype
J = similar(f.jac_prototype)
Expand All @@ -896,7 +861,6 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag())
J = jacvec
W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)

elseif alg.linsolve !== nothing && !LinearSolve.needs_concrete_A(alg.linsolve) ||
concrete_jac(alg) !== nothing && concrete_jac(alg)
# The linear solver does not need a concrete Jacobian, but the user has
Expand All @@ -908,42 +872,36 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
else
deepcopy(f.jac_prototype)
end
__f = if IIP
(du, u, p, t) -> _f(du, u, p, t)
W = if J isa StaticMatrix && alg isa OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
StaticWOperator(J, false)
elseif J isa StaticMatrix
ArrayInterface.lu_instance(J)
else
(u, p, t) -> _f(u, p, t)
end
jacvec = JacVec(__f, copy(u), p, t;
autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag())
W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)

elseif islin || (!IIP && DiffEqBase.has_jac(f))
J = islin ? (isode ? f.f : f.f1.f) : f.jac(uprev, p, t) # unwrap the Jacobian accordingly
if !isa(J, AbstractSciMLOperator)
J = MatrixOperator(J)
__f = if IIP
(du, u, p, t) -> _f(du, u, p, t)
else
(u, p, t) -> _f(u, p, t)
end
jacvec = JacVec(__f, copy(u), p, t;
autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag())
WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)
end
W = WOperator{IIP}(f.mass_matrix, dt, J, u)
else
J = if f.jac_prototype === nothing
J = if !IIP && DiffEqBase.has_jac(f)
f.jac(uprev, p, t)
elseif f.jac_prototype === nothing
ArrayInterface.undefmatrix(u)
else
deepcopy(f.jac_prototype)
end
isdae = alg isa DAEAlgorithm
W = if isdae
W = if alg isa DAEAlgorithm
J
elseif IIP
similar(J)
elseif J isa StaticMatrix && alg isa OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
StaticWOperator(J, false)
else
len = StaticArrayInterface.known_length(typeof(J))
if len !== nothing &&
alg isa
Union{Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P,
Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
StaticWOperator(J, false)
else
ArrayInterface.lu_instance(J)
end
ArrayInterface.lu_instance(J)
end
end
return J, W
Expand Down
40 changes: 40 additions & 0 deletions test/interface/linear_solver_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,43 @@ end
atol = 1e-1, rtol = 1e-1)
@test isapprox(exp.(p), g_helper(p; alg = KenCarp47(linsolve = KrylovJL_GMRES()));
atol = 1e-1, rtol = 1e-1)

using OrdinaryDiffEq, StaticArrays, LinearSolve, ParameterizedFunctions

hires = @ode_def Hires begin
dy1 = -1.71*y1 + 0.43*y2 + 8.32*y3 + 0.0007
dy2 = 1.71*y1 - 8.75*y2
dy3 = -10.03*y3 + 0.43*y4 + 0.035*y5
dy4 = 8.32*y2 + 1.71*y3 - 1.12*y4
dy5 = -1.745*y5 + 0.43*y6 + 0.43*y7
dy6 = -280.0*y6*y8 + 0.69*y4 + 1.71*y5 - 0.43*y6 + 0.69*y7
dy7 = 280.0*y6*y8 - 1.81*y7
dy8 = -280.0*y6*y8 + 1.81*y7
end

u0 = zeros(8)
u0[1] = 1
u0[8] = 0.0057

probiip = ODEProblem{true}(hires, u0, (0.0,10.0))
proboop = ODEProblem{false}(hires, u0, (0.0,10.0))
probstatic = ODEProblem{false}(hires, SVector{8}(u0), (0.0,10.0))
probs = (;probiip, proboop, probstatic)
qndf = QNDF()
krylov_qndf = QNDF(linsolve=KrylovJL_GMRES())
fbdf = FBDF()
krylov_fbdf = FBDF(linsolve=KrylovJL_GMRES())
rodas = Rodas5P()
krylov_rodas = Rodas5P(linsolve=KrylovJL_GMRES())
solvers = (;qndf, krylov_qndf, rodas, krylov_rodas, fbdf, krylov_fbdf, )

refsol = solve(probiip, FBDF(), abstol=1e-12, reltol=1e-12)
@testset "Hires calc_W tests" begin
@testset "$probname" for (probname, prob) in pairs(probs)
@testset "$solname" for (solname, solver) in pairs(solvers)
sol = solve(prob, solver, abstol=1e-12, reltol=1e-12, maxiters=2e4)
@test sol.retcode == ReturnCode.Success
@test isapprox(sol.u[end], refsol.u[end], rtol=1e-8, atol=1e-10)
end
end
end

0 comments on commit 7aae96e

Please sign in to comment.