Skip to content

feat: use SII to implement ImperativeAffect, support nested NamedTuples #3829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
StochasticDelayDiffEq = "1.10"
StochasticDiffEq = "6.72.1"
SymbolicIndexingInterface = "0.3.39"
SymbolicIndexingInterface = "0.3.42"
SymbolicUtils = "3.26.1"
Symbolics = "6.40"
URIs = "1"
Expand Down
5 changes: 4 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,10 @@ function _all_ts_idxs!(ts_idxs, ::ScalarSymbolic, sys, sym::Symbol)
push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx)
end
end
function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::AbstractArray)
function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::NamedTuple)
_all_ts_idxs!(ts_idxs, NotSymbolic(), sys, values(sym))
end
function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::Union{AbstractArray, Tuple})
for s in sym
_all_ts_idxs!(ts_idxs, sys, s)
end
Expand Down
230 changes: 88 additions & 142 deletions src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ in the returned tuple, in which case the associated field will not be updated.
"""
struct ImperativeAffect
f::Any
obs::Vector
obs_syms::Vector{Symbol}
modified::Vector
mod_syms::Vector{Symbol}
observed::NamedTuple
modified::NamedTuple
ctx::Any
skip_checks::Bool
end
Expand All @@ -43,10 +41,7 @@ function ImperativeAffect(f;
modified::NamedTuple = NamedTuple{()}(()),
ctx = nothing,
skip_checks = false)
ImperativeAffect(f,
collect(values(observed)), collect(keys(observed)),
collect(values(modified)), collect(keys(modified)),
ctx, skip_checks)
ImperativeAffect(f, observed, modified, ctx, skip_checks)
end
function ImperativeAffect(f, modified::NamedTuple;
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false)
Expand All @@ -68,61 +63,54 @@ function ImperativeAffect(; f, kwargs...)
end

function Base.show(io::IO, mfa::ImperativeAffect)
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")
obs = mfa.observed
mod = mfa.modified
affect = mfa.f
print(io,
"ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)")
"ImperativeAffect(observed: [$(obs)], modified: [$(mod)], affect:$affect)")
end
func(f::ImperativeAffect) = f.f
context(a::ImperativeAffect) = a.ctx
observed(a::ImperativeAffect) = a.obs
observed_syms(a::ImperativeAffect) = a.obs_syms
function discretes(a::ImperativeAffect)
Iterators.filter(ModelingToolkit.isparameter,
Iterators.flatten(Iterators.map(
x -> symbolic_type(x) == NotSymbolic() && x isa AbstractArray ? x : [x],
a.modified)))
end
modified(a::ImperativeAffect) = a.modified
modified_syms(a::ImperativeAffect) = a.mod_syms

function Base.:(==)(a1::ImperativeAffect, a2::ImperativeAffect)
isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) &&
isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) &&
isequal(a1.f, a2.f) && isequal(a1.observed, a2.observed) &&
isequal(a1.modified, a2.modified) &&
isequal(a1.ctx, a2.ctx)
end

function Base.hash(a::ImperativeAffect, s::UInt)
s = hash(a.f, s)
s = hash(a.obs, s)
s = hash(a.obs_syms, s)
s = hash(a.observed, s)
s = hash(a.modified, s)
s = hash(a.mod_syms, s)
hash(a.ctx, s)
end

namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s)
function namespace_affect(affect::ImperativeAffect, s)
rmn = []
for modded in modified(affect)
if symbolic_type(modded) == NotSymbolic() && modded isa AbstractArray
res = []
for m in modded
push!(res, renamespace(s, m))
end
push!(rmn, res)

function _namespace_nt(nt::NamedTuple, s::AbstractSystem)
return NamedTuple{keys(nt)}(_namespace_nt(values(nt), s))
end

function _namespace_nt(nt::Union{AbstractArray, Tuple}, s::AbstractSystem)
return map(nt) do v
if symbolic_type(v) == NotSymbolic()
_namespace_nt(v, s)
else
push!(rmn, renamespace(s, modded))
renamespace(s, v)
end
end
ImperativeAffect(func(affect),
namespace_expr.(observed(affect), (s,)),
observed_syms(affect),
rmn,
modified_syms(affect),
context(affect),
affect.skip_checks)
end

function namespace_affect(affect::ImperativeAffect, s)
obs = _namespace_nt(affect.observed, s)
mod = _namespace_nt(affect.modified, s)
ImperativeAffect(affect.f, obs, mod, affect.ctx, affect.skip_checks)
end

function invalid_variables(sys, expr)
Expand All @@ -139,21 +127,6 @@ function unassignable_variables(sys, expr)
x -> !any(isequal(x), assignable_syms), written)
end

@generated function _generated_writeback(integ, setters::NamedTuple{NS1, <:Tuple},
values::NamedTuple{NS2, <:Tuple}) where {NS1, NS2}
setter_exprs = []
for name in NS2
if !(name in NS1)
missing_name = "Tried to write back to $name from affect; only declared states ($NS1) may be written to."
error(missing_name)
end
push!(setter_exprs, :(setters.$name(integ, values.$name)))
end
return :(begin
$(setter_exprs...)
end)
end

function check_assignable(sys, sym)
if symbolic_type(sym) == ScalarSymbolic()
is_variable(sys, sym) || is_parameter(sys, sym)
Expand All @@ -167,6 +140,42 @@ function check_assignable(sys, sym)
end
end

function _nt_check_valid(nt::NamedTuple, s::AbstractSystem, isobserved::Bool)
_nt_check_valid(values(nt), s, isobserved)
end

function _nt_check_valid(
nt::Union{Tuple, AbstractArray}, s::AbstractSystem, isobserved::Bool)
for v in nt
if symbolic_type(v) == NotSymbolic()
_nt_check_valid(v, s, isobserved)
continue
end
if !isobserved && !check_assignable(s, v)
error("""
Expression $v cannot be assigned to; currently only unknowns and parameters may \
be updated by an affect.
""")
end
invalid = invalid_variables(s, v)
isempty(invalid) && continue
name = isobserved ? "Observed" : "Modified"
error("""
$name expression $(v) in affect refers to missing variable(s) $(invalid); \
the variables may not have been added (e.g. if a component is missing).
""")
end
end

function _nt_check_overlap(nta::NamedTuple, ntb::NamedTuple)
common = intersect(keys(nta), keys(ntb))
isempty(common) && return
@warn """
The symbols $common are declared as both observed and modified; this is a code smell \
because it becomes easy to confuse them and assign/not assign a value.
"""
end

function compile_functional_affect(
affect::ImperativeAffect, sys; reset_jumps = false, kwargs...)
#=
Expand All @@ -176,93 +185,27 @@ function compile_functional_affect(
call the affect method
unpack and apply the resulting values
=#
function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup)
seen = Set{Symbol}()
syms_dedup = []
exprs_dedup = []
for (sym, exp) in Iterators.zip(syms, exprs)
if !in(sym, seen)
push!(syms_dedup, sym)
push!(exprs_dedup, exp)
push!(seen, sym)
elseif !affect.skip_checks
@warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used."
end
end
return (syms_dedup, exprs_dedup)
end

dvs = unknowns(sys)
ps = parameters(sys)

obs_exprs = observed(affect)
if !affect.skip_checks
for oexpr in obs_exprs
invalid_vars = invalid_variables(sys, oexpr)
if length(invalid_vars) > 0
error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).")
end
end
end
obs_syms = observed_syms(affect)
obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs)

mod_exprs = modified(affect)
if !affect.skip_checks
for mexpr in mod_exprs
if !check_assignable(sys, mexpr)
@warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.")
end
invalid_vars = unassignable_variables(sys, mexpr)
if length(invalid_vars) > 0
error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.")
end
end
end
mod_syms = modified_syms(affect)
mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs)

overlapping_syms = intersect(mod_syms, obs_syms)
if length(overlapping_syms) > 0 && !affect.skip_checks
@warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
_nt_check_valid(affect.observed, sys, true)
_nt_check_valid(affect.modified, sys, false)
_nt_check_overlap(affect.observed, affect.modified)
end

# sanity checks done! now build the data and update function for observed values
mkzero(sz) =
if sz === ()
0.0
else
zeros(sz)
end
obs_fun = build_explicit_observed_function(
sys, Symbolics.scalarize.(obs_exprs);
mkarray = (es, _) -> MakeTuple(es))
obs_sym_tuple = (obs_syms...,)

# okay so now to generate the stuff to assign it back into the system
mod_pairs = mod_exprs .=> mod_syms
mod_names = (mod_syms...,)
mod_og_val_fun = build_explicit_observed_function(
sys, Symbolics.scalarize.(first.(mod_pairs));
mkarray = (es, _) -> MakeTuple(es))
let user_affect = func(affect), ctx = context(affect),
obs_getter = isempty(affect.observed) ? Returns((;)) : getsym(sys, affect.observed),
mod_getter = isempty(affect.modified) ? Returns((;)) : getsym(sys, affect.modified),
mod_setter = isempty(affect.modified) ? Returns((;)) : setsym(sys, affect.modified),
reset_jumps = reset_jumps

upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))

let user_affect = func(affect), ctx = context(affect), reset_jumps = reset_jumps
@inline function (integ)
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
modvals = mod_og_val_fun(integ.u, integ.p, integ.t)
upd_component_array = NamedTuple{mod_names}(modvals)

# update the observed values
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(
integ.u, integ.p, integ.t))
mod = mod_getter(integ)
obs = obs_getter(integ)

# let the user do their thing
upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ)

# write the new values back to the integrator
_generated_writeback(integ, upd_funs, upd_vals)
upd_vals = user_affect(mod, obs, ctx, integ)
mod_setter(integ, upd_vals)

reset_jumps && reset_aggregated_jumps!(integ)
end
Expand All @@ -271,19 +214,22 @@ end

scalarize_affects(affects::ImperativeAffect) = affects

function vars!(vars, aff::ImperativeAffect; op = Differential)
for var in Iterators.flatten((observed(aff), modified(aff)))
if symbolic_type(var) == NotSymbolic()
if var isa AbstractArray
for v in var
v = unwrap(v)
vars!(vars, v)
end
end
else
var = unwrap(var)
vars!(vars, var)
function _vars_nt!(vars, nt::NamedTuple, op)
_vars_nt!(vars, values(nt), op)
end

function _vars_nt!(vars, nt::Union{AbstractArray, Tuple}, op)
for v in nt
if symbolic_type(v) == NotSymbolic()
_vars_nt!(vars, v, op)
continue
end
vars!(vars, v; op)
end
end

function vars!(vars, aff::ImperativeAffect; op = Differential)
_vars_nt!(vars, aff.observed, op)
_vars_nt!(vars, aff.modified, op)
return vars
end
Loading
Loading