Skip to content

Commit

Permalink
Improve Rosenbrock finite difference by swapping dt with J
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed May 17, 2023
1 parent 68962b7 commit 8c0398a
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
67 changes: 67 additions & 0 deletions test/interface/difftype_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using OrdinaryDiffEq, Test

function fcn(du, u, p, t)
du[1] = u[1]^2 - 1 / t^4
end

function f_jac(J, y, P, t)
#-- numerical Jac by FD
y_thres = P
del = sqrt(eps(1.0))
n = length(y)
f0 = similar(y)
f1 = similar(y)
fcn(f0, y, P, t)
for i = 1:n
del_1 = del * max(abs(y[i]), y_thres)
y1 = copy(y)
y1[i] = y1[i] + del_1
fcn(f1, y1, P, t)
J[:, i] = (f1 - f0) / del_1
end
end

tspan = (1.0, 1.0e3);
M = zeros(1, 1);

println("--- AD ---")

f = ODEFunction(fcn, mass_matrix = M)
problem = ODEProblem(f, [-1.0], tspan);

sol = solve(problem, Rodas4P2(), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100

println("--- FD central ---")

sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:central}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100

println("--- FD forward ---")

sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100

println("--- FD forward, y_thres = 1 ---")

y_thres = 1.0;
f = ODEFunction(fcn, mass_matrix = M, jac = f_jac)
problem = ODEProblem(f, [-1.0], tspan, y_thres);
sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100

println("--- FD forward, y_thres = 1.0e-5 ---")

y_thres = 1.0e-5;
f = ODEFunction(fcn, mass_matrix = M, jac = f_jac)
problem = ODEProblem(f, [-1.0], tspan, y_thres);
sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100

println("--- FD forward, y_thres = sqrt(eps) ---")

y_thres = sqrt(eps(1.0));
f = ODEFunction(fcn, mass_matrix = M, jac = f_jac)
problem = ODEProblem(f, [-1.0], tspan, y_thres);
sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ end
@time @safetestset "No Recompile Tests" begin include("interface/norecompile.jl") end
@time @safetestset "Linear Nonlinear Solver Tests" begin include("interface/linear_nonlinear_tests.jl") end
@time @safetestset "Linear Solver Tests" begin include("interface/linear_solver_test.jl") end
@time @safetestset "Diff Type Tests" begin include("interface/difftype_tests.jl") end
@time @safetestset "Linear Solver Split ODE Tests" begin include("interface/linear_solver_split_ode_test.jl") end
@time @safetestset "Sparse Diff Tests" begin include("interface/sparsediff_tests.jl") end
@time @safetestset "Enum Tests" begin include("interface/enums.jl") end
Expand Down

0 comments on commit 8c0398a

Please sign in to comment.