Skip to content

Commit 980e11d

Browse files
start fixing
1 parent 760c847 commit 980e11d

File tree

3 files changed

+40
-57
lines changed

3 files changed

+40
-57
lines changed

lib/OptimizationODE/Project.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1212
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1313
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
1414
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
15-
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
16-
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
1715

1816
[compat]
1917
ForwardDiff = "0.10, 1"
@@ -26,7 +24,8 @@ julia = "1.10"
2624

2725
[extras]
2826
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
27+
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
2928
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3029

3130
[targets]
32-
test = ["ADTypes", "Test"]
31+
test = ["ADTypes", "Sundials", "Test"]

lib/OptimizationODE/src/OptimizationODE.jl

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ using Reexport
55
using LinearAlgebra, ForwardDiff
66

77
using NonlinearSolve
8-
using OrdinaryDiffEq, DifferentialEquations, SteadyStateDiffEq, Sundials
8+
using OrdinaryDiffEq, SteadyStateDiffEq, Sundials
99

1010
export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent
11-
export DAEOptimizer, DAEMassMatrix, DAEIndexing
11+
export DAEOptimizer, DAEMassMatrix
1212

1313
struct ODEOptimizer{T}
1414
solver::T
@@ -23,8 +23,7 @@ struct DAEOptimizer{T}
2323
solver::T
2424
end
2525

26-
DAEMassMatrix() = DAEOptimizer(Rosenbrock23(autodiff = false))
27-
DAEIndexing() = DAEOptimizer(IDA())
26+
DAEMassMatrix() = DAEOptimizer(Rodas5P(autodiff = false))
2827

2928

3029
SciMLBase.requiresbounds(::ODEOptimizer) = false
@@ -62,29 +61,6 @@ function SciMLBase.__init(prob::OptimizationProblem, opt::DAEOptimizer;
6261
maxiters=maxiters, differential_vars=differential_vars, kwargs...)
6362
end
6463

65-
66-
function handle_parameters(p)
67-
if p isa SciMLBase.NullParameters
68-
return Float64[]
69-
else
70-
return p
71-
end
72-
end
73-
74-
function setup_progress_callback(cache, solve_kwargs)
75-
if get(cache.solver_args, :progress, false)
76-
condition = (u, t, integrator) -> true
77-
affect! = (integrator) -> begin
78-
u_opt = integrator.u isa AbstractArray ? integrator.u : integrator.u.u
79-
cache.solver_args[:callback](u_opt, integrator.p, integrator.t)
80-
end
81-
cb = DiscreteCallback(condition, affect!)
82-
solve_kwargs[:callback] = cb
83-
end
84-
return solve_kwargs
85-
end
86-
87-
8864
function SciMLBase.__solve(
8965
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
9066
) where {F,RC,LB,UB,LC,UC,S,O<:Union{ODEOptimizer,DAEOptimizer},D,P,C}
@@ -93,15 +69,15 @@ function SciMLBase.__solve(
9369
maxit = get(cache.solver_args, :maxiters, nothing)
9470
differential_vars = get(cache.solver_args, :differential_vars, nothing)
9571
u0 = copy(cache.u0)
96-
p = handle_parameters(cache.p) # Properly handle NullParameters
72+
p = cache.p # Properly handle NullParameters
9773

9874
if cache.opt isa ODEOptimizer
9975
return solve_ode(cache, dt, maxit, u0, p)
10076
else
101-
if cache.opt.solver == Rosenbrock23(autodiff = false)
102-
return solve_dae_mass_matrix(cache, dt, maxit, u0, p)
77+
if cache.opt.solver isa SciMLBase.AbstractDAEAlgorithm
78+
return solve_dae_implicit(cache, dt, maxit, u0, p, differential_vars)
10379
else
104-
return solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars)
80+
return solve_dae_mass_matrix(cache, dt, maxit, u0, p)
10581
end
10682
end
10783
end
@@ -112,41 +88,37 @@ function solve_ode(cache, dt, maxit, u0, p)
11288
end
11389

11490
function f!(du, u, p, t)
115-
grad_vec = similar(u)
116-
if isempty(p)
117-
cache.f.grad(grad_vec, u)
118-
else
119-
cache.f.grad(grad_vec, u, p)
120-
end
121-
@. du = -grad_vec
91+
cache.f.grad(du, u, p)
92+
@. du = -du
12293
return nothing
12394
end
12495

