Skip to content

Commit 0a9134e

Browse files
Merge pull request #3569 from AayushSabharwal/as/nosplit-array-initials
fix: scalarize `Initial` parameters for `split = false` systems
2 parents e574fb3 + 3dea97c commit 0a9134e

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

src/systems/abstractsystem.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ end
696696

697697
supports_initialization(sys::AbstractSystem) = true
698698

699-
function add_initialization_parameters(sys::AbstractSystem)
699+
function add_initialization_parameters(sys::AbstractSystem; split = true)
700700
@assert !has_systems(sys) || isempty(get_systems(sys))
701701
supports_initialization(sys) || return sys
702702
is_initializesystem(sys) && return sys
@@ -711,7 +711,7 @@ function add_initialization_parameters(sys::AbstractSystem)
711711
obs, eqs = unhack_observed(observed(sys), eqs)
712712
for x in Iterators.flatten((unknowns(sys), Iterators.map(eq -> eq.lhs, obs)))
713713
x = unwrap(x)
714-
if iscall(x) && operation(x) == getindex
714+
if iscall(x) && operation(x) == getindex && split
715715
push!(all_initialvars, arguments(x)[1])
716716
else
717717
push!(all_initialvars, x)
@@ -788,7 +788,7 @@ function complete(
788788
end
789789
sys = newsys
790790
if add_initial_parameters
791-
sys = add_initialization_parameters(sys)
791+
sys = add_initialization_parameters(sys; split)
792792
end
793793
end
794794
if split && has_index_cache(sys)
@@ -1465,7 +1465,11 @@ function parameters(sys::AbstractSystem; initial_parameters = false)
14651465
result = unique(isempty(systems) ? ps :
14661466
[ps; reduce(vcat, namespace_parameters.(systems))])
14671467
if !initial_parameters && !is_initializesystem(sys)
1468-
filter!(x -> !iscall(x) || !isa(operation(x), Initial), result)
1468+
filter!(result) do sym
1469+
return !(isoperator(sym, Initial) ||
1470+
iscall(sym) && operation(sym) == getindex &&
1471+
isoperator(arguments(sym)[1], Initial))
1472+
end
14691473
end
14701474
return result
14711475
end

test/initial_values.jl

+16
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,19 @@ end
265265
@test eltype(oprob.u0) == Float32
266266
@test eltype(eltype(sol.u)) == Float32
267267
end
268+
269+
@testset "Array initials and scalar parameters with `split = false`" begin
270+
@variables x(t)[1:2]
271+
@parameters p
272+
@mtkbuild sys=ODESystem([D(x[1]) ~ x[1], D(x[2]) ~ x[2] + p], t) split=false
273+
ps = Set(parameters(sys; initial_parameters = true))
274+
@test length(ps) == 5
275+
for i in 1:2
276+
@test Initial(x[i]) in ps
277+
@test Initial(D(x[i])) in ps
278+
end
279+
@test p in ps
280+
prob = ODEProblem(sys, [x => ones(2)], (0.0, 1.0), [p => 1.0])
281+
@test prob.p isa Vector{Float64}
282+
@test length(prob.p) == 5
283+
end

test/odesystem.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1305,7 +1305,7 @@ end
13051305
ps = zeros(8)
13061306
setp(sys2, x)(ps, 2ones(2))
13071307
setp(sys2, p)(ps, 2ones(2, 2))
1308-
@test_nowarn fn2(ones(4), 2ones(6), 4.0)
1308+
@test_nowarn fn2(ones(4), 2ones(14), 4.0)
13091309
end
13101310

13111311
# https://github.com/SciML/ModelingToolkit.jl/issues/2969
@@ -1416,7 +1416,7 @@ end
14161416
obsfn = ModelingToolkit.build_explicit_observed_function(
14171417
sys1, u + x + p[1:2]; inputs = [x...])
14181418

1419-
@test obsfn(ones(2), 2ones(2), 3ones(4), 4.0) == 6ones(2)
1419+
@test obsfn(ones(2), 2ones(2), 3ones(12), 4.0) == 6ones(2)
14201420
end
14211421

14221422
@testset "Passing `nothing` to `u0`" begin

0 commit comments

Comments
 (0)