diff --git a/docs/src/optimization_packages/ode.md b/docs/src/optimization_packages/ode.md new file mode 100644 index 000000000..0144ac301 --- /dev/null +++ b/docs/src/optimization_packages/ode.md @@ -0,0 +1,55 @@ +# OptimizationODE.jl + +**OptimizationODE.jl** provides ODE-based optimization methods as a solver plugin for [SciML's Optimization.jl](https://github.com/SciML/Optimization.jl). It wraps various ODE solvers to perform gradient-based optimization using continuous-time dynamics. + +## Installation + +```julia +using Pkg +Pkg.add(url="OptimizationODE.jl") +``` + +## Usage + +```julia +using OptimizationODE, Optimization, ADTypes, SciMLBase + +function f(x, p) + return sum(abs2, x) +end + +function g!(g, x, p) + @. g = 2 * x +end + +x0 = [2.0, -3.0] +p = [] + +f_manual = OptimizationFunction(f, SciMLBase.NoAD(); grad = g!) +prob_manual = OptimizationProblem(f_manual, x0) + +opt = ODEGradientDescent(dt=0.01) +sol = solve(prob_manual, opt; maxiters=50_000) + +@show sol.u +@show sol.objective +``` + +## Local Gradient-based Optimizers + +All provided optimizers are **gradient-based local optimizers** that solve optimization problems by integrating gradient-based ODEs to convergence: + +* `ODEGradientDescent(dt=...)` — performs basic gradient descent using the explicit Euler method. This is a simple and efficient method suitable for small-scale or well-conditioned problems. + +* `RKChebyshevDescent()` — uses the ROCK2 solver, a stabilized explicit Runge-Kutta method suitable for stiff problems. It allows larger step sizes while maintaining stability. + +* `RKAccelerated()` — leverages the Tsit5 method, a 5th-order Runge-Kutta solver that achieves faster convergence for smooth problems by improving integration accuracy. + +* `HighOrderDescent()` — applies Vern7, a high-order (7th-order) explicit Runge-Kutta method for even more accurate integration. This can be beneficial for problems requiring high precision. + +You can also define a custom optimizer using the generic `ODEOptimizer(solver; dt=nothing)` constructor by supplying any ODE solver supported by [OrdinaryDiffEq.jl](https://docs.sciml.ai/DiffEqDocs/stable/solvers/ode_solve/). + +## Interface Details + +All optimizers require gradient information (either via automatic differentiation or manually provided `grad!`). The optimization is performed by integrating the ODE defined by the negative gradient until a steady state is reached. + diff --git a/lib/OptimizationODE/Project.toml b/lib/OptimizationODE/Project.toml index df5de47a7..d04840cc4 100644 --- a/lib/OptimizationODE/Project.toml +++ b/lib/OptimizationODE/Project.toml @@ -6,10 +6,14 @@ version = "0.1.0" [deps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" +DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" [compat] ForwardDiff = "0.10, 1" diff --git a/lib/OptimizationODE/src/OptimizationODE.jl b/lib/OptimizationODE/src/OptimizationODE.jl index ffacdc20a..bcfe80338 100644 --- a/lib/OptimizationODE/src/OptimizationODE.jl +++ b/lib/OptimizationODE/src/OptimizationODE.jl @@ -2,58 +2,123 @@ module OptimizationODE using Reexport @reexport using Optimization, SciMLBase -using OrdinaryDiffEq, SteadyStateDiffEq +using LinearAlgebra, ForwardDiff + +using NonlinearSolve +using OrdinaryDiffEq, DifferentialEquations, SteadyStateDiffEq, Sundials export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent +export DAEOptimizer, DAEMassMatrix, DAEIndexing + +struct ODEOptimizer{T} + solver::T +end -struct ODEOptimizer{T, T2} +ODEGradientDescent() = ODEOptimizer(Euler()) +RKChebyshevDescent() = ODEOptimizer(ROCK2()) +RKAccelerated() = ODEOptimizer(Tsit5()) +HighOrderDescent() = ODEOptimizer(Vern7()) + +struct DAEOptimizer{T} solver::T - dt::T2 end -ODEOptimizer(solver ; dt=nothing) = ODEOptimizer(solver, dt) -# Solver Constructors (users call these) -ODEGradientDescent(; dt) = ODEOptimizer(Euler(); dt) -RKChebyshevDescent() = ODEOptimizer(ROCK2()) -RKAccelerated() = ODEOptimizer(Tsit5()) -HighOrderDescent() = ODEOptimizer(Vern7()) +DAEMassMatrix() = DAEOptimizer(Rosenbrock23(autodiff = false)) +DAEIndexing() = DAEOptimizer(IDA()) -SciMLBase.requiresbounds(::ODEOptimizer) = false -SciMLBase.allowsbounds(::ODEOptimizer) = false -SciMLBase.allowscallback(::ODEOptimizer) = true +SciMLBase.requiresbounds(::ODEOptimizer) = false +SciMLBase.allowsbounds(::ODEOptimizer) = false +SciMLBase.allowscallback(::ODEOptimizer) = true SciMLBase.supports_opt_cache_interface(::ODEOptimizer) = true -SciMLBase.requiresgradient(::ODEOptimizer) = true -SciMLBase.requireshessian(::ODEOptimizer) = false -SciMLBase.requiresconsjac(::ODEOptimizer) = false -SciMLBase.requiresconshess(::ODEOptimizer) = false +SciMLBase.requiresgradient(::ODEOptimizer) = true +SciMLBase.requireshessian(::ODEOptimizer) = false +SciMLBase.requiresconsjac(::ODEOptimizer) = false +SciMLBase.requiresconshess(::ODEOptimizer) = false + + +SciMLBase.requiresbounds(::DAEOptimizer) = false +SciMLBase.allowsbounds(::DAEOptimizer) = false +SciMLBase.allowsconstraints(::DAEOptimizer) = true +SciMLBase.allowscallback(::DAEOptimizer) = true +SciMLBase.supports_opt_cache_interface(::DAEOptimizer) = true +SciMLBase.requiresgradient(::DAEOptimizer) = true +SciMLBase.requireshessian(::DAEOptimizer) = false +SciMLBase.requiresconsjac(::DAEOptimizer) = true +SciMLBase.requiresconshess(::DAEOptimizer) = false function SciMLBase.__init(prob::OptimizationProblem, opt::ODEOptimizer; - callback=Optimization.DEFAULT_CALLBACK, progress=false, + callback=Optimization.DEFAULT_CALLBACK, progress=false, dt=nothing, maxiters=nothing, kwargs...) - - return OptimizationCache(prob, opt; callback=callback, progress=progress, + return OptimizationCache(prob, opt; callback=callback, progress=progress, dt=dt, maxiters=maxiters, kwargs...) end -function SciMLBase.__solve( - cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C} - ) where {F,RC,LB,UB,LC,UC,S,O<:ODEOptimizer,D,P,C} +function SciMLBase.__init(prob::OptimizationProblem, opt::DAEOptimizer; + callback=Optimization.DEFAULT_CALLBACK, progress=false, dt=nothing, + maxiters=nothing, differential_vars=nothing, kwargs...) + return OptimizationCache(prob, opt; callback=callback, progress=progress, dt=dt, + maxiters=maxiters, differential_vars=differential_vars, kwargs...) +end - dt = cache.opt.dt - maxit = get(cache.solver_args, :maxiters, 1000) +function handle_parameters(p) + if p isa SciMLBase.NullParameters + return Float64[] + else + return p + end +end + +function setup_progress_callback(cache, solve_kwargs) + if get(cache.solver_args, :progress, false) + condition = (u, t, integrator) -> true + affect! = (integrator) -> begin + u_opt = integrator.u isa AbstractArray ? integrator.u : integrator.u.u + cache.solver_args[:callback](u_opt, integrator.p, integrator.t) + end + cb = DiscreteCallback(condition, affect!) + solve_kwargs[:callback] = cb + end + return solve_kwargs +end + + +function SciMLBase.__solve( + cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C} + ) where {F,RC,LB,UB,LC,UC,S,O<:Union{ODEOptimizer,DAEOptimizer},D,P,C} + + dt = get(cache.solver_args, :dt, nothing) + maxit = get(cache.solver_args, :maxiters, nothing) + differential_vars = get(cache.solver_args, :differential_vars, nothing) u0 = copy(cache.u0) - p = cache.p + p = handle_parameters(cache.p) # Properly handle NullParameters + if cache.opt isa ODEOptimizer + return solve_ode(cache, dt, maxit, u0, p) + else + if cache.opt.solver == Rosenbrock23(autodiff = false) + return solve_dae_mass_matrix(cache, dt, maxit, u0, p) + else + return solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars) + end + end +end + +function solve_ode(cache, dt, maxit, u0, p) if cache.f.grad === nothing error("ODEOptimizer requires a gradient. Please provide a function with `grad` defined.") end function f!(du, u, p, t) - cache.f.grad(du, u, p) - @. du = -du + grad_vec = similar(u) + if isempty(p) + cache.f.grad(grad_vec, u) + else + cache.f.grad(grad_vec, u, p) + end + @. du = -grad_vec return nothing end @@ -62,14 +127,11 @@ function SciMLBase.__solve( algorithm = DynamicSS(cache.opt.solver) cb = cache.callback - if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false) === true - function condition(u, t, integrator) - true - end + if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false) + function condition(u, t, integrator) true end function affect!(integrator) u_now = integrator.u - state = Optimization.OptimizationState(u=u_now, objective=cache.f(integrator.u, integrator.p)) - Optimization.callback_function(cb, state) + cache.callback(u_now, integrator.p, integrator.t) end cb_struct = DiscreteCallback(condition, affect!) callback = CallbackSet(cb_struct) @@ -86,16 +148,16 @@ function SciMLBase.__solve( end sol = solve(ss_prob, algorithm; solve_kwargs...) -has_destats = hasproperty(sol, :destats) -has_t = hasproperty(sol, :t) && !isempty(sol.t) + has_destats = hasproperty(sol, :destats) + has_t = hasproperty(sol, :t) && !isempty(sol.t) -stats = Optimization.OptimizationStats( - iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10), - time = has_t ? sol.t[end] : 0.0, - fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0, - gevals = has_destats ? get(sol.destats, :iters, 0) : 0, - hevals = 0 -) + stats = Optimization.OptimizationStats( + iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10), + time = has_t ? sol.t[end] : 0.0, + fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0, + gevals = has_destats ? get(sol.destats, :iters, 0) : 0, + hevals = 0 + ) SciMLBase.build_solution(cache, cache.opt, sol.u, cache.f(sol.u, p); retcode = ReturnCode.Success, @@ -103,4 +165,121 @@ stats = Optimization.OptimizationStats( ) end +function solve_dae_mass_matrix(cache, dt, maxit, u0, p) + if cache.f.cons === nothing + return solve_ode(cache, dt, maxit, u0, p) + end + x=u0 + cons_vals = cache.f.cons(x, p) + n = length(u0) + m = length(cons_vals) + u0_extended = vcat(u0, zeros(m)) + M = Diagonal(ones(n + m)) + + + function f_mass!(du, u, p_, t) + x = @view u[1:n] + λ = @view u[n+1:end] + grad_f = similar(x) + if cache.f.grad !== nothing + cache.f.grad(grad_f, x, p_) + else + grad_f .= ForwardDiff.gradient(z -> cache.f.f(z, p_), x) + end + J = Matrix{eltype(x)}(undef, m, n) + cache.f.cons_j !== nothing && cache.f.cons_j(J, x) + + @. du[1:n] = -grad_f - (J' * λ) + consv = cache.f.cons(x, p_) + @. du[n+1:end] = consv + return nothing + end + + if m == 0 + optf = ODEFunction(f_mass!) + prob = ODEProblem(optf, u0, (0.0, 1.0), p) + return solve(prob, cache.opt.solver; dt=dt, maxiters=maxit) + end + + ss_prob = SteadyStateProblem(ODEFunction(f_mass!, mass_matrix = M), u0_extended, p) + + solve_kwargs = setup_progress_callback(cache, Dict()) + if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end + if dt !== nothing; solve_kwargs[:dt] = dt; end + + sol = solve(ss_prob, DynamicSS(cache.opt.solver); solve_kwargs...) + # if sol.retcode ≠ ReturnCode.Success + # # you may still accept Default or warn + # end + u_ext = sol.u + u_final = u_ext[1:n] + return SciMLBase.build_solution(cache, cache.opt, u_final, cache.f(u_final, p); + retcode = sol.retcode) +end + + +function solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars) + if cache.f.cons === nothing + return solve_ode(cache, dt, maxit, u0, p) + end + x=u0 + cons_vals = cache.f.cons(x, p) + n = length(u0) + m = length(cons_vals) + u0_ext = vcat(u0, zeros(m)) + du0_ext = zeros(n + m) + + if differential_vars === nothing + differential_vars = vcat(fill(true, n), fill(false, m)) + else + if length(differential_vars) == n + differential_vars = vcat(differential_vars, fill(false, m)) + elseif length(differential_vars) == n + m + # use as is + else + error("differential_vars length must be number of variables ($n) or extended size ($(n+m))") + end + end + + function dae_residual!(res, du, u, p_, t) + x = @view u[1:n] + λ = @view u[n+1:end] + du_x = @view du[1:n] + grad_f = similar(x) + cache.f.grad(grad_f, x, p_) + J = zeros(m, n) + cache.f.cons_j !== nothing && cache.f.cons_j(J, x) + + @. res[1:n] = du_x + grad_f + J' * λ + consv = cache.f.cons(x, p_) + @. res[n+1:end] = consv + return nothing + end + + if m == 0 + optf = ODEFunction(dae_residual!, differential_vars = differential_vars) + prob = ODEProblem(optf, du0_ext, (0.0, 1.0), p) + return solve(prob, HighOrderDescent(); dt=dt, maxiters=maxit) + end + + tspan = (0.0, 10.0) + prob = DAEProblem(dae_residual!, du0_ext, u0_ext, tspan, p; + differential_vars = differential_vars) + + solve_kwargs = setup_progress_callback(cache, Dict()) + if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end + if dt !== nothing; solve_kwargs[:dt] = dt; end + if hasfield(typeof(cache.opt.solver), :initializealg) + solve_kwargs[:initializealg] = BrownFullBasicInit() + end + + sol = solve(prob, cache.opt.solver; solve_kwargs...) + u_ext = sol.u + u_final = u_ext[end][1:n] + + return SciMLBase.build_solution(cache, cache.opt, u_final, cache.f(u_final, p); + retcode = sol.retcode) end + + +end diff --git a/lib/OptimizationODE/test/runtests.jl b/lib/OptimizationODE/test/runtests.jl index f3d8a18c0..c0c3aa867 100644 --- a/lib/OptimizationODE/test/runtests.jl +++ b/lib/OptimizationODE/test/runtests.jl @@ -1,44 +1,278 @@ using Test -using OptimizationODE, SciMLBase, ADTypes +using OptimizationODE +using Optimization +using LinearAlgebra, ForwardDiff +using OrdinaryDiffEq, DifferentialEquations, SteadyStateDiffEq, Sundials -@testset "OptimizationODE Tests" begin +# Test helper functions +function rosenbrock(x, p) + return (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 +end + +function rosenbrock_grad!(grad, x, p) + grad[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * (x[2] - x[1]^2) * x[1] + grad[2] = 2.0 * p[2] * (x[2] - x[1]^2) +end + +function quadratic(x, p) + return (x[1] - p[1])^2 + (x[2] - p[2])^2 +end + +function quadratic_grad!(grad, x, p) + grad[1] = 2.0 * (x[1] - p[1]) + grad[2] = 2.0 * (x[2] - p[2]) +end - function f(x, p) - return sum(abs2, x) +# Constrained optimization problem +function constrained_objective(x, p) + return x[1]^2 + x[2]^2 +end + +function constrained_objective_grad!(grad, x, p) + grad[1] = 2.0 * x[1] + grad[2] = 2.0 * x[2] +end + +function constraint_func(res, x, p) + res[1] = x[1] + x[2] - 1.0 # x[1] + x[2] = 1 + return x[1] + x[2] - 1.0 +end + +function constraint_jac!(jac, x, p) + jac[1, 1] = 1.0 + jac[1, 2] = -1.0 +end + +@testset "OptimizationODE.jl Tests" begin + + + @testset "Basic Unconstrained Optimization" begin + @testset "Quadratic Function - ODE Optimizers" begin + x0 = [2.0, 2.0] + p = [1.0, 1.0] # Minimum at (1, 1) + + optf = OptimizationFunction(quadratic, grad=quadratic_grad!) + prob = OptimizationProblem(optf, x0, p) + + optimizers = [ + ("ODEGradientDescent", ODEGradientDescent()), + ("RKChebyshevDescent", RKChebyshevDescent()), + ("RKAccelerated", RKAccelerated()), + ("HighOrderDescent", HighOrderDescent()) + ] + + for (name, opt) in optimizers + @testset "$name" begin + sol = solve(prob, opt, dt=0.001, maxiters=1000000) + @test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default + @test isapprox(sol.u, p, atol=1e-1) + @test sol.objective < 1e-2 + end + end + end + + @testset "Rosenbrock Function - Selected Optimizers" begin + x0 = [1.5, 2.0] + p = [1.0, 100.0] # Classic Rosenbrock parameters + + optf = OptimizationFunction(rosenbrock, grad=rosenbrock_grad!) + prob = OptimizationProblem(optf, x0, p) + + # Test with more robust optimizers for Rosenbrock + optimizers = [ + ("RKAccelerated", RKAccelerated()), + ("HighOrderDescent", HighOrderDescent()) + ] + + for (name, opt) in optimizers + @testset "$name" begin + sol = solve(prob, opt, dt=0.001, maxiters=1000000) + @test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default + # Rosenbrock is harder, so we use looser tolerances + @test isapprox(sol.u[1], 1.0, atol=1e-1) + @test isapprox(sol.u[2], 1.0, atol=1e-1) + @test sol.objective < 1.0 + end + end + end end + + @testset "Constrained Optimization - DAE Optimizers" begin + @testset "Equality Constrained Optimization" begin + # Minimize f(x) = x₁² + x₂² + # Subject to x₁ - x₂ = 1 - function g!(g, x, p) - @. g = 2 * x + function constrained_objective(x, p) + return x[1]^2 + x[2]^2 end - x0 = [2.0, -3.0] - p = [5.0] + function constrained_objective_grad!(g, x, p) + g .= 2 .* x .* p[1] + return nothing + end - f_autodiff = OptimizationFunction(f, ADTypes.AutoForwardDiff()) - prob_auto = OptimizationProblem(f_autodiff, x0, p) + # Constraint: x₁ - x₂ - p[1] = 0 (p[1] = 1 → x₁ - x₂ = 1) + function constraint_func(x, p) + return x[1] - x[2] - p[1] + end - for opt in (ODEGradientDescent(dt=0.01), RKChebyshevDescent(), RKAccelerated(), HighOrderDescent()) - sol = solve(prob_auto, opt; maxiters=50_000) - @test sol.u ≈ [0.0, 0.0] atol=1e-2 - @test sol.objective ≈ 0.0 atol=1e-2 - @test sol.retcode == ReturnCode.Success + function constraint_jac!(J, x) + J[1, 1] = 1.0 + J[1, 2] = -1.0 + return nothing end - f_manual = OptimizationFunction(f, SciMLBase.NoAD(); grad=g!) - prob_manual = OptimizationProblem(f_manual, x0) + x0 = [1.0, 0.0] # reasonable initial guess + p = [1.0] # enforce x₁ - x₂ = 1 - for opt in (ODEGradientDescent(dt=0.01), RKChebyshevDescent(), RKAccelerated(), HighOrderDescent()) - sol = solve(prob_manual, opt; maxiters=50_000) - @test sol.u ≈ [0.0, 0.0] atol=1e-2 - @test sol.objective ≈ 0.0 atol=1e-2 - @test sol.retcode == ReturnCode.Success - end + optf = OptimizationFunction(constrained_objective; + grad = constrained_objective_grad!, + cons = constraint_func, + cons_j = constraint_jac!) - f_fail = OptimizationFunction(f, SciMLBase.NoAD()) - prob_fail = OptimizationProblem(f_fail, x0) + @testset "Equality Constrained - Mass Matrix Method" begin + prob = OptimizationProblem(optf, x0, p) + opt = DAEMassMatrix() + sol = solve(prob, opt; dt=0.01, maxiters=1_000_000) - for opt in (ODEGradientDescent(dt=0.001), RKChebyshevDescent(), RKAccelerated(), HighOrderDescent()) - @test_throws ErrorException solve(prob_fail, opt; maxiters=20_000) + @test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default + @test isapprox(sol.u[1] - sol.u[2], 1.0; atol = 1e-2) + @test isapprox(sol.u, [0.5, -0.5]; atol = 1e-2) end + @testset "Equality Constrained - Index Method" begin + prob = OptimizationProblem(optf, x0, p) + opt = DAEIndexing() + differential_vars = [true, true, false] # x vars = differential, λ = algebraic + sol = solve(prob, opt; dt=0.01, maxiters=1_000_000, + differential_vars = differential_vars) + + @test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default + @test isapprox(sol.u[1] - sol.u[2], 1.0; atol = 1e-2) + @test isapprox(sol.u, [0.5, -0.5]; atol = 1e-2) + end +end + end + + @testset "Parameter Handling" begin + @testset "NullParameters Handling" begin + x0 = [0.0, 0.0] + p=Float64[] # No parameters provided + # Create a problem with NullParameters + optf = OptimizationFunction((x, p) -> sum(x.^2), + grad=(grad, x, p) -> (grad .= 2.0 .* x)) + prob = OptimizationProblem(optf, x0,p) # No parameters provided + + opt = ODEGradientDescent() + sol = solve(prob, opt, dt=0.01, maxiters=100000) + + @test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default + @test isapprox(sol.u, [0.0, 0.0], atol=1e-2) + end + + @testset "Regular Parameters" begin + x0 = [0.5, 1.5] + p = [1.0, 1.0] + + optf = OptimizationFunction(quadratic, grad=quadratic_grad!) + prob = OptimizationProblem(optf, x0, p) + + opt = RKAccelerated() + sol = solve(prob, opt; dt=0.001, maxiters=1000000) + + @test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default + @test isapprox(sol.u, p, atol=1e-1) + end + end + + @testset "Solver Options and Keywords" begin + @testset "Custom dt and maxiters" begin + x0 = [0.0, 0.0] + p = [1.0, 1.0] + + optf = OptimizationFunction(quadratic, grad=quadratic_grad!) + prob = OptimizationProblem(optf, x0, p) + + opt = RKAccelerated() + + # Test with custom dt + sol1 = solve(prob, opt; dt=0.001, maxiters=100000) + @test sol1.retcode == ReturnCode.Success || sol1.retcode == ReturnCode.Default + + # Test with smaller dt (should be more accurate) + sol2 = solve(prob, opt; dt=0.001, maxiters=100000) + @test sol2.retcode == ReturnCode.Success || sol2.retcode == ReturnCode.Default + @test sol2.objective <= sol1.objective # Should be at least as good + end + end + + @testset "Callback Functionality" begin + @testset "Progress Callback" begin + x0 = [0.0, 0.0] + p = [1.0, 1.0] + + callback_called = Ref(false) + callback_values = Vector{Vector{Float64}}() + + function test_callback(x, p, t) + return false + end + + optf = OptimizationFunction(quadratic; grad=quadratic_grad!) + prob = OptimizationProblem(optf, x0, p) + + opt = RKAccelerated() + sol = solve(prob, opt, dt=0.1, maxiters=100000, callback=test_callback, progress=true) + + @test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default + end + end + + @testset "Finite Difference Jacobian" begin + @testset "Jacobian Computation" begin + x = [1.0, 2.0] + f(x) = [x[1]^2 + x[2], x[1] * x[2]] + + J = ForwardDiff.jacobian(f, x) + + expected_J = [2.0 1.0; 2.0 1.0] + + @test isapprox(J, expected_J, atol=1e-6) + end + end + + @testset "Error Handling and Edge Cases" begin + @testset "Empty Constraints" begin + x0 = [1.5, 0.5] + p = Float64[] + + # Problem without constraints should fall back to ODE method + optf = OptimizationFunction(constrained_objective, + grad=constrained_objective_grad!) + prob = OptimizationProblem(optf, x0, p) + + opt = DAEMassMatrix() + sol = solve(prob, opt; dt=0.001, maxiters=50000) + + @test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default + @test isapprox(sol.u, [0.0, 0.0], atol=1e-1) + end + + @testset "Single Variable Optimization" begin + x0 = [0.5] + p = [1.0] + + single_var_func(x, p) = (x[1] - p[1])^2 + single_var_grad!(grad, x, p) = (grad[1] = 2.0 * (x[1] - p[1])) + + optf = OptimizationFunction(single_var_func; grad=single_var_grad!) + prob = OptimizationProblem(optf, x0, p) + + opt = RKAccelerated() + sol = solve(prob, opt; dt=0.001, maxiters=10000) + + @test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default + @test isapprox(sol.u[1], p[1], atol=1e-1) + end + end end