Skip to content

Commit

Permalink
fix tests and nonlinear precs
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Aug 22, 2024
1 parent 1a76866 commit cfa4d58
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 137 deletions.
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ end
end
const TryAgain = SlowConvergence

DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
DEFAULT_PRECS(W, p) = nothing, nothing
isdiscretecache(cache) = false

include("doc_utils.jl")
Expand Down
1 change: 0 additions & 1 deletion lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ issuccess_W(::Any) = true

function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
A !== nothing && (linsolve.A = A)
b !== nothing && (linsolve.b = b)
linu !== nothing && (linsolve.u = linu)

Expand Down
32 changes: 8 additions & 24 deletions lib/OrdinaryDiffEqExtrapolation/src/extrapolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,14 @@ function alg_cache(alg::ImplicitEulerExtrapolation, u, rate_prototype,
linsolve_tmps[i] = zero(rate_prototype)
end

linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]); u0 = _vec(k_tmps[1]))
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]), (nothing, u, p, t); u0 = _vec(k_tmps[1]))
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))

linsolve = Array{typeof(linsolve1), 1}(undef, Threads.nthreads())
linsolve[1] = linsolve1
for i in 2:Threads.nthreads()
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]); u0 = _vec(k_tmps[i]))
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]), (nothing, u, p, t); u0 = _vec(k_tmps[i]))
linsolve[i] = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))
end

res = uEltypeNoUnits.(zero(u))
Expand Down Expand Up @@ -1150,18 +1146,14 @@ function alg_cache(alg::ImplicitDeuflhardExtrapolation, u, rate_prototype,
linsolve_tmps[i] = zero(rate_prototype)
end

linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]); u0 = _vec(k_tmps[1]))
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]), (nothing, u, p, t); u0 = _vec(k_tmps[1]))
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))

linsolve = Array{typeof(linsolve1), 1}(undef, Threads.nthreads())
linsolve[1] = linsolve1
for i in 2:Threads.nthreads()
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]); u0 = _vec(k_tmps[i]))
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]), (nothing, u, p, t); u0 = _vec(k_tmps[i]))
linsolve[i] = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))
end
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, du1, du2)
Expand Down Expand Up @@ -1478,18 +1470,14 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation, u, rate_prototype,
linsolve_tmps[i] = zero(rate_prototype)
end

linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]); u0 = _vec(k_tmps[1]))
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]), (nothing, u, p, t); u0 = _vec(k_tmps[1]))
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))

linsolve = Array{typeof(linsolve1), 1}(undef, Threads.nthreads())
linsolve[1] = linsolve1
for i in 2:Threads.nthreads()
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]); u0 = _vec(k_tmps[i]))
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]), (nothing, u, p, t); u0 = _vec(k_tmps[i]))
linsolve[i] = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))
end
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, du1, du2)
Expand Down Expand Up @@ -1674,18 +1662,14 @@ function alg_cache(alg::ImplicitEulerBarycentricExtrapolation, u, rate_prototype
linsolve_tmps[i] = zero(rate_prototype)
end

linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]); u0 = _vec(k_tmps[1]))
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]), (nothing, u, p, t); u0 = _vec(k_tmps[1]))
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))

linsolve = Array{typeof(linsolve1), 1}(undef, Threads.nthreads())
linsolve[1] = linsolve1
for i in 2:Threads.nthreads()
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]); u0 = _vec(k_tmps[i]))
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]), (nothing, u, p, t); u0 = _vec(k_tmps[i]))
linsolve[i] = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))
end
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, du1, du2)
Expand Down
24 changes: 6 additions & 18 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,9 @@ function alg_cache(alg::RadauIIA3, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)
jac_config = jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw12)

linprob = LinearProblem(W1, _vec(cubuff); u0 = _vec(dw12))
linprob = LinearProblem(W1, _vec(cubuff), (nothing,u,p,t); u0 = _vec(dw12))
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))

rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)
Expand Down Expand Up @@ -252,16 +250,12 @@ function alg_cache(alg::RadauIIA5, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)

linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
linprob = LinearProblem(W1, _vec(ubuff), (nothing,u,p,t); u0 = _vec(dw1))
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))
linprob = LinearProblem(W2, _vec(cubuff); u0 = _vec(dw23))
linprob = LinearProblem(W2, _vec(cubuff), (nothing,u,p,t); u0 = _vec(dw23))
linsolve2 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))

rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)
Expand Down Expand Up @@ -441,21 +435,15 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)

linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
linprob = LinearProblem(W1, _vec(ubuff), (nothing,u,p,t); u0 = _vec(dw1))
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))
linprob = LinearProblem(W2, _vec(cubuff1); u0 = _vec(dw23))
linprob = LinearProblem(W2, _vec(cubuff1), (nothing,u,p,t); u0 = _vec(dw23))
linsolve2 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))
linprob = LinearProblem(W3, _vec(cubuff2); u0 = _vec(dw45))
linprob = LinearProblem(W3, _vec(cubuff2), (nothing,u,p,t); u0 = _vec(dw45))
linsolve3 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))

rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)
Expand Down
5 changes: 0 additions & 5 deletions lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,7 @@ function build_nlsolver(
jac_config = build_jac_config(alg, nf, uf, du1, uprev, u, ztmp, dz)
end
linprob = LinearProblem(W, _vec(k), (isdae ? du1 : nothing,u,p,t); u0 = _vec(dz))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)...,
weight, dz)
linsolve = init(linprob, alg.linsolve, (isdae ? du1 : nothing,u,p,t); alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))

