Skip to content

Commit 6a31496

Browse files
Merge pull request #932 from ParasPuneetSingh/daenew
DAE optimizers added to OptimizationODE
2 parents 59380b1 + e841c56 commit 6a31496

File tree

4 files changed

+490
-71
lines changed

4 files changed

+490
-71
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# OptimizationODE.jl
2+
3+
**OptimizationODE.jl** provides ODE-based optimization methods as a solver plugin for [SciML's Optimization.jl](https://github.com/SciML/Optimization.jl). It wraps various ODE solvers to perform gradient-based optimization using continuous-time dynamics.
4+
5+
## Installation
6+
7+
```julia
8+
using Pkg
9+
Pkg.add("OptimizationODE")
10+
```
11+
12+
## Usage
13+
14+
```julia
15+
using OptimizationODE, Optimization, ADTypes, SciMLBase
16+
17+
function f(x, p)
18+
return sum(abs2, x)
19+
end
20+
21+
function g!(g, x, p)
22+
@. g = 2 * x
23+
end
24+
25+
x0 = [2.0, -3.0]
26+
p = []
27+
28+
f_manual = OptimizationFunction(f, SciMLBase.NoAD(); grad = g!)
29+
prob_manual = OptimizationProblem(f_manual, x0)
30+
31+
opt = ODEGradientDescent(dt=0.01)
32+
sol = solve(prob_manual, opt; maxiters=50_000)
33+
34+
@show sol.u
35+
@show sol.objective
36+
```
37+
38+
## Local Gradient-based Optimizers
39+
40+
All provided optimizers are **gradient-based local optimizers** that solve optimization problems by integrating gradient-based ODEs to convergence:
41+
42+
* `ODEGradientDescent(dt=...)` — performs basic gradient descent using the explicit Euler method. This is a simple and efficient method suitable for small-scale or well-conditioned problems.
43+
44+
* `RKChebyshevDescent()` — uses the ROCK2 solver, a stabilized explicit Runge-Kutta method suitable for stiff problems. It allows larger step sizes while maintaining stability.
45+
46+
* `RKAccelerated()` — leverages the Tsit5 method, a 5th-order Runge-Kutta solver that achieves faster convergence for smooth problems by improving integration accuracy.
47+
48+
* `HighOrderDescent()` — applies Vern7, a high-order (7th-order) explicit Runge-Kutta method for even more accurate integration. This can be beneficial for problems requiring high precision.
49+
50+
You can also define a custom optimizer using the generic `ODEOptimizer(solver; dt=nothing)` constructor by supplying any ODE solver supported by [OrdinaryDiffEq.jl](https://docs.sciml.ai/DiffEqDocs/stable/solvers/ode_solve/).
51+
52+
## DAE-based Optimizers
53+
54+
!!! warn
55+
DAE-based optimizers are still experimental and a research project. Use with caution.
56+
57+
In addition to ODE-based optimizers, OptimizationODE.jl provides optimizers for differential-algebraic equation (DAE) constrained problems:
58+
59+
* `DAEMassMatrix()` — uses the Rodas5P solver (from OrdinaryDiffEq.jl) for DAE problems with a mass matrix formulation.
60+
61+
* `DAEOptimizer(IDA())` — uses the IDA solver (from Sundials.jl) for DAE problems with index variable support (requires `using Sundials`)
62+
63+
You can also define a custom optimizer using the generic `ODEOptimizer(solver)` or `DAEOptimizer(solver)` constructor by supplying any ODE or DAE solver supported by [OrdinaryDiffEq.jl](https://docs.sciml.ai/DiffEqDocs/stable/solvers/ode_solve/) or [Sundials.jl](https://github.com/SciML/Sundials.jl).
64+
65+
## Interface Details
66+
67+
All optimizers require gradient information (either via automatic differentiation or manually provided `grad!`). The optimization is performed by integrating the ODE defined by the negative gradient until a steady state is reached.

lib/OptimizationODE/Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@ authors = ["Paras Puneet Singh <[email protected]>"]
44
version = "0.1.1"
55

66
[deps]
7+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
78
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
89
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
911
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1012
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1113
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1214
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
15+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1316

1417
[compat]
1518
ForwardDiff = "0.10, 1"
@@ -22,7 +25,8 @@ julia = "1.10"
2225

2326
[extras]
2427
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
28+
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
2529
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2630

2731
[targets]
28-
test = ["ADTypes", "Test"]
32+
test = ["ADTypes", "Sundials", "Test"]

lib/OptimizationODE/src/OptimizationODE.jl

Lines changed: 168 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,87 @@ module OptimizationODE
22

33
using Reexport
44
@reexport using Optimization, SciMLBase
5+
using LinearAlgebra, ForwardDiff
6+
using DiffEqBase
7+
8+
using NonlinearSolve
59
using OrdinaryDiffEq, SteadyStateDiffEq
610

711
export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent
12+
export DAEOptimizer, DAEMassMatrix
813

9-
struct ODEOptimizer{T, T2}
14+
struct ODEOptimizer{T}
1015
solver::T
11-
dt::T2
1216
end
13-
ODEOptimizer(solver; dt = nothing) = ODEOptimizer(solver, dt)
1417

15-
# Solver Constructors (users call these)
16-
ODEGradientDescent(; dt) = ODEOptimizer(Euler(); dt)
18+
ODEGradientDescent() = ODEOptimizer(Euler())
1719
RKChebyshevDescent() = ODEOptimizer(ROCK2())
1820
RKAccelerated() = ODEOptimizer(Tsit5())
1921
HighOrderDescent() = ODEOptimizer(Vern7())
2022

23+
struct DAEOptimizer{T}
24+
solver::T
25+
end
26+
27+
DAEMassMatrix() = DAEOptimizer(Rodas5P(autodiff = false))
28+
29+
2130
SciMLBase.requiresbounds(::ODEOptimizer) = false
2231
SciMLBase.allowsbounds(::ODEOptimizer) = false
2332
SciMLBase.allowscallback(::ODEOptimizer) = true
24-
@static if isdefined(SciMLBase, :supports_opt_cache_interface)
25-
SciMLBase.supports_opt_cache_interface(::ODEOptimizer) = true
26-
end
27-
@static if isdefined(OptimizationBase, :supports_opt_cache_interface)
28-
OptimizationBase.supports_opt_cache_interface(::ODEOptimizer) = true
29-
end
33+
SciMLBase.supports_opt_cache_interface(::ODEOptimizer) = true
3034
SciMLBase.requiresgradient(::ODEOptimizer) = true
3135
SciMLBase.requireshessian(::ODEOptimizer) = false
3236
SciMLBase.requiresconsjac(::ODEOptimizer) = false
3337
SciMLBase.requiresconshess(::ODEOptimizer) = false
3438

39+
40+
SciMLBase.requiresbounds(::DAEOptimizer) = false
41+
SciMLBase.allowsbounds(::DAEOptimizer) = false
42+
SciMLBase.allowsconstraints(::DAEOptimizer) = true
43+
SciMLBase.allowscallback(::DAEOptimizer) = true
44+
SciMLBase.supports_opt_cache_interface(::DAEOptimizer) = true
45+
SciMLBase.requiresgradient(::DAEOptimizer) = true
46+
SciMLBase.requireshessian(::DAEOptimizer) = false
47+
SciMLBase.requiresconsjac(::DAEOptimizer) = true
48+
SciMLBase.requiresconshess(::DAEOptimizer) = false
49+
50+
3551
function SciMLBase.__init(prob::OptimizationProblem, opt::ODEOptimizer;
36-
callback = Optimization.DEFAULT_CALLBACK, progress = false,
37-
maxiters = nothing, kwargs...)
38-
return OptimizationCache(prob, opt; callback = callback, progress = progress,
39-
maxiters = maxiters, kwargs...)
52+
callback=Optimization.DEFAULT_CALLBACK, progress=false, dt=nothing,
53+
maxiters=nothing, kwargs...)
54+
return OptimizationCache(prob, opt; callback=callback, progress=progress, dt=dt,
55+
maxiters=maxiters, kwargs...)
56+
end
57+
58+
function SciMLBase.__init(prob::OptimizationProblem, opt::DAEOptimizer;
59+
callback=Optimization.DEFAULT_CALLBACK, progress=false, dt=nothing,
60+
maxiters=nothing, kwargs...)
61+
return OptimizationCache(prob, opt; callback=callback, progress=progress, dt=dt,
62+
maxiters=maxiters, kwargs...)
4063
end
4164

4265
function SciMLBase.__solve(
43-
cache::OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P,
44-
C}
45-
) where {F, RC, LB, UB, LC, UC, S, O <: ODEOptimizer, D, P, C}
46-
dt = cache.opt.dt
47-
maxit = get(cache.solver_args, :maxiters, 1000)
66+
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
67+
) where {F,RC,LB,UB,LC,UC,S,O<:Union{ODEOptimizer,DAEOptimizer},D,P,C}
4868

