diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 09b59c4ed6..5db5b3992b 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -69,7 +69,8 @@ using Symbolics: _parse_vars, value, @derivatives, get_variables, NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval, initial_state, transition, activeState, entry, hasnode, ticksInState, timeInState, fixpoint_sub, fast_substitute, - CallWithMetadata, CallWithParent + CallWithMetadata, CallWithParent, Transition, InitialState, + StateMachineOperator const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR) import Symbolics: rename, get_variables!, _solve, hessian_sparsity, jacobian_sparsity, isaffine, islinear, _iszero, _isone, diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 8213b8f241..724194fb94 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1390,6 +1390,26 @@ function namespace_expr( O end end + +function namespace_expr( + O::Transition, sys, n = nameof(sys); ivs = independent_variables(sys)) + return Transition( + O.from === nothing ? O.from : renamespace(sys, O.from), + O.to === nothing ? O.to : renamespace(sys, O.to), + O.cond === nothing ? O.cond : namespace_expr(O.cond, sys), + O.immediate, O.reset, O.synchronize, O.priority + ) +end + +function namespace_expr( + O::InitialState, sys, n = nameof(sys); ivs = independent_variables(sys)) + return InitialState(O.s === nothing ? O.s : renamespace(sys, O.s)) +end + +function namespace_expr(O::StateMachineOperator, sys, n = nameof(sys); kwargs...) + error("Unhandled state machine operator") +end + _nonum(@nospecialize x) = x isa Num ? x.val : x """ diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 611f8e2fae..66454e0785 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -93,7 +93,7 @@ function infer_clocks!(ci::ClockInference) c = BitSet(c′) idxs = intersect(c, inferred) isempty(idxs) && continue - if !allequal(var_domain[i] for i in idxs) + if !allequal(iscontinuous(var_domain[i]) for i in idxs) display(fullvars[c′]) throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])")) end @@ -144,6 +144,9 @@ function split_system(ci::ClockInference{S}) where {S} var_to_cid = Vector{Int}(undef, ndsts(graph)) cid_to_var = Vector{Int}[] cid_counter = Ref(0) + + # populates clock_to_id and id_to_clock + # checks if there is a continuous_id (for some reason? clock to id does this too) for (i, d) in enumerate(eq_domain) cid = let cid_counter = cid_counter, id_to_clock = id_to_clock, continuous_id = continuous_id @@ -161,9 +164,13 @@ function split_system(ci::ClockInference{S}) where {S} resize_or_push!(cid_to_eq, i, cid) end continuous_id = continuous_id[] + # for each clock partition what are the input (indexes/vars) input_idxs = map(_ -> Int[], 1:cid_counter[]) inputs = map(_ -> Any[], 1:cid_counter[]) + # var_domain corresponds to fullvars/all variables in the system nvv = length(var_domain) + # put variables into the right clock partition + # keep track of inputs to each partition for i in 1:nvv d = var_domain[i] cid = get(clock_to_id, d, 0) @@ -177,6 +184,7 @@ function split_system(ci::ClockInference{S}) where {S} resize_or_push!(cid_to_var, i, cid) end + # breaks the system up into a continous and 0 or more discrete systems tss = similar(cid_to_eq, S) for (id, ieqs) in enumerate(cid_to_eq) ts_i = system_subset(ts, ieqs) @@ -186,6 +194,7 @@ function split_system(ci::ClockInference{S}) where {S} end tss[id] = ts_i end + # put the continous system at the back if continuous_id != 0 tss[continuous_id], tss[end] = tss[end], tss[continuous_id] inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id] diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 4c9ff3d248..56c3721317 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -211,7 +211,9 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) # write the new values back to the integrator - _generated_writeback(integ, upd_funs, upd_vals) + if !isnothing(upd_vals) + _generated_writeback(integ, upd_funs, upd_vals) + end for idx in save_idxs SciMLBase.save_discretes!(integ, idx) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index f8630f2d20..924be54b38 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -31,7 +31,7 @@ function structural_simplify( kwargs...) isscheduled(sys) && throw(RepeatedStructuralSimplificationError()) newsys′ = __structural_simplify(sys, io; simplify, - allow_symbolic, allow_parameter, conservative, fully_determined, + allow_symbolic, allow_parameter, conservative, fully_determined, additional_passes, kwargs...) if newsys′ isa Tuple @assert length(newsys′) == 2 @@ -82,12 +82,13 @@ end function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = false, kwargs...) + sys, statemachines = extract_top_level_statemachines(sys) sys = expand_connections(sys) state = TearingState(sys) + append!(state.statemachines, statemachines) @unpack structure, fullvars = state @unpack graph, var_to_diff, var_types = structure - eqs = equations(state) brown_vars = Int[] new_idxs = zeros(Int, length(var_types)) idx = 0 @@ -104,7 +105,8 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal Is = Int[] Js = Int[] vals = Num[] - new_eqs = copy(eqs) + make_eqs_zero_equals!(state) + new_eqs = copy(equations(state)) dvar2eq = Dict{Any, Int}() for (v, dv) in enumerate(var_to_diff) dv === nothing && continue @@ -169,3 +171,8 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal guesses = guesses(sys), initialization_eqs = initialization_equations(sys)) end end + +""" +Mark whether an extra pass `p` can support compiling discrete systems. +""" +discrete_compile_pass(p) = false diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 1bdc11f06a..6f51311cbb 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -1,5 +1,5 @@ using DataStructures -using Symbolics: linear_expansion, unwrap, Connection +using Symbolics: linear_expansion, unwrap, Connection, Transition, InitialState using SymbolicUtils: iscall, operation, arguments, Symbolic using SymbolicUtils: quick_cancel, maketerm using ..ModelingToolkit @@ -198,16 +198,35 @@ end mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} sys::T + original_eqs::Vector{Equation} fullvars::Vector structure::SystemStructure extra_eqs::Vector + statemachines::Vector{T} end TransformationState(sys::AbstractSystem) = TearingState(sys) function system_subset(ts::TearingState, ieqs::Vector{Int}) eqs = equations(ts) + @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.sys.eqs = eqs[ieqs] @set! ts.structure = system_subset(ts.structure, ieqs) + if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys)) + names = Symbol[] + for eq in get_eqs(ts.sys) + if eq.lhs isa Transition + push!(names, first(namespace_hierarchy(nameof(eq.rhs.from)))) + push!(names, first(namespace_hierarchy(nameof(eq.rhs.to)))) + elseif eq.lhs isa InitialState + push!(names, first(namespace_hierarchy(nameof(eq.rhs.s)))) + else + error("Unhandled state machine operator") + end + end + @set! ts.statemachines = filter(x -> nameof(x) in names, ts.statemachines) + else + @set! ts.statemachines = eltype(ts.statemachines)[] + end ts end @@ -247,12 +266,56 @@ function Base.push!(ev::EquationsView, eq) push!(ev.ts.extra_eqs, eq) end +""" + $(TYPEDSIGNATURES) + +Descend through the system hierarchy and look for statemachines. Remove equations from +the inner statemachine systems. Return the new `sys` and an array of top-level +statemachines. +""" +function extract_top_level_statemachines(sys::AbstractSystem) + eqs = get_eqs(sys) + + if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs) + # top-level statemachine + with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys)) + return with_removed, [sys] + elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs) + # error: can't mix + error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.") + else + # descend + subsystems = get_systems(sys) + newsubsystems = eltype(subsystems)[] + statemachines = eltype(subsystems)[] + for subsys in subsystems + newsubsys, sub_statemachines = extract_top_level_statemachines(subsys) + push!(newsubsystems, newsubsys) + append!(statemachines, sub_statemachines) + end + @set! sys.systems = newsubsystems + return sys, statemachines + end +end + +""" + $(TYPEDSIGNATURES) + +Return `sys` with all equations (including those in subsystems) removed. +""" +function remove_child_equations(sys::AbstractSystem) + @set! sys.eqs = eltype(get_eqs(sys))[] + @set! sys.systems = map(remove_child_equations, get_systems(sys)) + return sys +end + function TearingState(sys; quick_cancel = false, check = true) sys = flatten(sys) ivs = independent_variables(sys) iv = length(ivs) == 1 ? ivs[1] : nothing # scalarize array equations, without scalarizing arguments to registered functions - eqs = flatten_equations(copy(equations(sys))) + original_eqs = flatten_equations(copy(equations(sys))) + eqs = copy(original_eqs) neqs = length(eqs) dervaridxs = OrderedSet{Int}() var2idx = Dict{Any, Int}() @@ -275,7 +338,12 @@ function TearingState(sys; quick_cancel = false, check = true) check ? error("$(nameof(sys)) has unexpanded `connect` statements") : return nothing end - if _iszero(eq′.lhs) + is_statemachine_equation = false + if eq′.lhs isa StateMachineOperator + is_statemachine_equation = true + eq = eq′ + rhs = eq.rhs + elseif _iszero(eq′.lhs) rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs eq = eq′ else @@ -340,7 +408,7 @@ function TearingState(sys; quick_cancel = false, check = true) empty!(unknownvars) empty!(vars) empty!(varsvec) - if isalgeq + if isalgeq || is_statemachine_equation eqs[i] = eq else eqs[i] = eqs[i].lhs ~ rhs @@ -428,10 +496,10 @@ function TearingState(sys; quick_cancel = false, check = true) eq_to_diff = DiffGraph(nsrcs(graph)) - ts = TearingState(sys, fullvars, + ts = TearingState(sys, original_eqs, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), complete(graph), nothing, var_types, sys isa DiscreteSystem), - Any[]) + Any[], typeof(sys)[]) if sys isa DiscreteSystem ts = shift_discrete_system(ts) end @@ -622,44 +690,69 @@ function merge_io(io, inputs) return io end +function make_eqs_zero_equals!(ts::TearingState) + neweqs = map(enumerate(get_eqs(ts.sys))) do kvp + i, eq = kvp + isalgeq = true + for j in 𝑠neighbors(ts.structure.graph, i) + isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing + end + if isalgeq + return 0 ~ eq.rhs - eq.lhs + else + return eq + end + end + copyto!(get_eqs(ts.sys), neweqs) +end + function structural_simplify!(state::TearingState, io = nothing; simplify = false, check_consistency = true, fully_determined = true, warn_initialize_determined = true, kwargs...) if state.sys isa ODESystem + # split_system returns one or two systems and the inputs for each + # mod clock inference to be binary + # if it's continous keep going, if not then error unless given trait impl in additional passes ci = ModelingToolkit.ClockInference(state) ci = ModelingToolkit.infer_clocks!(ci) time_domains = merge(Dict(state.fullvars .=> ci.var_domain), Dict(default_toterm.(state.fullvars) .=> ci.var_domain)) tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci) + if continuous_id == 0 + # do a trait check here - handle fully discrete system + additional_passes = get(kwargs, :additional_passes, nothing) + if !isnothing(additional_passes) && + any(discrete_compile_pass, additional_passes) + # take the first discrete compilation pass given for now + discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) + discrete_compile = additional_passes[discrete_pass_idx] + deleteat!(additional_passes, discrete_pass_idx) + return discrete_compile(tss, inputs, ci) + else + # error goes here! this is a purely discrete system + throw(HybridSystemNotSupportedException("Discrete systems without JuliaSimCompiler are currently not supported in ODESystem.")) + end + end + make_eqs_zero_equals!(tss[continuous_id]) + # puts the ios passed in to the call into the continous system cont_io = merge_io(io, inputs[continuous_id]) + # simplify as normal sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify, check_consistency, fully_determined, kwargs...) if length(tss) > 1 - if continuous_id > 0 + additional_passes = get(kwargs, :additional_passes, nothing) + if !isnothing(additional_passes) && + any(discrete_compile_pass, additional_passes) + discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) + discrete_compile = additional_passes[discrete_pass_idx] + deleteat!(additional_passes, discrete_pass_idx) + # in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems + # and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed + sys = discrete_compile(sys, tss[[i for i in eachindex(tss) if i != continuous_id]], inputs, ci) + else throw(HybridSystemNotSupportedException("Hybrid continuous-discrete systems are currently not supported with the standard MTK compiler. This system requires JuliaSimCompiler.jl, see https://help.juliahub.com/juliasimcompiler/stable/")) end - # TODO: rename it to something else - discrete_subsystems = Vector{ODESystem}(undef, length(tss)) - # Note that the appended_parameters must agree with - # `generate_discrete_affect`! - appended_parameters = parameters(sys) - for (i, state) in enumerate(tss) - if i == continuous_id - discrete_subsystems[i] = sys - continue - end - dist_io = merge_io(io, inputs[i]) - ss, = _structural_simplify!(state, dist_io; simplify, check_consistency, - fully_determined, kwargs...) - append!(appended_parameters, inputs[i], unknowns(ss)) - discrete_subsystems[i] = ss - end - @set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id, - id_to_clock - @set! sys.ps = appended_parameters - @set! sys.defaults = merge(ModelingToolkit.defaults(sys), - Dict(v => 0.0 for v in Iterators.flatten(inputs))) end ps = [sym isa CallWithMetadata ? sym : setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous())) diff --git a/src/utils.jl b/src/utils.jl index cf49d9f445..3b4ccd3be9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -401,6 +401,25 @@ vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op) function vars!(vars, eq::Equation; op = Differential) (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) end +function vars!(vars, O::AbstractSystem; op = Differential) + for eq in equations(O) + vars!(vars, eq; op) + end + return vars +end +function vars!(vars, O::Transition; op = Differential) + vars!(vars, O.from) + vars!(vars, O.to) + vars!(vars, O.cond; op) + return vars +end +function vars!(vars, O::InitialState; op = Differential) + vars!(vars, O.s; op) + return vars +end +function vars!(vars, O::StateMachineOperator; op = Differential) + error("Unhandled state machine operator") +end function vars!(vars, O; op = Differential) if isvariable(O) if iscall(O) && operation(O) === getindex && iscalledparameter(first(arguments(O)))