tType = typeof(t)
Expand Down
69 changes: 29 additions & 40 deletions test/interface/linear_nonlinear_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,22 @@ end
u0 = rand(3)
prob = ODEProblem(rn, u0, (0, 50.0))

function precsl(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
if newW === nothing || newW
Pl = lu(convert(AbstractMatrix, W), check = false)
else
Pl = Plprev
end
Pl, nothing
function precsl(W, p)
Pl = lu(convert(AbstractMatrix, W), check = false)
Pl, IdentityOperator(size(W, 1))
end

function precsr(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
if newW === nothing || newW
Pr = lu(convert(AbstractMatrix, W), check = false)
else
Pr = Prprev
end
nothing, Pr
function precsr(W, p)
Pr = lu(convert(AbstractMatrix, W), check = false)
IdentityOperator(size(W, 1)), Pr
end

function precslr(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
if newW === nothing || newW
Pr = lu(convert(AbstractMatrix, W), check = false)
else
Pr = Prprev
end
function precslr(W, p)
Pr = lu(convert(AbstractMatrix, W), check = false)
Pr, Pr
end


sol = @test_nowarn solve(prob, TRBDF2(autodiff = false));
@test length(sol.t) < 20
sol = @test_nowarn solve(prob,
Expand All @@ -45,29 +34,29 @@ solref = @test_nowarn solve(prob,
smooth_est = false));
@test length(sol.t) < 20
sol = @test_nowarn solve(prob,
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
precs = precsl, smooth_est = false, concrete_jac = true));
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precsl),
smooth_est = false, concrete_jac = true));
@test length(sol.t) < 20
sol = @test_nowarn solve(prob,
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
precs = precsr, smooth_est = false, concrete_jac = true));
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precsr),
smooth_est = false, concrete_jac = true));
@test length(sol.t) < 20
sol = @test_nowarn solve(prob,
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
precs = precslr, smooth_est = false, concrete_jac = true));
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precslr)
, smooth_est = false, concrete_jac = true));
@test length(sol.t) < 20
sol = @test_nowarn solve(prob,
QNDF(autodiff = false, linsolve = KrylovJL_GMRES(),
concrete_jac = true));
@test length(sol.t) < 25
sol = @test_nowarn solve(prob,
Rosenbrock23(autodiff = false,
linsolve = KrylovJL_GMRES(),
precs = precslr, concrete_jac = true));
linsolve = KrylovJL_GMRES(precs = precslr),
concrete_jac = true));
@test length(sol.t) < 20
sol = @test_nowarn solve(prob,
Rodas4(autodiff = false, linsolve = KrylovJL_GMRES(),
precs = precslr, concrete_jac = true));
Rodas4(autodiff = false, linsolve = KrylovJL_GMRES(precs = precslr),
concrete_jac = true));
@test length(sol.t) < 20

sol = @test_nowarn solve(prob, TRBDF2(autodiff = false));
Expand All @@ -79,26 +68,26 @@ sol = @test_nowarn solve(prob,
smooth_est = false));
@test length(sol.t) < 20
sol = @test_nowarn solve(prob,
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
precs = precsl, smooth_est = false, concrete_jac = true));
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precsl),
smooth_est = false, concrete_jac = true));
@test length(sol.t) < 20
sol = @test_nowarn solve(prob,
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
precs = precsr, smooth_est = false, concrete_jac = true));
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precsr),
smooth_est = false, concrete_jac = true));
@test length(sol.t) < 20
sol = @test_nowarn solve(prob,
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
precs = precslr, smooth_est = false, concrete_jac = true));
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precslr),
smooth_est = false, concrete_jac = true));
@test length(sol.t) < 20
sol = @test_nowarn solve(prob,
QNDF(autodiff = false, linsolve = KrylovJL_GMRES(),
concrete_jac = true));
@test length(sol.t) < 25
sol = @test_nowarn solve(prob,
Rosenbrock23(autodiff = false, linsolve = KrylovJL_GMRES(),
precs = precslr, concrete_jac = true));
Rosenbrock23(autodiff = false, linsolve = KrylovJL_GMRES(precs = precslr),
concrete_jac = true));
@test length(sol.t) < 20
sol = @test_nowarn solve(prob,
Rodas4(autodiff = false, linsolve = KrylovJL_GMRES(),
precs = precslr, concrete_jac = true));
Rodas4(autodiff = false, linsolve = KrylovJL_GMRES(precs = precslr),
concrete_jac = true));
@test length(sol.t) < 20
4 changes: 2 additions & 2 deletions test/interface/linear_solver_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ refsol = solve(probiip, FBDF(), abstol = 1e-12, reltol = 1e-12)
@testset "$solname" for (solname, solver) in pairs(solvers)
sol = solve(prob, solver, abstol = 1e-12, reltol = 1e-12, maxiters = 2e4)
@test sol.retcode == ReturnCode.Success
@test isapprox(sol.u[end], refsol.u[end], rtol = 1e-8, atol = 1e-10)
@test isapprox(sol.u[end], refsol.u[end], rtol = 2e-8, atol = 1e-10)
end
end
end
Expand All @@ -207,7 +207,7 @@ end
@testset "$solname" for (solname, solver) in pairs(solvers)
sol = solve(prob, solver, maxiters = 2e4)
@test sol.retcode == ReturnCode.Success
@test isapprox(sol.u[end], refsol.u[end], rtol = 2e-3, atol = 1e-6)
@test isapprox(sol.u[end], refsol.u[end], rtol = 5e-3, atol = 1e-6)
end
end
end
Loading

0 comments on commit cfa4d58

Please sign in to comment.