From 9102a79e40845c698cd41b8fd1bf27ad56063533 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 18 Jul 2025 00:42:12 +0530 Subject: [PATCH 1/5] refactor: use SII for ImperativeAffect --- src/systems/imperative_affect.jl | 230 ++++++++++++------------------- 1 file changed, 88 insertions(+), 142 deletions(-) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 7b1a9fb286..9c3971c08a 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -30,10 +30,8 @@ in the returned tuple, in which case the associated field will not be updated. """ struct ImperativeAffect f::Any - obs::Vector - obs_syms::Vector{Symbol} - modified::Vector - mod_syms::Vector{Symbol} + observed::NamedTuple + modified::NamedTuple ctx::Any skip_checks::Bool end @@ -43,10 +41,7 @@ function ImperativeAffect(f; modified::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false) - ImperativeAffect(f, - collect(values(observed)), collect(keys(observed)), - collect(values(modified)), collect(keys(modified)), - ctx, skip_checks) + ImperativeAffect(f, observed, modified, ctx, skip_checks) end function ImperativeAffect(f, modified::NamedTuple; observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false) @@ -68,61 +63,54 @@ function ImperativeAffect(; f, kwargs...) end function Base.show(io::IO, mfa::ImperativeAffect) - obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ") - mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ") + obs = mfa.observed + mod = mfa.modified affect = mfa.f print(io, - "ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)") + "ImperativeAffect(observed: [$(obs)], modified: [$(mod)], affect:$affect)") end func(f::ImperativeAffect) = f.f context(a::ImperativeAffect) = a.ctx -observed(a::ImperativeAffect) = a.obs -observed_syms(a::ImperativeAffect) = a.obs_syms function discretes(a::ImperativeAffect) Iterators.filter(ModelingToolkit.isparameter, Iterators.flatten(Iterators.map( x -> symbolic_type(x) == NotSymbolic() && x isa AbstractArray ? x : [x], a.modified))) end -modified(a::ImperativeAffect) = a.modified -modified_syms(a::ImperativeAffect) = a.mod_syms function Base.:(==)(a1::ImperativeAffect, a2::ImperativeAffect) - isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) && - isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) && + isequal(a1.f, a2.f) && isequal(a1.observed, a2.observed) && + isequal(a1.modified, a2.modified) && isequal(a1.ctx, a2.ctx) end function Base.hash(a::ImperativeAffect, s::UInt) s = hash(a.f, s) - s = hash(a.obs, s) - s = hash(a.obs_syms, s) + s = hash(a.observed, s) s = hash(a.modified, s) - s = hash(a.mod_syms, s) hash(a.ctx, s) end namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s) -function namespace_affect(affect::ImperativeAffect, s) - rmn = [] - for modded in modified(affect) - if symbolic_type(modded) == NotSymbolic() && modded isa AbstractArray - res = [] - for m in modded - push!(res, renamespace(s, m)) - end - push!(rmn, res) + +function _namespace_nt(nt::NamedTuple, s::AbstractSystem) + return NamedTuple{keys(nt)}(_namespace_nt(values(nt), s)) +end + +function _namespace_nt(nt::Union{AbstractArray, Tuple}, s::AbstractSystem) + return map(nt) do v + if symbolic_type(v) == NotSymbolic() + _namespace_nt(v, s) else - push!(rmn, renamespace(s, modded)) + renamespace(s, v) end end - ImperativeAffect(func(affect), - namespace_expr.(observed(affect), (s,)), - observed_syms(affect), - rmn, - modified_syms(affect), - context(affect), - affect.skip_checks) +end + +function namespace_affect(affect::ImperativeAffect, s) + obs = _namespace_nt(affect.observed, s) + mod = _namespace_nt(affect.modified, s) + ImperativeAffect(affect.f, obs, mod, affect.ctx, affect.skip_checks) end function invalid_variables(sys, expr) @@ -139,21 +127,6 @@ function unassignable_variables(sys, expr) x -> !any(isequal(x), assignable_syms), written) end -@generated function _generated_writeback(integ, setters::NamedTuple{NS1, <:Tuple}, - values::NamedTuple{NS2, <:Tuple}) where {NS1, NS2} - setter_exprs = [] - for name in NS2 - if !(name in NS1) - missing_name = "Tried to write back to $name from affect; only declared states ($NS1) may be written to." - error(missing_name) - end - push!(setter_exprs, :(setters.$name(integ, values.$name))) - end - return :(begin - $(setter_exprs...) - end) -end - function check_assignable(sys, sym) if symbolic_type(sym) == ScalarSymbolic() is_variable(sys, sym) || is_parameter(sys, sym) @@ -167,6 +140,42 @@ function check_assignable(sys, sym) end end +function _nt_check_valid(nt::NamedTuple, s::AbstractSystem, isobserved::Bool) + _nt_check_valid(values(nt), s, isobserved) +end + +function _nt_check_valid( + nt::Union{Tuple, AbstractArray}, s::AbstractSystem, isobserved::Bool) + for v in nt + if symbolic_type(v) == NotSymbolic() + _nt_check_valid(v, s, isobserved) + continue + end + if !isobserved && !check_assignable(s, v) + error(""" + Expression $v cannot be assigned to; currently only unknowns and parameters may \ + be updated by an affect. + """) + end + invalid = invalid_variables(s, v) + isempty(invalid) && continue + name = isobserved ? "Observed" : "Modified" + error(""" + $name expression $(v) in affect refers to missing variable(s) $(invalid); \ + the variables may not have been added (e.g. if a component is missing). + """) + end +end + +function _nt_check_overlap(nta::NamedTuple, ntb::NamedTuple) + common = intersect(keys(nta), keys(ntb)) + isempty(common) && return + @warn """ + The symbols $common are declared as both observed and modified; this is a code smell \ + because it becomes easy to confuse them and assign/not assign a value. + """ +end + function compile_functional_affect( affect::ImperativeAffect, sys; reset_jumps = false, kwargs...) #= @@ -176,93 +185,27 @@ function compile_functional_affect( call the affect method unpack and apply the resulting values =# - function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup) - seen = Set{Symbol}() - syms_dedup = [] - exprs_dedup = [] - for (sym, exp) in Iterators.zip(syms, exprs) - if !in(sym, seen) - push!(syms_dedup, sym) - push!(exprs_dedup, exp) - push!(seen, sym) - elseif !affect.skip_checks - @warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used." - end - end - return (syms_dedup, exprs_dedup) - end - dvs = unknowns(sys) - ps = parameters(sys) - - obs_exprs = observed(affect) - if !affect.skip_checks - for oexpr in obs_exprs - invalid_vars = invalid_variables(sys, oexpr) - if length(invalid_vars) > 0 - error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).") - end - end - end - obs_syms = observed_syms(affect) - obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs) - - mod_exprs = modified(affect) if !affect.skip_checks - for mexpr in mod_exprs - if !check_assignable(sys, mexpr) - @warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.") - end - invalid_vars = unassignable_variables(sys, mexpr) - if length(invalid_vars) > 0 - error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.") - end - end - end - mod_syms = modified_syms(affect) - mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs) - - overlapping_syms = intersect(mod_syms, obs_syms) - if length(overlapping_syms) > 0 && !affect.skip_checks - @warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value." + _nt_check_valid(affect.observed, sys, true) + _nt_check_valid(affect.modified, sys, false) + _nt_check_overlap(affect.observed, affect.modified) end # sanity checks done! now build the data and update function for observed values - mkzero(sz) = - if sz === () - 0.0 - else - zeros(sz) - end - obs_fun = build_explicit_observed_function( - sys, Symbolics.scalarize.(obs_exprs); - mkarray = (es, _) -> MakeTuple(es)) - obs_sym_tuple = (obs_syms...,) - - # okay so now to generate the stuff to assign it back into the system - mod_pairs = mod_exprs .=> mod_syms - mod_names = (mod_syms...,) - mod_og_val_fun = build_explicit_observed_function( - sys, Symbolics.scalarize.(first.(mod_pairs)); - mkarray = (es, _) -> MakeTuple(es)) + let user_affect = func(affect), ctx = context(affect), + obs_getter = isempty(affect.observed) ? Returns((;)) : getsym(sys, affect.observed), + mod_getter = isempty(affect.modified) ? Returns((;)) : getsym(sys, affect.modified), + mod_setter = isempty(affect.modified) ? Returns((;)) : setsym(sys, affect.modified), + reset_jumps = reset_jumps - upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,)) - - let user_affect = func(affect), ctx = context(affect), reset_jumps = reset_jumps @inline function (integ) - # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens - modvals = mod_og_val_fun(integ.u, integ.p, integ.t) - upd_component_array = NamedTuple{mod_names}(modvals) - - # update the observed values - obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun( - integ.u, integ.p, integ.t)) + mod = mod_getter(integ) + obs = obs_getter(integ) # let the user do their thing - upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) - - # write the new values back to the integrator - _generated_writeback(integ, upd_funs, upd_vals) + upd_vals = user_affect(mod, obs, ctx, integ) + mod_setter(integ, upd_vals) reset_jumps && reset_aggregated_jumps!(integ) end @@ -271,19 +214,22 @@ end scalarize_affects(affects::ImperativeAffect) = affects -function vars!(vars, aff::ImperativeAffect; op = Differential) - for var in Iterators.flatten((observed(aff), modified(aff))) - if symbolic_type(var) == NotSymbolic() - if var isa AbstractArray - for v in var - v = unwrap(v) - vars!(vars, v) - end - end - else - var = unwrap(var) - vars!(vars, var) +function _vars_nt!(vars, nt::NamedTuple, op) + _vars_nt!(vars, values(nt), op) +end + +function _vars_nt!(vars, nt::Union{AbstractArray, Tuple}, op) + for v in nt + if symbolic_type(v) == NotSymbolic() + _vars_nt!(vars, v, op) + continue end + vars!(vars, v; op) end +end + +function vars!(vars, aff::ImperativeAffect; op = Differential) + _vars_nt!(vars, aff.observed, op) + _vars_nt!(vars, aff.modified, op) return vars end From 7b47da1de32a68e5613f16abd2b30dcc5a860ab4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 18 Jul 2025 00:54:14 +0530 Subject: [PATCH 2/5] feat: support `Tuple` and `NamedTuple` in `SII.get_all_timeseries_indexes` --- src/systems/abstractsystem.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 8193b795a2..3834e544d9 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -351,7 +351,10 @@ function _all_ts_idxs!(ts_idxs, ::ScalarSymbolic, sys, sym::Symbol) push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx) end end -function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::AbstractArray) +function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::NamedTuple) + _all_ts_idxs!(ts_idxs, NotSymbolic(), sys, values(sym)) +end +function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::Union{AbstractArray, Tuple}) for s in sym _all_ts_idxs!(ts_idxs, sys, s) end From ad0e4c480b42ac1a62df23218ce7dce1366e5dab Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 18 Jul 2025 00:54:35 +0530 Subject: [PATCH 3/5] test: test nested namedtuples in ImperativeAffect --- test/symbolic_events.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 886813a37f..c3aafa4363 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1410,3 +1410,21 @@ end sol = solve(prob, FBDF()) @test SciMLBase.successful_retcode(sol) end + +@testset "Nested NamedTuple ImperativeAffect" begin + @variables x(t) y(t) z(t) w(t) + obs = (; a = (; b = (x, y))) + mod = (; p = (; q = y, r = z)) + affect = ModelingToolkit.ImperativeAffect(mod, obs) do mod, obs, ctx, integ + @test integ[x] ≈ obs.a.b[1] + @test integ[y] ≈ obs.a.b[2] + @test integ[y] ≈ mod.p.q + @test integ[z] ≈ mod.p.r + return (; p = (; q = obs.a.b[1])) + end + event = 1.0 => affect + @mtkcompile sys = System([D(x) ~ x, D(y) ~ y, D(z) ~ z], t; discrete_events = [event]) + prob = ODEProblem(sys, [x => 1.0, y => 2.0, z => 3.0], (0.0, 5.0)) + sol = solve(prob, Tsit5()) + @test SciMLBase.successful_retcode(sol) +end From 085fabaebd8acda16273711180fb92cab1dd6588 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 18 Jul 2025 15:31:13 +0530 Subject: [PATCH 4/5] build: bump SII compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 98c7837212..62dae1992b 100644 --- a/Project.toml +++ b/Project.toml @@ -160,7 +160,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" StochasticDelayDiffEq = "1.10" StochasticDiffEq = "6.72.1" -SymbolicIndexingInterface = "0.3.39" +SymbolicIndexingInterface = "0.3.42" SymbolicUtils = "3.26.1" Symbolics = "6.40" URIs = "1" From d0b419e8d6a6d1e87edcc9ea1e7105239e3059f2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 18 Jul 2025 19:31:09 +0530 Subject: [PATCH 5/5] test: update tests to new ImperativeAffect internals --- test/symbolic_events.jl | 77 ++++++++++++++--------------------------- 1 file changed, 26 insertions(+), 51 deletions(-) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index c3aafa4363..359ec54c73 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -96,111 +96,87 @@ end m = ModelingToolkit.ImperativeAffect(fmfa) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test m.obs == [] - @test m.obs_syms == [] - @test m.modified == [] - @test m.mod_syms == [] + @test m.observed == (;) + @test m.modified == (;) @test m.ctx === nothing m = ModelingToolkit.ImperativeAffect(fmfa, (;)) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test m.obs == [] - @test m.obs_syms == [] - @test m.modified == [] - @test m.mod_syms == [] + @test m.observed == (;) + @test m.modified == (;) @test m.ctx === nothing m = ModelingToolkit.ImperativeAffect(fmfa, (; x)) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test isequal(m.obs, []) - @test m.obs_syms == [] - @test isequal(m.modified, [x]) - @test m.mod_syms == [:x] + @test m.observed == (;) + @test m.modified == (; x) @test m.ctx === nothing m = ModelingToolkit.ImperativeAffect(fmfa, (; y = x)) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test isequal(m.obs, []) - @test m.obs_syms == [] - @test isequal(m.modified, [x]) - @test m.mod_syms == [:y] + @test m.observed == (;) + @test m.modified == (; y = x) @test m.ctx === nothing m = ModelingToolkit.ImperativeAffect(fmfa; observed = (; y = x)) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test isequal(m.obs, [x]) - @test m.obs_syms == [:y] - @test m.modified == [] - @test m.mod_syms == [] + @test m.observed == (; y = x) + @test m.modified == (;) @test m.ctx === nothing m = ModelingToolkit.ImperativeAffect(fmfa; modified = (; x)) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test isequal(m.obs, []) - @test m.obs_syms == [] - @test isequal(m.modified, [x]) - @test m.mod_syms == [:x] + @test m.observed == (;) + @test m.modified == (; x) @test m.ctx === nothing m = ModelingToolkit.ImperativeAffect(fmfa; modified = (; y = x)) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test isequal(m.obs, []) - @test m.obs_syms == [] - @test isequal(m.modified, [x]) - @test m.mod_syms == [:y] + @test m.observed == (;) + @test m.modified == (; y = x) @test m.ctx === nothing m = ModelingToolkit.ImperativeAffect(fmfa, (; x), (; x)) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test isequal(m.obs, [x]) - @test m.obs_syms == [:x] - @test isequal(m.modified, [x]) - @test m.mod_syms == [:x] + @test m.observed == (; x) + @test m.modified == (; x) @test m.ctx === nothing m = ModelingToolkit.ImperativeAffect(fmfa, (; y = x), (; y = x)) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test isequal(m.obs, [x]) - @test m.obs_syms == [:y] - @test isequal(m.modified, [x]) - @test m.mod_syms == [:y] + @test m.observed == (; y = x) + @test m.modified == (; y = x) @test m.ctx === nothing m = ModelingToolkit.ImperativeAffect( fmfa; modified = (; y = x), observed = (; y = x)) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test isequal(m.obs, [x]) - @test m.obs_syms == [:y] - @test isequal(m.modified, [x]) - @test m.mod_syms == [:y] + @test m.observed == (; y = x) + @test m.modified == (; y = x) @test m.ctx === nothing m = ModelingToolkit.ImperativeAffect( fmfa; modified = (; y = x), observed = (; y = x), ctx = 3) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test isequal(m.obs, [x]) - @test m.obs_syms == [:y] - @test isequal(m.modified, [x]) - @test m.mod_syms == [:y] + @test m.observed == (; y = x) + @test m.modified == (; y = x) @test m.ctx === 3 m = ModelingToolkit.ImperativeAffect(fmfa, (; x), (; x), 3) @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa - @test isequal(m.obs, [x]) - @test m.obs_syms == [:x] - @test isequal(m.modified, [x]) - @test m.mod_syms == [:x] + @test m.observed == (; x) + @test m.modified == (; x) @test m.ctx === 3 end @@ -966,8 +942,7 @@ end end) @named sys = System(eqs, t, [temp], params; continuous_events = [furnace_off]) ss = mtkcompile(sys) - @test_logs (:warn, - "The symbols Any[:furnace_on] are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value.") prob=ODEProblem( + @test_warn "The symbols [:furnace_on] are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value." prob=ODEProblem( ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) @variables tempsq(t) # trivially eliminated @@ -1010,7 +985,7 @@ end ss = mtkcompile(sys) prob = ODEProblem( ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) - @test_throws "Tried to write back to" solve(prob, Tsit5()) + @test_throws "Invalid name" solve(prob, Tsit5()) end @testset "Quadrature" begin