@@ -5,10 +5,10 @@ using Reexport
5
5
using LinearAlgebra, ForwardDiff
6
6
7
7
using NonlinearSolve
8
- using OrdinaryDiffEq, DifferentialEquations, SteadyStateDiffEq, Sundials
8
+ using OrdinaryDiffEq, SteadyStateDiffEq, Sundials
9
9
10
10
export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent
11
- export DAEOptimizer, DAEMassMatrix, DAEIndexing
11
+ export DAEOptimizer, DAEMassMatrix
12
12
13
13
struct ODEOptimizer{T}
14
14
solver:: T
@@ -23,8 +23,7 @@ struct DAEOptimizer{T}
23
23
solver:: T
24
24
end
25
25
26
- DAEMassMatrix () = DAEOptimizer (Rosenbrock23 (autodiff = false ))
27
- DAEIndexing () = DAEOptimizer (IDA ())
26
+ DAEMassMatrix () = DAEOptimizer (Rodas5P (autodiff = false ))
28
27
29
28
30
29
SciMLBase. requiresbounds (:: ODEOptimizer ) = false
@@ -62,29 +61,6 @@ function SciMLBase.__init(prob::OptimizationProblem, opt::DAEOptimizer;
62
61
maxiters= maxiters, differential_vars= differential_vars, kwargs... )
63
62
end
64
63
65
-
66
- function handle_parameters (p)
67
- if p isa SciMLBase. NullParameters
68
- return Float64[]
69
- else
70
- return p
71
- end
72
- end
73
-
74
- function setup_progress_callback (cache, solve_kwargs)
75
- if get (cache. solver_args, :progress , false )
76
- condition = (u, t, integrator) -> true
77
- affect! = (integrator) -> begin
78
- u_opt = integrator. u isa AbstractArray ? integrator. u : integrator. u. u
79
- cache. solver_args[:callback ](u_opt, integrator. p, integrator. t)
80
- end
81
- cb = DiscreteCallback (condition, affect!)
82
- solve_kwargs[:callback ] = cb
83
- end
84
- return solve_kwargs
85
- end
86
-
87
-
88
64
function SciMLBase. __solve (
89
65
cache:: OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
90
66
) where {F,RC,LB,UB,LC,UC,S,O<: Union{ODEOptimizer,DAEOptimizer} ,D,P,C}
@@ -93,15 +69,15 @@ function SciMLBase.__solve(
93
69
maxit = get (cache. solver_args, :maxiters , nothing )
94
70
differential_vars = get (cache. solver_args, :differential_vars , nothing )
95
71
u0 = copy (cache. u0)
96
- p = handle_parameters ( cache. p) # Properly handle NullParameters
72
+ p = cache. p # Properly handle NullParameters
97
73
98
74
if cache. opt isa ODEOptimizer
99
75
return solve_ode (cache, dt, maxit, u0, p)
100
76
else
101
- if cache. opt. solver == Rosenbrock23 (autodiff = false )
102
- return solve_dae_mass_matrix (cache, dt, maxit, u0, p)
77
+ if cache. opt. solver isa SciMLBase . AbstractDAEAlgorithm
78
+ return solve_dae_implicit (cache, dt, maxit, u0, p, differential_vars )
103
79
else
104
- return solve_dae_indexing (cache, dt, maxit, u0, p, differential_vars )
80
+ return solve_dae_mass_matrix (cache, dt, maxit, u0, p)
105
81
end
106
82
end
107
83
end
@@ -112,41 +88,37 @@ function solve_ode(cache, dt, maxit, u0, p)
112
88
end
113
89
114
90
function f! (du, u, p, t)
115
- grad_vec = similar (u)
116
- if isempty (p)
117
- cache. f. grad (grad_vec, u)
118
- else
119
- cache. f. grad (grad_vec, u, p)
120
- end
121
- @. du = - grad_vec
91
+ cache. f. grad (du, u, p)
92
+ @. du = - du
122
93
return nothing
123
94
end
124
95
125
96
ss_prob = SteadyStateProblem (f!, u0, p)
126
97
127
98
algorithm = DynamicSS (cache. opt. solver)
128
99
129
- cb = cache. callback
130
- if cb != Optimization . DEFAULT_CALLBACK || get (cache . solver_args, :progress , false )
131
- function condition (u, t, integrator) true end
132
- function affect! ( integrator)
133
- u_now = integrator. u
134
- cache. callback (u_now, integrator. p, integrator . t )
100
+ if cache. callback != = Optimization . DEFAULT_CALLBACK
101
+ condition = (u, t, integrator) -> true
102
+ affect! = ( integrator) -> begin
103
+ u_opt = integrator. u isa AbstractArray ? integrator . u : integrator . u . u
104
+ l = cache . f ( integrator. u, integrator . p)
105
+ cache. callback (integrator. u, l )
135
106
end
136
- cb_struct = DiscreteCallback (condition, affect!)
137
- callback = CallbackSet (cb_struct )
107
+ cb = DiscreteCallback (condition, affect!)
108
+ solve_kwargs = Dict {Symbol, Any} ( :callback => cb )
138
109
else
139
- callback = nothing
110
+ solve_kwargs = Dict {Symbol, Any} ()
140
111
end
141
-
142
- solve_kwargs = Dict {Symbol, Any} (:callback => callback)
112
+
143
113
if ! isnothing (maxit)
144
114
solve_kwargs[:maxiters ] = maxit
145
115
end
146
116
if dt != = nothing
147
117
solve_kwargs[:dt ] = dt
148
118
end
149
119
120
+ solve_kwargs[:progress ] = cache. progress
121
+
150
122
sol = solve (ss_prob, algorithm; solve_kwargs... )
151
123
has_destats = hasproperty (sol, :destats )
152
124
has_t = hasproperty (sol, :t ) && ! isempty (sol. t)
@@ -218,7 +190,7 @@ function solve_dae_mass_matrix(cache, dt, maxit, u0, p)
218
190
end
219
191
220
192
221
- function solve_dae_indexing (cache, dt, maxit, u0, p, differential_vars)
193
+ function solve_dae_implicit (cache, dt, maxit, u0, p, differential_vars)
222
194
if cache. f. cons === nothing
223
195
return solve_ode (cache, dt, maxit, u0, p)
224
196
end
0 commit comments