Skip to content

Commit fbcfdaf

Browse files
fix: handle case when prob.u0 === nothing in linearization_function
1 parent 31073eb commit fbcfdaf

File tree

1 file changed

+38
-21
lines changed

1 file changed

+38
-21
lines changed

src/linearization.jl

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -89,30 +89,43 @@ function linearization_function(sys::AbstractSystem, inputs,
8989
t0 = current_time(prob)
9090
inputvals = [p[idx] for idx in input_idxs]
9191

92-
uf_fun = let fun = prob.f
93-
function uff(du, u, p, t)
94-
SciMLBase.UJacobianWrapper(fun, t, p)(du, u)
95-
end
96-
end
97-
uf_jac = PreparedJacobian{true}(uf_fun, similar(prob.u0), autodiff, prob.u0, DI.Constant(p), DI.Constant(t0))
98-
# observed function is a `GeneratedFunctionWrapper` with iip component
99-
h_jac = PreparedJacobian{true}(h, similar(prob.u0, size(outputs)), autodiff, prob.u0, DI.Constant(p), DI.Constant(t0))
100-
pf_fun = let fun = prob.f, setter = setp_oop(sys, input_idxs)
101-
function pff(du, input, u, p, t)
102-
p = setter(p, input)
103-
SciMLBase.ParamJacobianWrapper(fun, t, u)(du, p)
104-
end
105-
end
106-
pf_jac = PreparedJacobian{true}(pf_fun, similar(prob.u0), autodiff, inputvals, DI.Constant(prob.u0), DI.Constant(p), DI.Constant(t0))
10792
hp_fun = let fun = h, setter = setp_oop(sys, input_idxs)
10893
function hpf(du, input, u, p, t)
10994
p = setter(p, input)
11095
fun(du, u, p, t)
11196
return du
11297
end
11398
end
114-
hp_jac = PreparedJacobian{true}(hp_fun, similar(prob.u0, size(outputs)), autodiff, inputvals, DI.Constant(prob.u0), DI.Constant(p), DI.Constant(t0))
115-
99+
if u0 === nothing
100+
uf_jac = h_jac = pf_jac = nothing
101+
T = p isa MTKParameters ? eltype(p.tunable) : eltype(p)
102+
hp_jac = PreparedJacobian{true}(
103+
hp_fun, zeros(T, size(outputs)), autodiff, inputvals,
104+
DI.Constant(prob.u0), DI.Constant(p), DI.Constant(t0))
105+
else
106+
uf_fun = let fun = prob.f
107+
function uff(du, u, p, t)
108+
SciMLBase.UJacobianWrapper(fun, t, p)(du, u)
109+
end
110+
end
111+
uf_jac = PreparedJacobian{true}(
112+
uf_fun, similar(prob.u0), autodiff, prob.u0, DI.Constant(p), DI.Constant(t0))
113+
# observed function is a `GeneratedFunctionWrapper` with iip component
114+
h_jac = PreparedJacobian{true}(h, similar(prob.u0, size(outputs)), autodiff,
115+
prob.u0, DI.Constant(p), DI.Constant(t0))
116+
pf_fun = let fun = prob.f, setter = setp_oop(ssimilarys, input_idxs)
117+
function pff(du, input, u, p, t)
118+
p = setter(p, input)
119+
SciMLBase.ParamJacobianWrapper(fun, t, u)(du, p)
120+
end
121+
end
122+
pf_jac = PreparedJacobian{true}(pf_fun, similar(prob.u0), autodiff, inputvals,
123+
DI.Constant(prob.u0), DI.Constant(p), DI.Constant(t0))
124+
hp_jac = PreparedJacobian{true}(
125+
hp_fun, similar(prob.u0, size(outputs)), autodiff, inputvals,
126+
DI.Constant(prob.u0), DI.Constant(p), DI.Constant(t0))
127+
end
128+
116129
lin_fun = LinearizationFunction(
117130
diff_idxs, alge_idxs, input_idxs, length(unknowns(sys)),
118131
prob, h, u0 === nothing ? nothing : similar(u0), uf_jac, h_jac, pf_jac,
@@ -151,12 +164,14 @@ end
151164

152165
function PreparedJacobian{true}(f, buf, autodiff, args...)
153166
prep = DI.prepare_jacobian(f, buf, autodiff, args...)
154-
return PreparedJacobian{true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)}(prep, f, buf, autodiff)
167+
return PreparedJacobian{true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)}(
168+
prep, f, buf, autodiff)
155169
end
156170

157171
function PreparedJacobian{false}(f, autodiff, args...)
158172
prep = DI.prepare_jacobian(f, autodiff, args...)
159-
return PreparedJacobian{true, typeof(prep), typeof(f), Nothing, typeof(autodiff)}(prep, f, nothing)
173+
return PreparedJacobian{true, typeof(prep), typeof(f), Nothing, typeof(autodiff)}(
174+
prep, f, nothing)
160175
end
161176

162177
function (pj::PreparedJacobian{true})(args...)
@@ -279,14 +294,16 @@ function (linfun::LinearizationFunction)(u, p, t)
279294
end
280295
fg_xz = linfun.uf_jac(u, DI.Constant(p), DI.Constant(t))
281296
h_xz = linfun.h_jac(u, DI.Constant(p), DI.Constant(t))
282-
fg_u = linfun.pf_jac([p[idx] for idx in linfun.input_idxs], DI.Constant(u), DI.Constant(p), DI.Constant(t))
297+
fg_u = linfun.pf_jac([p[idx] for idx in linfun.input_idxs],
298+
DI.Constant(u), DI.Constant(p), DI.Constant(t))
283299
else
284300
linfun.num_states == 0 ||
285301
error("Number of unknown variables (0) does not match the number of input unknowns ($(length(u)))")
286302
fg_xz = zeros(0, 0)
287303
h_xz = fg_u = zeros(0, length(linfun.input_idxs))
288304
end
289-
h_u = linfun.hp_jac([p[idx] for idx in linfun.input_idxs], DI.Constant(u), DI.Constant(p), DI.Constant(t))
305+
h_u = linfun.hp_jac([p[idx] for idx in linfun.input_idxs],
306+
DI.Constant(u), DI.Constant(p), DI.Constant(t))
290307
(f_x = fg_xz[linfun.diff_idxs, linfun.diff_idxs],
291308
f_z = fg_xz[linfun.diff_idxs, linfun.alge_idxs],
292309
g_x = fg_xz[linfun.alge_idxs, linfun.diff_idxs],

0 commit comments

Comments
 (0)