From c4eb8d738c2d0fbb0f57114887a43079cc3c03f4 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 4 Mar 2022 09:52:34 -0500 Subject: [PATCH] Improve Rosenbrock finite difference by swapping dt with J Fixes https://github.com/SciML/DifferentialEquations.jl/issues/773 --- src/derivative_utils.jl | 3 +- test/interface/difftype_tests.jl | 67 ++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 test/interface/difftype_tests.jl diff --git a/src/derivative_utils.jl b/src/derivative_utils.jl index 9849c3012d..eaa10b2eb5 100644 --- a/src/derivative_utils.jl +++ b/src/derivative_utils.jl @@ -645,12 +645,13 @@ end function calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, W_transform) nlsolver = nothing + calc_tderivative!(integrator, cache, dtd1, repeat_step) + # we need to skip calculating `W` when a step is repeated new_W = false if !repeat_step new_W = calc_W!(cache.W, integrator, nlsolver, cache, dtgamma, repeat_step, W_transform) end - calc_tderivative!(integrator, cache, dtd1, repeat_step) return new_W end diff --git a/test/interface/difftype_tests.jl b/test/interface/difftype_tests.jl new file mode 100644 index 0000000000..87b6f1568a --- /dev/null +++ b/test/interface/difftype_tests.jl @@ -0,0 +1,67 @@ +using OrdinaryDiffEq, Test + +function fcn(du, u, p, t) + du[1] = u[1]^2 - 1 / t^4 +end + +function f_jac(J, y, P, t) + #-- numerical Jac by FD + y_thres = P + del = sqrt(eps(1.0)) + n = length(y) + f0 = similar(y) + f1 = similar(y) + fcn(f0, y, P, t) + for i = 1:n + del_1 = del * max(abs(y[i]), y_thres) + y1 = copy(y) + y1[i] = y1[i] + del_1 + fcn(f1, y1, P, t) + J[:, i] = (f1 - f0) / del_1 + end +end + +tspan = (1.0, 1.0e3); +M = zeros(1, 1); + +println("--- AD ---") + +f = ODEFunction(fcn, mass_matrix = M) +problem = ODEProblem(f, [-1.0], tspan); + +sol = solve(problem, Rodas4P2(), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12); +@test sol.destats.naccept < 6100 + +println("--- FD central ---") + +sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:central}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12); +@test sol.destats.naccept < 6100 + +println("--- FD forward ---") + +sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12); +@test sol.destats.naccept < 6100 + +println("--- FD forward, y_thres = 1 ---") + +y_thres = 1.0; +f = ODEFunction(fcn, mass_matrix = M, jac = f_jac) +problem = ODEProblem(f, [-1.0], tspan, y_thres); +sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12); +@test sol.destats.naccept < 6100 + +println("--- FD forward, y_thres = 1.0e-5 ---") + +y_thres = 1.0e-5; +f = ODEFunction(fcn, mass_matrix = M, jac = f_jac) +problem = ODEProblem(f, [-1.0], tspan, y_thres); +sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12); +@test sol.destats.naccept < 6100 + +println("--- FD forward, y_thres = sqrt(eps) ---") + +y_thres = sqrt(eps(1.0)); +f = ODEFunction(fcn, mass_matrix = M, jac = f_jac) +problem = ODEProblem(f, [-1.0], tspan, y_thres); +sol = solve(problem, Rodas4P2(autodiff = false, diff_type = Val{:forward}), maxiters = Int(1e7), reltol = 1.0e-12, abstol = 1.0e-12); +@test sol.destats.naccept < 6100 diff --git a/test/runtests.jl b/test/runtests.jl index ec697fcfdb..1a26b1b1e6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,6 +44,7 @@ end if !is_APPVEYOR && (GROUP == "All" || GROUP == "InterfaceII" || GROUP == "Interface") @time @safetestset "Linear Nonlinear Solver Tests" begin include("interface/linear_nonlinear_tests.jl") end @time @safetestset "Linear Solver Tests" begin include("interface/linear_solver_test.jl") end + @time @safetestset "Diff Type Tests" begin include("interface/difftype_tests.jl") end @time @safetestset "Sparse Diff Tests" begin include("interface/sparsediff_tests.jl") end @time @safetestset "Enum Tests" begin include("interface/enums.jl") end @time @safetestset "Mass Matrix Tests" begin include("interface/mass_matrix_tests.jl") end