Skip to content

Commit

Permalink
Improve Rosenbrock finite difference by swapping dt with J
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Mar 5, 2022
1 parent 079ac94 commit c4eb8d7
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -645,12 +645,13 @@ end

function calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, W_transform)
nlsolver = nothing
calc_tderivative!(integrator, cache, dtd1, repeat_step)

# we need to skip calculating `W` when a step is repeated
new_W = false
if !repeat_step
new_W = calc_W!(cache.W, integrator, nlsolver, cache, dtgamma, repeat_step, W_transform)
end
calc_tderivative!(integrator, cache, dtd1, repeat_step)
return new_W
end

Expand Down
67 changes: 67 additions & 0 deletions test/interface/difftype_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using OrdinaryDiffEq, Test

function fcn(du, u, p, t)
du[1] = u[1]^2 - 1 / t^4
end

function f_jac(J, y, P, t)
#-- numerical Jac by FD
y_thres = P
del = sqrt(eps(1.0))
n = length(y)
f0 = similar(y)
f1 = similar(y)
fcn(f0, y, P, t)
for i = 1:n
del_1 = del * max(abs(y[i]), y_thres)
y1 = copy(y)
y1[i] = y1[i] + del_1
fcn(f1, y1, P, t)
J[:, i] = (f1 - f0) / del_1
end
end

tspan = (1.0, 1.0e3);
M = zeros(1, 1);

println("--- AD ---")

f = ODEFunction(fcn, mass_matrix = M)
problem = ODEProblem(f, [-1.0], tspan);

sol = solve(problem, Rodas4P2(), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100

println("--- FD central ---")

sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:central}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100

println("--- FD forward ---")

sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100

println("--- FD forward, y_thres = 1 ---")

y_thres = 1.0;
f = ODEFunction(fcn, mass_matrix = M, jac = f_jac)
problem = ODEProblem(f, [-1.0], tspan, y_thres);
sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100

println("--- FD forward, y_thres = 1.0e-5 ---")

y_thres = 1.0e-5;
f = ODEFunction(fcn, mass_matrix = M, jac = f_jac)
problem = ODEProblem(f, [-1.0], tspan, y_thres);
sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100

println("--- FD forward, y_thres = sqrt(eps) ---")

y_thres = sqrt(eps(1.0));
f = ODEFunction(fcn, mass_matrix = M, jac = f_jac)
problem = ODEProblem(f, [-1.0], tspan, y_thres);
sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12);
@test sol.destats.naccept < 6100
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ end
if !is_APPVEYOR && (GROUP == "All" || GROUP == "InterfaceII" || GROUP == "Interface")
@time @safetestset "Linear Nonlinear Solver Tests" begin include("interface/linear_nonlinear_tests.jl") end
@time @safetestset "Linear Solver Tests" begin include("interface/linear_solver_test.jl") end
@time @safetestset "Diff Type Tests" begin include("interface/difftype_tests.jl") end
@time @safetestset "Sparse Diff Tests" begin include("interface/sparsediff_tests.jl") end
@time @safetestset "Enum Tests" begin include("interface/enums.jl") end
@time @safetestset "Mass Matrix Tests" begin include("interface/mass_matrix_tests.jl") end
Expand Down

0 comments on commit c4eb8d7

Please sign in to comment.