Skip to content

Commit b07f208

Browse files
committed
fix: fix collect_var
1 parent 4e32749 commit b07f208

File tree

3 files changed

+36
-29
lines changed

3 files changed

+36
-29
lines changed

src/systems/callbacks.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,9 @@ function make_affect(affect::Vector{Equation}; discrete_parameters::AbstractVect
301301
# get accessed parameters p from Pre(p) in the callback parameters
302302
accessed_params = filter(isparameter, map(unPre, collect(pre_params)))
303303
union!(accessed_params, sys_params)
304-
# add unknowns to the map
304+
305+
# add scalarized unknowns to the map.
306+
_dvs = reduce(vcat, map(scalarize, _dvs), init = Any[])
305307
for u in _dvs
306308
aff_map[u] = u
307309
end
@@ -616,7 +618,8 @@ function compile_condition(
616618
end
617619

618620
if !is_discrete(cbs)
619-
condit = [cond.lhs - cond.rhs for cond in condit]
621+
condit = reduce(vcat, flatten_equations(condit))
622+
condit = condit isa AbstractVector ? [c.lhs - c.rhs for c in condit] : [condit.lhs - condit.rhs]
620623
end
621624

622625
fs = build_function_wrapper(sys,

src/utils.jl

+1
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ end
595595

596596
function collect_var!(unknowns, parameters, var, iv; depth = 0)
597597
isequal(var, iv) && return nothing
598+
var = unwrap(var)
598599
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
599600
if iscalledparameter(var)
600601
callable = getcalledparameter(var)

test/odesystem.jl

+30-27
Original file line numberDiff line numberDiff line change
@@ -1032,24 +1032,26 @@ prob = ODEProblem(sys, [x => 1.0], (0.0, 10.0))
10321032
@test_nowarn solve(prob, Tsit5())
10331033

10341034
# Issue#2383
1035-
@variables x(t)[1:3]
1036-
@parameters p[1:3, 1:3]
1037-
eqs = [
1038-
D(x) ~ p * x
1039-
]
1040-
@mtkbuild sys = ODESystem(eqs, t; continuous_events = [[norm(x) ~ 3.0] => [x ~ ones(3)]])
1041-
# array affect equations used to not work
1042-
prob1 = @test_nowarn ODEProblem(sys, [x => ones(3)], (0.0, 10.0), [p => ones(3, 3)])
1043-
sol1 = @test_nowarn solve(prob1, Tsit5())
1044-
1045-
# array condition equations also used to not work
1046-
@mtkbuild sys = ODESystem(
1047-
eqs, t; continuous_events = [[x ~ sqrt(3) * ones(3)] => [x ~ ones(3)]])
1048-
# array affect equations used to not work
1049-
prob2 = @test_nowarn ODEProblem(sys, [x => ones(3)], (0.0, 10.0), [p => ones(3, 3)])
1050-
sol2 = @test_nowarn solve(prob2, Tsit5())
1051-
1052-
@test sol1 sol2
1035+
@testset "Arrays in affect/condition equations" begin
1036+
@variables x(t)[1:3]
1037+
@parameters p[1:3, 1:3]
1038+
eqs = [
1039+
D(x) ~ p * x
1040+
]
1041+
@mtkbuild sys = ODESystem(eqs, t; continuous_events = [[norm(x) ~ 3.0] => [x ~ ones(3)]])
1042+
# array affect equations used to not work
1043+
prob1 = @test_nowarn ODEProblem(sys, [x => ones(3)], (0.0, 10.0), [p => ones(3, 3)])
1044+
sol1 = @test_nowarn solve(prob1, Tsit5())
1045+
1046+
# array condition equations also used to not work
1047+
@mtkbuild sys = ODESystem(
1048+
eqs, t; continuous_events = [[x ~ sqrt(3) * ones(3)] => [x ~ ones(3)]])
1049+
# array affect equations used to not work
1050+
prob2 = @test_nowarn ODEProblem(sys, [x => ones(3)], (0.0, 10.0), [p => ones(3, 3)])
1051+
sol2 = @test_nowarn solve(prob2, Tsit5())
1052+
1053+
@test sol1.u sol2.u[2:end]
1054+
end
10531055

10541056
# Requires fix in symbolics for `linear_expansion(p * x, D(y))`
10551057
@test_skip begin
@@ -1196,10 +1198,12 @@ end
11961198
end
11971199

11981200
# Namespacing of array variables
1199-
@variables x(t)[1:2]
1200-
@named sys = ODESystem(Equation[], t)
1201-
@test getname(unknowns(sys, x)) == :sys₊x
1202-
@test size(unknowns(sys, x)) == size(x)
1201+
@testset "Namespacing of array variables" begin
1202+
@variables x(t)[1:2]
1203+
@named sys = ODESystem(Equation[], t)
1204+
@test getname(unknowns(sys, x)) == :sys₊x
1205+
@test size(unknowns(sys, x)) == size(x)
1206+
end
12031207

12041208
# Issue#2667 and Issue#2953
12051209
@testset "ForwardDiff through ODEProblem constructor" begin
@@ -1537,8 +1541,7 @@ end
15371541
@testset "Observed variables dependent on discrete parameters" begin
15381542
@variables x(t) obs(t)
15391543
@parameters c(t)
1540-
@mtkbuild sys = ODESystem(
1541-
[D(x) ~ c * cos(x), obs ~ c], t, [x], [c]; discrete_events = [1.0 => [c ~ c + 1]])
1544+
@mtkbuild sys = ODESystem([D(x) ~ c * cos(x), obs ~ c], t, [x], [c]; discrete_events = [SymbolicDiscreteCallback(1.0 => [c ~ Pre(c) + 1], discrete_parameters = [c])])
15421545
prob = ODEProblem(sys, [x => 0.0], (0.0, 2pi), [c => 1.0])
15431546
sol = solve(prob, Tsit5())
15441547
@test sol[obs] 1:7
@@ -1598,15 +1601,15 @@ end
15981601
# Test `isequal`
15991602
@testset "`isequal`" begin
16001603
@variables X(t)
1601-
@parameters p d
1604+
@parameters p d(t)
16021605
eq = D(X) ~ p - d * X
16031606

16041607
osys1 = complete(ODESystem([eq], t; name = :osys))
16051608
osys2 = complete(ODESystem([eq], t; name = :osys))
16061609
@test osys1 == osys2 # true
16071610

1608-
continuous_events = [[X ~ 1.0] => [X ~ X + 5.0]]
1609-
discrete_events = [5.0 => [d ~ d / 2.0]]
1611+
continuous_events = [[X ~ 1.0] => [X ~ Pre(X) + 5.0]]
1612+
discrete_events = [SymbolicDiscreteCallback(5.0 => [d ~ d / 2.0], discrete_parameters = [d])]
16101613

16111614
osys1 = complete(ODESystem([eq], t; name = :osys, continuous_events))
16121615
osys2 = complete(ODESystem([eq], t; name = :osys))

0 commit comments

Comments
 (0)