69+
dt = get(cache.solver_args, :dt, nothing)
70+
maxit = get(cache.solver_args, :maxiters, nothing)
4971
u0 = copy(cache.u0)
50-
p = cache.p
72+
p = cache.p # Properly handle NullParameters
5173

74+
if cache.opt isa ODEOptimizer
75+
return solve_ode(cache, dt, maxit, u0, p)
76+
else
77+
if cache.opt.solver isa SciMLBase.AbstractDAEAlgorithm
78+
return solve_dae_implicit(cache, dt, maxit, u0, p)
79+
else
80+
return solve_dae_mass_matrix(cache, dt, maxit, u0, p)
81+
end
82+
end
83+
end
84+
85+
function solve_ode(cache, dt, maxit, u0, p)
5286
if cache.f.grad === nothing
5387
error("ODEOptimizer requires a gradient. Please provide a function with `grad` defined.")
5488
end
@@ -63,39 +97,34 @@ function SciMLBase.__solve(
6397

6498
algorithm = DynamicSS(cache.opt.solver)
6599

66-
cb = cache.callback
67-
if cb != Optimization.DEFAULT_CALLBACK ||
68-
get(cache.solver_args, :progress, false) === true
69-
function condition(u, t, integrator)
70-
true
71-
end
72-
function affect!(integrator)
73-
u_now = integrator.u
74-
state = Optimization.OptimizationState(u = u_now, p = integrator.p,
75-
objective = cache.f(integrator.u, integrator.p))
76-
Optimization.callback_function(cb, state)
100+
if cache.callback !== Optimization.DEFAULT_CALLBACK
101+
condition = (u, t, integrator) -> true
102+
affect! = (integrator) -> begin
103+
u_opt = integrator.u isa AbstractArray ? integrator.u : integrator.u.u
104+
l = cache.f(integrator.u, integrator.p)
105+
cache.callback(integrator.u, l)
77106
end
78-
cb_struct = DiscreteCallback(condition, affect!)
79-
callback = CallbackSet(cb_struct)
107+
cb = DiscreteCallback(condition, affect!)
108+
solve_kwargs = Dict{Symbol, Any}(:callback => cb)
80109
else
81-
callback = nothing
110+
solve_kwargs = Dict{Symbol, Any}()
82111
end
83-
84-
solve_kwargs = Dict{Symbol, Any}(:callback => callback)
112+
85113
if !isnothing(maxit)
86114
solve_kwargs[:maxiters] = maxit
87115
end
88116
if dt !== nothing
89117
solve_kwargs[:dt] = dt
90118
end
91119

120+
solve_kwargs[:progress] = cache.progress
121+
92122
sol = solve(ss_prob, algorithm; solve_kwargs...)
93123
has_destats = hasproperty(sol, :destats)
94124
has_t = hasproperty(sol, :t) && !isempty(sol.t)
95125

96126
stats = Optimization.OptimizationStats(
97-
iterations = has_destats ? get(sol.destats, :iters, 10) :
98-
(has_t ? length(sol.t) - 1 : 10),
127+
iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10),
99128
time = has_t ? sol.t[end] : 0.0,
100129
fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0,
101130
gevals = has_destats ? get(sol.destats, :iters, 0) : 0,
@@ -108,4 +137,104 @@ function SciMLBase.__solve(
108137
)
109138
end
110139

140+
function solve_dae_mass_matrix(cache, dt, maxit, u0, p)
141+
if cache.f.cons === nothing
142+
error("DAEOptimizer requires constraints. Please provide a function with `cons` defined.")
143+
end
144+
n = length(u0)
145+
m = length(cache.ucons)
146+
147+
if m > n
148+
error("DAEOptimizer with mass matrix method requires the number of constraints to be less than or equal to the number of variables.")
149+
end
150+
M = Diagonal([ones(n-m); zeros(m)])
151+
function f_mass!(du, u, p_, t)
152+
cache.f.grad(du, u, p)
153+
@. du = -du
154+
consout = @view du[(n-m)+1:end]
155+
cache.f.cons(consout, u)
156+
return nothing
157+
end
158+
159+
ss_prob = SteadyStateProblem(ODEFunction(f_mass!, mass_matrix = M), u0, p)
160+
161+
if cache.callback !== Optimization.DEFAULT_CALLBACK
162+
condition = (u, t, integrator) -> true
163+
affect! = (integrator) -> begin
164+
u_opt = integrator.u isa AbstractArray ? integrator.u : integrator.u.u
165+
l = cache.f(integrator.u, integrator.p)
166+
cache.callback(integrator.u, l)
167+
end
168+
cb = DiscreteCallback(condition, affect!)
169+
solve_kwargs = Dict{Symbol, Any}(:callback => cb)
170+
else
171+
solve_kwargs = Dict{Symbol, Any}()
172+
end
173+
174+
solve_kwargs[:progress] = cache.progress
175+
if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end
176+
if dt !== nothing; solve_kwargs[:dt] = dt; end
177+
178+
sol = solve(ss_prob, DynamicSS(cache.opt.solver); solve_kwargs...)
179+
# if sol.retcode ≠ ReturnCode.Success
180+
# # you may still accept Default or warn
181+
# end
182+
u_ext = sol.u
183+
u_final = u_ext[1:n]
184+
return SciMLBase.build_solution(cache, cache.opt, u_final, cache.f(u_final, p);
185+
retcode = sol.retcode)
186+
end
187+
188+
function solve_dae_implicit(cache, dt, maxit, u0, p)
189+
if cache.f.cons === nothing
190+
error("DAEOptimizer requires constraints. Please provide a function with `cons` defined.")
191+
end
192+
193+
n = length(u0)
194+
m = length(cache.ucons)
195+
196+
if m > n
197+
error("DAEOptimizer with mass matrix method requires the number of constraints to be less than or equal to the number of variables.")
198+
end
199+
200+
function dae_residual!(res, du, u, p_, t)
201+
cache.f.grad(res, u, p)
202+
@. res = du-res
203+
consout = @view res[(n-m)+1:end]
204+
cache.f.cons(consout, u)
205+
return nothing
206+
end
207+
208+
tspan = (0.0, 10.0)
209+
du0 = zero(u0)
210+
prob = DAEProblem(dae_residual!, du0, u0, tspan, p)
211+
212+
if cache.callback !== Optimization.DEFAULT_CALLBACK
213+
condition = (u, t, integrator) -> true
214+
affect! = (integrator) -> begin
215+
u_opt = integrator.u isa AbstractArray ? integrator.u : integrator.u.u
216+
l = cache.f(integrator.u, integrator.p)
217+
cache.callback(integrator.u, l)
218+
end
219+
cb = DiscreteCallback(condition, affect!)
220+
solve_kwargs = Dict{Symbol, Any}(:callback => cb)
221+
else
222+
solve_kwargs = Dict{Symbol, Any}()
223+
end
224+
225+
solve_kwargs[:progress] = cache.progress
226+
227+
if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end
228+
if dt !== nothing; solve_kwargs[:dt] = dt; end
229+
solve_kwargs[:initializealg] = DiffEqBase.ShampineCollocationInit()
230+
231+
sol = solve(prob, cache.opt.solver; solve_kwargs...)
232+
u_ext = sol.u
233+
u_final = u_ext[end][1:n]
234+
235+
return SciMLBase.build_solution(cache, cache.opt, u_final, cache.f(u_final, p);
236+
retcode = sol.retcode)
111237
end
238+
239+
240+
end

0 commit comments

Comments
 (0)