12596
ss_prob = SteadyStateProblem(f!, u0, p)
12697

12798
algorithm = DynamicSS(cache.opt.solver)
12899

129-
cb = cache.callback
130-
if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false)
131-
function condition(u, t, integrator) true end
132-
function affect!(integrator)
133-
u_now = integrator.u
134-
cache.callback(u_now, integrator.p, integrator.t)
100+
if cache.callback !== Optimization.DEFAULT_CALLBACK
101+
condition = (u, t, integrator) -> true
102+
affect! = (integrator) -> begin
103+
u_opt = integrator.u isa AbstractArray ? integrator.u : integrator.u.u
104+
l = cache.f(integrator.u, integrator.p)
105+
cache.callback(integrator.u, l)
135106
end
136-
cb_struct = DiscreteCallback(condition, affect!)
137-
callback = CallbackSet(cb_struct)
107+
cb = DiscreteCallback(condition, affect!)
108+
solve_kwargs = Dict{Symbol, Any}(:callback => cb)
138109
else
139-
callback = nothing
110+
solve_kwargs = Dict{Symbol, Any}()
140111
end
141-
142-
solve_kwargs = Dict{Symbol, Any}(:callback => callback)
112+
143113
if !isnothing(maxit)
144114
solve_kwargs[:maxiters] = maxit
145115
end
146116
if dt !== nothing
147117
solve_kwargs[:dt] = dt
148118
end
149119

120+
solve_kwargs[:progress] = cache.progress
121+
150122
sol = solve(ss_prob, algorithm; solve_kwargs...)
151123
has_destats = hasproperty(sol, :destats)
152124
has_t = hasproperty(sol, :t) && !isempty(sol.t)
@@ -218,7 +190,7 @@ function solve_dae_mass_matrix(cache, dt, maxit, u0, p)
218190
end
219191

220192

221-
function solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars)
193+
function solve_dae_implicit(cache, dt, maxit, u0, p, differential_vars)
222194
if cache.f.cons === nothing
223195
return solve_ode(cache, dt, maxit, u0, p)
224196
end

lib/OptimizationODE/test/runtests.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ end
142142

143143
@testset "Equality Constrained - Index Method" begin
144144
prob = OptimizationProblem(optf, x0, p)
145-
opt = DAEIndexing()
145+
opt = DAEOptimizer(IDA())
146146
differential_vars = [true, true, false] # x vars = differential, λ = algebraic
147147
sol = solve(prob, opt; dt=0.01, maxiters=1_000_000,
148148
differential_vars = differential_vars)
@@ -207,22 +207,34 @@ end
207207
end
208208

209209
@testset "Callback Functionality" begin
210-
@testset "Progress Callback" begin
210+
@testset "Progress Callback" begin
211211
x0 = [0.0, 0.0]
212212
p = [1.0, 1.0]
213213

214214
callback_called = Ref(false)
215215
callback_values = Vector{Vector{Float64}}()
216216

217-
function test_callback(x, p, t)
217+
function test_callback(state, l)
218218
return false
219219
end
220220

221221
optf = OptimizationFunction(quadratic; grad=quadratic_grad!)
222222
prob = OptimizationProblem(optf, x0, p)
223223

224224
opt = RKAccelerated()
225-
sol = solve(prob, opt, dt=0.1, maxiters=100000, callback=test_callback, progress=true)
225+
sol = solve(prob, opt, dt=0.1, maxiters=100000, callback=test_callback)
226+
227+
@test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default
228+
end
229+
@testset "Progress Bar" begin
230+
x0 = [0.0, 0.0]
231+
p = [1.0, 1.0]
232+
233+
optf = OptimizationFunction(quadratic; grad=quadratic_grad!)
234+
prob = OptimizationProblem(optf, x0, p)
235+
236+
opt = RKAccelerated()
237+
sol = solve(prob, opt, dt=0.1, maxiters=100000, progress=true)
226238

227239
@test sol.retcode == ReturnCode.Success || sol.retcode == ReturnCode.Default
228240
end

0 commit comments

Comments
 (0)