diff --git a/Project.toml b/Project.toml index fb1f3a5ed..017bced67 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" +SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -97,6 +98,7 @@ ReverseDiff = "1" SciMLBase = "2.115.0" SciMLOperators = "1" SciMLStructures = "1.5" +SimpleNonlinearSolve = "2.7" Setfield = "1" SparseArrays = "1.9" Static = "1" diff --git a/ext/DiffEqBaseForwardDiffExt.jl b/ext/DiffEqBaseForwardDiffExt.jl index c5c0b4e26..f2ad064cf 100644 --- a/ext/DiffEqBaseForwardDiffExt.jl +++ b/ext/DiffEqBaseForwardDiffExt.jl @@ -1,13 +1,13 @@ module DiffEqBaseForwardDiffExt using DiffEqBase, ForwardDiff +using SimpleNonlinearSolve: ITP using DiffEqBase.ArrayInterface using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag, AbstractTimeseriesSolution, RecursiveArrayTools, reduce_tup, _promote_tspan, has_continuous_callback import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin, - promote_tspan, ODE_DEFAULT_NORM, - InternalITP, nextfloat_tdir + promote_tspan, ODE_DEFAULT_NORM import SciMLBase: isdualtype, DualEltypeChecker, sse, __sum const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1} @@ -153,7 +153,7 @@ end # Differentiation of internal solver -function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...) +function scalar_nlsolve_ad(prob, alg::ITP, args...; kwargs...) f = prob.f p = value(prob.p) @@ -186,7 +186,7 @@ end function SciMLBase.solve( prob::IntervalNonlinearProblem{uType, iip, <:ForwardDiff.Dual{T, V, P}}, - alg::InternalITP, args...; + alg::ITP, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), @@ -202,7 +202,7 @@ function SciMLBase.solve( V, P}, }}, - alg::InternalITP, args...; + alg::ITP, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index 1a0466ef3..645b04922 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -44,6 +44,8 @@ using SciMLBase using SciMLOperators: AbstractSciMLOperator, AbstractSciMLScalarOperator +using SimpleNonlinearSolve: ITP + using SciMLBase: @def, DEIntegrator, AbstractDEProblem, AbstractDiffEqInterpolation, DECallback, AbstractDEOptions, DECache, AbstractContinuousCallback, @@ -140,7 +142,6 @@ include("utils.jl") include("stats.jl") include("calculate_residuals.jl") include("tableaus.jl") -include("internal_itp.jl") include("callbacks.jl") include("common_defaults.jl") diff --git a/src/callbacks.jl b/src/callbacks.jl index a84d11e4c..a714cccac 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -358,17 +358,17 @@ end # rough implementation, needs multiple type handling # always ensures that if r = bisection(f, (x0, x1)) # then either f(nextfloat(r)) == 0 or f(nextfloat(r)) * f(r) < 0 -# note: not really using bisection - uses the ITP method +# note: not really using bisection - uses the ITP method function bisection( f, tup, t_forward::Bool, rootfind::SciMLBase.RootfindOpt, abstol, reltol; maxiters = 1000) if rootfind == SciMLBase.LeftRootFind solve(IntervalNonlinearProblem{false}(f, tup), - InternalITP(), abstol = abstol, + ITP(), abstol = abstol, reltol = reltol).left else solve(IntervalNonlinearProblem{false}(f, tup), - InternalITP(), abstol = abstol, + ITP(), abstol = abstol, reltol = reltol).right end end diff --git a/src/internal_itp.jl b/src/internal_itp.jl deleted file mode 100644 index 06f685045..000000000 --- a/src/internal_itp.jl +++ /dev/null @@ -1,87 +0,0 @@ -""" - prevfloat_tdir(x, x0, x1) - -Move `x` one floating point towards x0. -""" -function prevfloat_tdir(x, x0, x1) - x1 > x0 ? prevfloat(x) : nextfloat(x) -end - -function nextfloat_tdir(x, x0, x1) - x1 > x0 ? nextfloat(x) : prevfloat(x) -end - -function max_tdir(a, b, x0, x1) - x1 > x0 ? max(a, b) : min(a, b) -end - -""" -`InternalITP`: A non-allocating ITP method, internal to DiffEqBase for -simpler dependencies. -""" -struct InternalITP - scaled_k1::Float64 - n0::Int -end - -InternalITP() = InternalITP(0.2, 10) - -function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T, T}}, alg::InternalITP, - args...; - maxiters = 1000, kwargs...) where {IP, T} - f = Base.Fix2(prob.f, prob.p) - left, right = minmax(prob.tspan...) # a and b - fl, fr = f(left), f(right) - ϵ = eps(T) - if iszero(fl) - return SciMLBase.build_solution(prob, alg, left, fl; - retcode = ReturnCode.ExactSolutionLeft, left, right) - elseif iszero(fr) - return SciMLBase.build_solution(prob, alg, right, fr; - retcode = ReturnCode.ExactSolutionRight, left, right) - end - span = right - left - k1 = T(alg.scaled_k1) / span - n0 = T(alg.n0) - n_h = exponent(span / (2 * ϵ)) - ϵ_s = ϵ * exp2(n_h + n0) - T0 = zero(fl) - - i = 1 - while i ≤ maxiters - span = right - left - mid = (left + right) / 2 - r = ϵ_s - (span / 2) - - x_f = left + span * (fl / (fl - fr)) # Interpolation Step - - δ = max(k1 * span^2, eps(x_f)) - diff = mid - x_f - - xt = ifelse(δ ≤ abs(diff), x_f + copysign(δ, diff), mid) # Truncation Step - - xp = ifelse(abs(xt - mid) ≤ r, xt, mid - copysign(r, diff)) # Projection Step - yp = f(xp) - yps = yp * sign(fr) - if yps > T0 - right, fr = xp, yp - elseif yps < T0 - left, fl = xp, yp - else - return SciMLBase.build_solution( - prob, alg, xp, yps; retcode = ReturnCode.Success, left = xp, right = xp - ) - end - - i += 1 - ϵ_s /= 2 - - if nextfloat_tdir(left, left, right) == right - return SciMLBase.build_solution( - prob, alg, right, fr; retcode = ReturnCode.FloatingPointLimit, left, right - ) - end - end - return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters, - left = left, right = right) -end diff --git a/test/internal_rootfinder.jl b/test/internal_rootfinder.jl index 82b0632e7..29947cecd 100644 --- a/test/internal_rootfinder.jl +++ b/test/internal_rootfinder.jl @@ -1,8 +1,9 @@ using DiffEqBase -using DiffEqBase: InternalITP, IntervalNonlinearProblem +using DiffEqBase: IntervalNonlinearProblem +using SimpleNonlinearSolve: ITP using ForwardDiff -for Rootfinder in (InternalITP,) +for Rootfinder in (ITP,) rf = Rootfinder() # From SimpleNonlinearSolve f = (u, p) -> u * u - p