From d2c10eaec2eb5dae7729292db5f5b325dd5839de Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Sat, 13 Sep 2025 08:32:20 -0400 Subject: [PATCH 01/13] initial setup --- src/ModelingToolkit.jl | 1 + src/inputs.jl | 94 ++++++++++++++++++++++++++++++++++++++++++ test/inputs.jl | 64 ++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+) create mode 100644 src/inputs.jl create mode 100644 test/inputs.jl diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 4f29c5f428..662588032f 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -224,6 +224,7 @@ include("structural_transformation/StructuralTransformations.jl") @reexport using .StructuralTransformations include("inputoutput.jl") +include("inputs.jl") include("adjoints.jl") include("deprecations.jl") diff --git a/src/inputs.jl b/src/inputs.jl new file mode 100644 index 0000000000..f895acac7d --- /dev/null +++ b/src/inputs.jl @@ -0,0 +1,94 @@ +using SymbolicIndexingInterface +using Setfield +using StaticArrays + +struct Input + var::Num + data::SVector + time::SVector +end + +function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Union{Input, Vector{Input}}, args...; input_funs, kwargs...) + + set_input!, finalize! = input_funs + + tstops = Float64[] + callbacks = DiscreteCallback[] + if !isa(inputs, Vector) + inputs = [inputs] + end + + for input::Input in inputs + tstops = union(tstops, input.time) + + condition = (u,t,integrator) -> any(t .== input.time) + affect! = function (integrator) + i = findfirst(integrator.t .== input.time) + set_input!(integrator, input.var, input.data[i]) + end + callback = DiscreteCallback(condition, affect!) + + push!(callbacks, callback) + end + + # finalize! + t_end = prob.tspan[2] + condition = (u,t,integrator) -> (t == t_end) + affect! = (integrator) -> finalize!(integrator) + callback = DiscreteCallback(condition, affect!) + + push!(callbacks, callback) + push!(tstops, t_end) + + return solve(prob, args...; tstops, callback=CallbackSet(callbacks...), kwargs...) +end + +function setup_inputs(sys) + + inputs = ModelingToolkit.unbound_inputs(sys) + setters = Dict{Num, Function}() + + if !isempty(inputs) + sdcs = ModelingToolkit.SymbolicDiscreteCallback[] + for x in inputs + affect = ModelingToolkit.ImperativeAffect((m, o, c, i)->m, modified=(;x)) + sdc = ModelingToolkit.SymbolicDiscreteCallback(Inf, affect) + + push!(sdcs, sdc) + end + + @set! sys.discrete_events = sdcs + sys = complete(sys) # @set! sys.index_cache = ModelingToolkit.IndexCache(sys) + + for (i,x) in enumerate(inputs) + setter = SymbolicIndexingInterface.setsym(sys, x) + sdc = sdcs[i] + + setval = function (integrator, set, val=NaN) + if set + println("::setting $x to $val @ $(integrator.t)s") + setter(integrator, val) + else + println("::saving $x @ $(integrator.t)s") + end + ModelingToolkit.save_callback_discretes!(integrator, sdc) + end + + setters[x] = setval + end + end + + set_input! = function (integrator, var, value) + setters[var](integrator, true, value) + u_modified!(integrator, true) + end + + finalize! = function (integrator) + for ky in keys(setters) + setters[ky](integrator, false) + end + end + + return sys, set_input!, finalize! +end + diff --git a/test/inputs.jl b/test/inputs.jl new file mode 100644 index 0000000000..471ad748c6 --- /dev/null +++ b/test/inputs.jl @@ -0,0 +1,64 @@ +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq +using Plots +using Test +using StaticArrays + + + +# ----------------------------------------- +# ----- example --------------------------- +# ----------------------------------------- + +vars = @variables begin + x(t)=1, [input=true] + + # states + y(t) = 0 +end + +eqs = [ + # equations + D(y) ~ x + +] + +@mtkcompile sys = System(eqs, t, vars, []) inputs=[x] +sys, set_input!, finalize! = ModelingToolkit.setup_inputs(sys); +prob = ODEProblem(sys, [], (0, 4)) + +# indeterminate form ----------------------- + +integrator = init(prob, Tsit5()) + +set_input!(integrator, sys.x, 1.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 2.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 3.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 4.0) +step!(integrator, 1.0, true) + +finalize!(integrator) + +@test integrator.sol(0.0; idxs=sys.x) == 1.0 +@test integrator.sol(1.0; idxs=sys.x) == 2.0 +@test integrator.sol(2.0; idxs=sys.x) == 3.0 +@test integrator.sol(3.0; idxs=sys.x) == 4.0 +@test integrator.sol(4.0; idxs=sys.y) ≈ 10.0 + +# determinate form ----------------------- +input = ModelingToolkit.Input(sys.x, SA[1,2,3,4], SA[0,1,2,3]) +sol = solve(prob, input, Tsit5(); input_funs = (set_input!, finalize!)); + + +@test sol(0.0; idxs=sys.x) == 1.0 +@test sol(1.0; idxs=sys.x) == 2.0 +@test sol(2.0; idxs=sys.x) == 3.0 +@test sol(3.0; idxs=sys.x) == 4.0 +@test sol(4.0; idxs=sys.y) ≈ 10.0 \ No newline at end of file From 2d088e65605ece53ccd9eb424e0ef8dc9e26e01f Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Sun, 14 Sep 2025 19:25:11 -0400 Subject: [PATCH 02/13] unbound_inputs --- src/inputs.jl | 131 +++++++++++++++++++++++++---------------------- test/inputs.jl | 11 ++-- test/runtests.jl | 1 + 3 files changed, 78 insertions(+), 65 deletions(-) diff --git a/src/inputs.jl b/src/inputs.jl index f895acac7d..89d44faa03 100644 --- a/src/inputs.jl +++ b/src/inputs.jl @@ -8,87 +8,94 @@ struct Input time::SVector end -function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Union{Input, Vector{Input}}, args...; input_funs, kwargs...) +function Input(var, data::Vector{<:Real}, time::Vector{<:Real}) + n = length(data) + return Input(var, SVector{n}(data), SVector{n}(time)) +end - set_input!, finalize! = input_funs +struct InputFunctions + events::Tuple + vars::Tuple + setters::Tuple +end - tstops = Float64[] - callbacks = DiscreteCallback[] - if !isa(inputs, Vector) - inputs = [inputs] - end +InputFunctions(events::Vector, vars::Vector, setters::Vector) = InputFunctions(Tuple(events), Tuple(vars), Tuple(setters)) + +function set_input!(input_funs::InputFunctions, integrator, var, value::Real) + i = findfirst(isequal(var), input_funs.vars) + setter = input_funs.setters[i] + event = input_funs.events[i] + + setter(integrator, value) + save_callback_discretes!(integrator, event) + u_modified!(integrator, true) + return nothing +end - for input::Input in inputs - tstops = union(tstops, input.time) - - condition = (u,t,integrator) -> any(t .== input.time) - affect! = function (integrator) - i = findfirst(integrator.t .== input.time) - set_input!(integrator, input.var, input.data[i]) - end - callback = DiscreteCallback(condition, affect!) +function finalize!(input_funs::InputFunctions, integrator) - push!(callbacks, callback) + for i in eachindex(input_funs.vars) + save_callback_discretes!(integrator, input_funs.events[i]) end - # finalize! - t_end = prob.tspan[2] - condition = (u,t,integrator) -> (t == t_end) - affect! = (integrator) -> finalize!(integrator) - callback = DiscreteCallback(condition, affect!) + return nothing +end - push!(callbacks, callback) - push!(tstops, t_end) +(input_funs::InputFunctions)(integrator, var, value::Real) = set_input!(input_funs, integrator, var, value) +(input_funs::InputFunctions)(integrator) = finalize!(input_funs, integrator) - return solve(prob, args...; tstops, callback=CallbackSet(callbacks...), kwargs...) -end +function setup_inputs(sys, inputs = unbound_inputs(sys)) + + vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) for x in unwrap.(inputs)] + setters = [] + events = SymbolicDiscreteCallback[] + if !isempty(vars) + + for x in vars + affect = ImperativeAffect((m, o, c, i)->m, modified=(;x)) + sdc = SymbolicDiscreteCallback(Inf, affect) + + push!(events, sdc) + end -function setup_inputs(sys) + @set! sys.discrete_events = events + @set! sys.index_cache = ModelingToolkit.IndexCache(sys) - inputs = ModelingToolkit.unbound_inputs(sys) - setters = Dict{Num, Function}() + setters = [SymbolicIndexingInterface.setsym(sys, x) for x in vars] + + end - if !isempty(inputs) - sdcs = ModelingToolkit.SymbolicDiscreteCallback[] - for x in inputs - affect = ModelingToolkit.ImperativeAffect((m, o, c, i)->m, modified=(;x)) - sdc = ModelingToolkit.SymbolicDiscreteCallback(Inf, affect) + return sys, InputFunctions(events, vars, setters) +end - push!(sdcs, sdc) - end - @set! sys.discrete_events = sdcs - sys = complete(sys) # @set! sys.index_cache = ModelingToolkit.IndexCache(sys) - for (i,x) in enumerate(inputs) - setter = SymbolicIndexingInterface.setsym(sys, x) - sdc = sdcs[i] - setval = function (integrator, set, val=NaN) - if set - println("::setting $x to $val @ $(integrator.t)s") - setter(integrator, val) - else - println("::saving $x @ $(integrator.t)s") - end - ModelingToolkit.save_callback_discretes!(integrator, sdc) - end +function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Input}, args...; input_funs::InputFunctions, kwargs...) - setters[x] = setval - end - end + tstops = Float64[] + callbacks = DiscreteCallback[] - set_input! = function (integrator, var, value) - setters[var](integrator, true, value) - u_modified!(integrator, true) - end + for input::Input in inputs - finalize! = function (integrator) - for ky in keys(setters) - setters[ky](integrator, false) + tstops = union(tstops, input.time) + condition = (u,t,integrator) -> any(t .== input.time) + affect! = function (integrator) + @inbounds begin + i = findfirst(integrator.t .== input.time) + input_funs(integrator, input.var, input.data[i]) + end end + push!(callbacks, DiscreteCallback(condition, affect!)) + end - return sys, set_input!, finalize! -end + # finalize! + t_end = prob.tspan[2] + condition = (u,t,integrator) -> (t == t_end) + affect! = (integrator) -> input_funs(integrator) + push!(callbacks, DiscreteCallback(condition, affect!)) + push!(tstops, t_end) + return solve(prob, args...; tstops, callback=CallbackSet(callbacks...), kwargs...) +end diff --git a/test/inputs.jl b/test/inputs.jl index 471ad748c6..d1ff2d11be 100644 --- a/test/inputs.jl +++ b/test/inputs.jl @@ -25,13 +25,18 @@ eqs = [ ] @mtkcompile sys = System(eqs, t, vars, []) inputs=[x] -sys, set_input!, finalize! = ModelingToolkit.setup_inputs(sys); +ins = ModelingToolkit.unbound_inputs(sys) +# ins_ = [sys.x] +sys, input_funs = ModelingToolkit.setup_inputs(sys, ins); prob = ODEProblem(sys, [], (0, 4)) # indeterminate form ----------------------- integrator = init(prob, Tsit5()) +set_input! = input_funs +finalize! = input_funs + set_input!(integrator, sys.x, 1.0) step!(integrator, 1.0, true) @@ -54,11 +59,11 @@ finalize!(integrator) # determinate form ----------------------- input = ModelingToolkit.Input(sys.x, SA[1,2,3,4], SA[0,1,2,3]) -sol = solve(prob, input, Tsit5(); input_funs = (set_input!, finalize!)); +sol = solve(prob, [input], Tsit5(); input_funs); @test sol(0.0; idxs=sys.x) == 1.0 @test sol(1.0; idxs=sys.x) == 2.0 @test sol(2.0; idxs=sys.x) == 3.0 @test sol(3.0; idxs=sys.x) == 4.0 -@test sol(4.0; idxs=sys.y) ≈ 10.0 \ No newline at end of file +@test sol(4.0; idxs=sys.y) ≈ 10.0 diff --git a/test/runtests.jl b/test/runtests.jl index 522470b896..870ffa6433 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,7 @@ end @safetestset "Direct Usage Test" include("direct.jl") @safetestset "System Linearity Test" include("linearity.jl") @safetestset "Input Output Test" include("input_output_handling.jl") + @safetestset "Inputs" include("inputs.jl") @safetestset "Clock Test" include("clock.jl") @safetestset "ODESystem Test" include("odesystem.jl") @safetestset "Dynamic Quantities Test" include("dq_units.jl") From f2d80da2234ae0e308791e3ee59a8e247cb46375 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Tue, 28 Oct 2025 05:52:40 -0400 Subject: [PATCH 03/13] input_functions property and docs --- docs/src/basics/InputOutput.md | 72 +++++++++++++ src/ModelingToolkit.jl | 3 +- src/inputs.jl | 101 ------------------ src/systems/abstractsystem.jl | 3 +- src/systems/inputs.jl | 188 +++++++++++++++++++++++++++++++++ src/systems/system.jl | 9 +- src/systems/systems.jl | 5 + test/inputs.jl | 14 +-- 8 files changed, 281 insertions(+), 114 deletions(-) delete mode 100644 src/inputs.jl create mode 100644 src/systems/inputs.jl diff --git a/docs/src/basics/InputOutput.md b/docs/src/basics/InputOutput.md index b1eb2905df..33300616f9 100644 --- a/docs/src/basics/InputOutput.md +++ b/docs/src/basics/InputOutput.md @@ -87,6 +87,75 @@ See [Symbolic Metadata](@ref symbolic_metadata). Metadata specified when creatin See [Linearization](@ref linearization). +## Real-time Input Handling During Simulation + +ModelingToolkit supports setting input values during simulation for variables marked with the `[input=true]` metadata. This is useful for real-time simulations, hardware-in-the-loop testing, interactive simulations, or any scenario where input values need to be determined during integration rather than specified beforehand. + +To use this functionality, variables must be marked as inputs using the `[input=true]` metadata and specified in the `inputs` keyword argument of `@mtkcompile`. + +There are two approaches to handling inputs during simulation: + +### Determinate Form: Using `Input` Objects + +When all input values are known beforehand, you can use the [`Input`](@ref) type to specify input values at specific time points. The solver will automatically apply these values using discrete callbacks. + +```@example inputs +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq + +# Define system with an input variable +@variables x(t) [input=true] +@variables y(t) = 0 + +eqs = [D(y) ~ x] + +# Compile with inputs specified +@mtkcompile sys = System(eqs, t, [x, y], []) inputs=[x] + +prob = ODEProblem(sys, [], (0, 4)) + +# Create an Input object with predetermined values +input = Input(sys.x, [1, 2, 3, 4], [0, 1, 2, 3]) + +# Solve with the input - solver handles callbacks automatically +sol = solve(prob, [input], Tsit5()) + +plot(sol) +``` + +Multiple `Input` objects can be passed in a vector to handle multiple input variables simultaneously. + +### Indeterminate Form: Manual Input Setting with `set_input!` + +When input values need to be computed on-the-fly or depend on external data sources, you can manually set inputs during integration using [`set_input!`](@ref). This approach requires explicit control of the integration loop. + +```@example inputs +# Initialize the integrator +integrator = init(prob, Tsit5()) + +# Manually set inputs and step through time +set_input!(integrator, sys.x, 1.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 2.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 3.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 4.0) +step!(integrator, 1.0, true) + +# IMPORTANT: Must call finalize! to save all input callbacks +finalize!(integrator) + +plot(sol) +``` + +!!! warning "Always call `finalize!`" + When using `set_input!`, you must call [`finalize!`](@ref) after integration is complete. This ensures that all discrete callbacks associated with input variables are properly saved in the solution. Without this call, input values may not be correctly recorded when querying the solution. + ## Docstrings ```@index @@ -96,4 +165,7 @@ Pages = ["InputOutput.md"] ```@docs; canonical=false ModelingToolkit.generate_control_function ModelingToolkit.build_explicit_observed_function +ModelingToolkit.Input +ModelingToolkit.set_input! +ModelingToolkit.finalize! ``` diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 662588032f..08d2691ffb 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -159,6 +159,7 @@ include("constants.jl") include("utils.jl") +export set_input!, finalize!, Input include("systems/index_cache.jl") include("systems/parameter_buffer.jl") include("systems/abstractsystem.jl") @@ -169,6 +170,7 @@ include("systems/state_machines.jl") include("systems/analysis_points.jl") include("systems/imperative_affect.jl") include("systems/callbacks.jl") +include("systems/inputs.jl") include("systems/system.jl") include("systems/codegen_utils.jl") include("problems/docs.jl") @@ -224,7 +226,6 @@ include("structural_transformation/StructuralTransformations.jl") @reexport using .StructuralTransformations include("inputoutput.jl") -include("inputs.jl") include("adjoints.jl") include("deprecations.jl") diff --git a/src/inputs.jl b/src/inputs.jl deleted file mode 100644 index 89d44faa03..0000000000 --- a/src/inputs.jl +++ /dev/null @@ -1,101 +0,0 @@ -using SymbolicIndexingInterface -using Setfield -using StaticArrays - -struct Input - var::Num - data::SVector - time::SVector -end - -function Input(var, data::Vector{<:Real}, time::Vector{<:Real}) - n = length(data) - return Input(var, SVector{n}(data), SVector{n}(time)) -end - -struct InputFunctions - events::Tuple - vars::Tuple - setters::Tuple -end - -InputFunctions(events::Vector, vars::Vector, setters::Vector) = InputFunctions(Tuple(events), Tuple(vars), Tuple(setters)) - -function set_input!(input_funs::InputFunctions, integrator, var, value::Real) - i = findfirst(isequal(var), input_funs.vars) - setter = input_funs.setters[i] - event = input_funs.events[i] - - setter(integrator, value) - save_callback_discretes!(integrator, event) - u_modified!(integrator, true) - return nothing -end - -function finalize!(input_funs::InputFunctions, integrator) - - for i in eachindex(input_funs.vars) - save_callback_discretes!(integrator, input_funs.events[i]) - end - - return nothing -end - -(input_funs::InputFunctions)(integrator, var, value::Real) = set_input!(input_funs, integrator, var, value) -(input_funs::InputFunctions)(integrator) = finalize!(input_funs, integrator) - -function setup_inputs(sys, inputs = unbound_inputs(sys)) - - vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) for x in unwrap.(inputs)] - setters = [] - events = SymbolicDiscreteCallback[] - if !isempty(vars) - - for x in vars - affect = ImperativeAffect((m, o, c, i)->m, modified=(;x)) - sdc = SymbolicDiscreteCallback(Inf, affect) - - push!(events, sdc) - end - - @set! sys.discrete_events = events - @set! sys.index_cache = ModelingToolkit.IndexCache(sys) - - setters = [SymbolicIndexingInterface.setsym(sys, x) for x in vars] - - end - - return sys, InputFunctions(events, vars, setters) -end - - - - -function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Input}, args...; input_funs::InputFunctions, kwargs...) - - tstops = Float64[] - callbacks = DiscreteCallback[] - - for input::Input in inputs - - tstops = union(tstops, input.time) - condition = (u,t,integrator) -> any(t .== input.time) - affect! = function (integrator) - @inbounds begin - i = findfirst(integrator.t .== input.time) - input_funs(integrator, input.var, input.data[i]) - end - end - push!(callbacks, DiscreteCallback(condition, affect!)) - - end - - # finalize! - t_end = prob.tspan[2] - condition = (u,t,integrator) -> (t == t_end) - affect! = (integrator) -> input_funs(integrator) - push!(callbacks, DiscreteCallback(condition, affect!)) - push!(tstops, t_end) - - return solve(prob, args...; tstops, callback=CallbackSet(callbacks...), kwargs...) -end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 0bd05bb4b9..6d73e09ce1 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -789,7 +789,8 @@ const SYS_PROPS = [:eqs :index_cache :isscheduled :costs - :consolidate] + :consolidate + :input_functions] for prop in SYS_PROPS fname_get = Symbol(:get_, prop) diff --git a/src/systems/inputs.jl b/src/systems/inputs.jl new file mode 100644 index 0000000000..34dd90869a --- /dev/null +++ b/src/systems/inputs.jl @@ -0,0 +1,188 @@ +using SymbolicIndexingInterface +using Setfield +using StaticArrays +using OrdinaryDiffEqCore + +""" + Input(var, data::Vector{<:Real}, time::Vector{<:Real}) + +Create an `Input` object that specifies predetermined input values for a variable at specific time points. + +# Arguments +- `var`: The symbolic variable (marked with `[input=true]` metadata) to be used as an input. +- `data`: A vector of real values that the input variable should take at the corresponding time points. +- `time`: A vector of time points at which the input values should be applied. Must be the same length as `data`. + +# Description +The `Input` struct is used with the extended `solve` method to provide time-varying inputs to a system +during simulation. When passed to `solve(prob, [input1, input2, ...], alg)`, the solver will automatically +set the input variable to the specified values at the specified times using discrete callbacks. + +This provides a "determinate form" of input handling where all input values are known a priori, +as opposed to setting inputs manually during integration with [`set_input!`](@ref). + +See also [`set_input!`](@ref), [`finalize!`](@ref) +""" +struct Input + var::Num + data::SVector + time::SVector +end + +function Input(var, data::Vector{<:Real}, time::Vector{<:Real}) + n = length(data) + return Input(var, SVector{n}(data), SVector{n}(time)) +end + +struct InputFunctions + events::Tuple{SymbolicDiscreteCallback} + vars::Tuple{SymbolicUtils.BasicSymbolic{Real}} + setters::Tuple{SymbolicIndexingInterface.ParameterHookWrapper} +end + +InputFunctions(events::Vector, vars::Vector, setters::Vector) = InputFunctions(Tuple(events), Tuple(vars), Tuple(setters)) + +""" + set_input!(integrator, var, value::Real) + +Set the value of an input variable during integration. + +# Arguments +- `integrator`: An ODE integrator object (from `init(prob, alg)` or available in callbacks). +- `var`: The symbolic input variable to set (must be marked with `[input=true]` metadata and included in the `inputs` keyword of `@mtkcompile`). +- `value`: The new real-valued input to assign to the variable. +- `input_funs` (optional): The `InputFunctions` object associated with the system. If not provided, it will be retrieved from `integrator.f.sys`. + +# Description +This function allows you to manually set input values during integration, providing an "indeterminate form" +of input handling where inputs can be computed on-the-fly. This is useful when input values depend on +runtime conditions, external data sources, or interactive user input. + +After setting input values with `set_input!`, you must call [`finalize!`](@ref) at the end of integration +to ensure all discrete callbacks are properly saved. + +# Example +```julia +@variables x(t) [input=true] +@variables y(t) = 0 + +eqs = [D(y) ~ x] +@mtkcompile sys = System(eqs, t, [x, y], []) inputs=[x] + +prob = ODEProblem(sys, [], (0, 4)) +integrator = init(prob, Tsit5()) + +# Set input and step forward +set_input!(integrator, sys.x, 1.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 2.0) +step!(integrator, 1.0, true) + +# Must call finalize! at the end +finalize!(integrator) +``` + +See also [`finalize!`](@ref), [`Input`](@ref) +""" +function set_input!(input_funs::InputFunctions, integrator::OrdinaryDiffEqCore.ODEIntegrator, var, value::Real) + i = findfirst(isequal(var), input_funs.vars) + setter = input_funs.setters[i] + event = input_funs.events[i] + + setter(integrator, value) + save_callback_discretes!(integrator, event) + u_modified!(integrator, true) + return nothing +end +set_input!(integrator, var, value::Real) = set_input!(get_input_functions(integrator.f.sys), integrator, var, value) + +""" + finalize!(integrator) + +Finalize all input callbacks at the end of integration. + +# Arguments +- `integrator`: An ODE integrator object (from `init(prob, alg)` or available in callbacks). +- `input_funs` (optional): The `InputFunctions` object associated with the system. If not provided, it will be retrieved from `integrator.f.sys`. + +# Description +This function must be called after using [`set_input!`](@ref) to manually set input values during integration. +It ensures that all discrete callbacks associated with input variables are properly saved in the solution, +making the input values accessible when querying the solution at specific time points. + +Without calling `finalize!`, input values set with `set_input!` may not be correctly recorded in the +final solution object, leading to incorrect results when indexing the solution. + +See also [`set_input!`](@ref), [`Input`](@ref) +""" +function finalize!(input_funs::InputFunctions, integrator) + + for i in eachindex(input_funs.vars) + save_callback_discretes!(integrator, input_funs.events[i]) + end + + return nothing +end +finalize!(integrator) = finalize!(get_input_functions(integrator.f.sys), integrator) + +(input_funs::InputFunctions)(integrator, var, value::Real) = set_input!(input_funs, integrator, var, value) +(input_funs::InputFunctions)(integrator) = finalize!(input_funs, integrator) + +function build_input_functions(sys, inputs) + + vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) for x in unwrap.(inputs)] + setters = [] + events = SymbolicDiscreteCallback[] + if !isempty(vars) + + for x in vars + affect = ImperativeAffect((m, o, c, i)->m, modified=(;x)) + sdc = SymbolicDiscreteCallback(Inf, affect) + + push!(events, sdc) + end + + @set! sys.discrete_events = events + @set! sys.index_cache = ModelingToolkit.IndexCache(sys) + + setters = [SymbolicIndexingInterface.setsym(sys, x) for x in vars] + + end + + @set! sys.input_functions = InputFunctions(events, vars, setters) + + return sys +end + + + + +function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Input}, args...; kwargs...) + + tstops = Float64[] + callbacks = DiscreteCallback[] + + for input::Input in inputs + + tstops = union(tstops, input.time) + condition = (u,t,integrator) -> any(t .== input.time) + affect! = function (integrator) + @inbounds begin + i = findfirst(integrator.t .== input.time) + set_input!(integrator, input.var, input.data[i]) + end + end + push!(callbacks, DiscreteCallback(condition, affect!)) + + end + + # finalize! + t_end = prob.tspan[2] + condition = (u,t,integrator) -> (t == t_end) + affect! = (integrator) -> finalize!(integrator) + push!(callbacks, DiscreteCallback(condition, affect!)) + push!(tstops, t_end) + + return solve(prob, args...; tstops, callback=CallbackSet(callbacks...), kwargs...) +end diff --git a/src/systems/system.jl b/src/systems/system.jl index 6db36ebd36..7d90480ffd 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -259,6 +259,11 @@ struct System <: IntermediateDeprecationSystem The `Schedule` containing additional information about the simplified system. """ schedule::Union{Schedule, Nothing} + """ + $INTERNAL_FIELD_WARNING + Functions used to set input variables with `set_input!` and `finalize!` functions + """ + input_functions::Union{InputFunctions, Nothing} function System( tag, eqs, noise_eqs, jumps, constraints, costs, consolidate, unknowns, ps, @@ -271,7 +276,7 @@ struct System <: IntermediateDeprecationSystem complete = false, index_cache = nothing, ignored_connections = nothing, preface = nothing, parent = nothing, initializesystem = nothing, is_initializesystem = false, is_discrete = false, isscheduled = false, - schedule = nothing; checks::Union{Bool, Int} = true) + schedule = nothing, input_functions = nothing; checks::Union{Bool, Int} = true) if is_initializesystem && iv !== nothing throw(ArgumentError(""" Expected initialization system to be time-independent. Found independent @@ -310,7 +315,7 @@ struct System <: IntermediateDeprecationSystem tstops, inputs, outputs, tearing_state, namespacing, complete, index_cache, ignored_connections, preface, parent, initializesystem, is_initializesystem, is_discrete, - isscheduled, schedule) + isscheduled, schedule, input_functions) end end diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 4c52300239..23b5d6fdb1 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -51,6 +51,11 @@ function mtkcompile( @set! newsys.parent = complete(sys; split = false, flatten = false) end newsys = complete(newsys; split) + + if !isempty(inputs) + newsys = build_input_functions(newsys, inputs) + end + if newsys′ isa Tuple idxs = [parameter_index(newsys, i) for i in io[1]] return newsys, idxs diff --git a/test/inputs.jl b/test/inputs.jl index d1ff2d11be..0644b77050 100644 --- a/test/inputs.jl +++ b/test/inputs.jl @@ -3,8 +3,6 @@ using ModelingToolkit: t_nounits as t, D_nounits as D using OrdinaryDiffEq using Plots using Test -using StaticArrays - # ----------------------------------------- @@ -25,17 +23,15 @@ eqs = [ ] @mtkcompile sys = System(eqs, t, vars, []) inputs=[x] -ins = ModelingToolkit.unbound_inputs(sys) -# ins_ = [sys.x] -sys, input_funs = ModelingToolkit.setup_inputs(sys, ins); + +@test !isnothing(ModelingToolkit.get_input_functions(sys)) + prob = ODEProblem(sys, [], (0, 4)) # indeterminate form ----------------------- integrator = init(prob, Tsit5()) -set_input! = input_funs -finalize! = input_funs set_input!(integrator, sys.x, 1.0) step!(integrator, 1.0, true) @@ -58,8 +54,8 @@ finalize!(integrator) @test integrator.sol(4.0; idxs=sys.y) ≈ 10.0 # determinate form ----------------------- -input = ModelingToolkit.Input(sys.x, SA[1,2,3,4], SA[0,1,2,3]) -sol = solve(prob, [input], Tsit5(); input_funs); +input = Input(sys.x, [1,2,3,4], [0,1,2,3]) +sol = solve(prob, [input], Tsit5()); @test sol(0.0; idxs=sys.x) == 1.0 From f49800bac958374e5c7409e2a4fe6412d8d5e8a3 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Tue, 28 Oct 2025 17:31:24 -0400 Subject: [PATCH 04/13] full working version with docs --- docs/src/basics/InputOutput.md | 5 +++-- src/systems/inputs.jl | 25 +++++++++++++++++++++---- test/inputs.jl | 3 +-- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/docs/src/basics/InputOutput.md b/docs/src/basics/InputOutput.md index 33300616f9..dc05672337 100644 --- a/docs/src/basics/InputOutput.md +++ b/docs/src/basics/InputOutput.md @@ -103,6 +103,7 @@ When all input values are known beforehand, you can use the [`Input`](@ref) type using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D using OrdinaryDiffEq +using Plots # Define system with an input variable @variables x(t) [input=true] @@ -121,7 +122,7 @@ input = Input(sys.x, [1, 2, 3, 4], [0, 1, 2, 3]) # Solve with the input - solver handles callbacks automatically sol = solve(prob, [input], Tsit5()) -plot(sol) +plot(sol; idxs=[x,y]) ``` Multiple `Input` objects can be passed in a vector to handle multiple input variables simultaneously. @@ -150,7 +151,7 @@ step!(integrator, 1.0, true) # IMPORTANT: Must call finalize! to save all input callbacks finalize!(integrator) -plot(sol) +plot(sol; idxs=[x,y]) ``` !!! warning "Always call `finalize!`" diff --git a/src/systems/inputs.jl b/src/systems/inputs.jl index 34dd90869a..644a20b13f 100644 --- a/src/systems/inputs.jl +++ b/src/systems/inputs.jl @@ -131,9 +131,12 @@ finalize!(integrator) = finalize!(get_input_functions(integrator.f.sys), integra function build_input_functions(sys, inputs) - vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) for x in unwrap.(inputs)] + # Here we ensure the inputs have metadata marking the discrete variables as parameters. In some + # cases the inputs can be fed to this function before they are converted to parameters by mtkcompile. + vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) for x in unwrap.(inputs)] setters = [] events = SymbolicDiscreteCallback[] + defaults = get_defaults(sys) if !isempty(vars) for x in vars @@ -141,16 +144,24 @@ function build_input_functions(sys, inputs) sdc = SymbolicDiscreteCallback(Inf, affect) push!(events, sdc) + + # ensure that the ODEProblem does not complain about missing parameter map + if !haskey(defaults, x) + push!(defaults, x => 0.0) + end + end @set! sys.discrete_events = events @set! sys.index_cache = ModelingToolkit.IndexCache(sys) + @set! sys.defaults = defaults setters = [SymbolicIndexingInterface.setsym(sys, x) for x in vars] - - end + + @set! sys.input_functions = InputFunctions(events, vars, setters) - @set! sys.input_functions = InputFunctions(events, vars, setters) + end + return sys end @@ -163,6 +174,7 @@ function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Inpu tstops = Float64[] callbacks = DiscreteCallback[] + # set_input! for input::Input in inputs tstops = union(tstops, input.time) @@ -174,6 +186,11 @@ function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Inpu end end push!(callbacks, DiscreteCallback(condition, affect!)) + + # DiscreteCallback doesn't hit on t==0, workaround... + if input.time[1] == 0 + prob.ps[input.var] = input.data[1] + end end diff --git a/test/inputs.jl b/test/inputs.jl index 0644b77050..b9540a2d8d 100644 --- a/test/inputs.jl +++ b/test/inputs.jl @@ -1,7 +1,6 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D using OrdinaryDiffEq -using Plots using Test @@ -10,7 +9,7 @@ using Test # ----------------------------------------- vars = @variables begin - x(t)=1, [input=true] + x(t), [input=true] # states y(t) = 0 From 4ff37595ad4e9c2fb9acfa1948ba3223fcd35b3f Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Sat, 13 Sep 2025 08:32:20 -0400 Subject: [PATCH 05/13] initial setup --- src/ModelingToolkit.jl | 1 + src/inputs.jl | 94 ++++++++++++++++++++++++++++++++++++++++++ test/inputs.jl | 64 ++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+) create mode 100644 src/inputs.jl create mode 100644 test/inputs.jl diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 6c79eb4fb1..3895228999 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -230,6 +230,7 @@ include("structural_transformation/StructuralTransformations.jl") @reexport using .StructuralTransformations include("inputoutput.jl") +include("inputs.jl") include("adjoints.jl") include("deprecations.jl") diff --git a/src/inputs.jl b/src/inputs.jl new file mode 100644 index 0000000000..f895acac7d --- /dev/null +++ b/src/inputs.jl @@ -0,0 +1,94 @@ +using SymbolicIndexingInterface +using Setfield +using StaticArrays + +struct Input + var::Num + data::SVector + time::SVector +end + +function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Union{Input, Vector{Input}}, args...; input_funs, kwargs...) + + set_input!, finalize! = input_funs + + tstops = Float64[] + callbacks = DiscreteCallback[] + if !isa(inputs, Vector) + inputs = [inputs] + end + + for input::Input in inputs + tstops = union(tstops, input.time) + + condition = (u,t,integrator) -> any(t .== input.time) + affect! = function (integrator) + i = findfirst(integrator.t .== input.time) + set_input!(integrator, input.var, input.data[i]) + end + callback = DiscreteCallback(condition, affect!) + + push!(callbacks, callback) + end + + # finalize! + t_end = prob.tspan[2] + condition = (u,t,integrator) -> (t == t_end) + affect! = (integrator) -> finalize!(integrator) + callback = DiscreteCallback(condition, affect!) + + push!(callbacks, callback) + push!(tstops, t_end) + + return solve(prob, args...; tstops, callback=CallbackSet(callbacks...), kwargs...) +end + +function setup_inputs(sys) + + inputs = ModelingToolkit.unbound_inputs(sys) + setters = Dict{Num, Function}() + + if !isempty(inputs) + sdcs = ModelingToolkit.SymbolicDiscreteCallback[] + for x in inputs + affect = ModelingToolkit.ImperativeAffect((m, o, c, i)->m, modified=(;x)) + sdc = ModelingToolkit.SymbolicDiscreteCallback(Inf, affect) + + push!(sdcs, sdc) + end + + @set! sys.discrete_events = sdcs + sys = complete(sys) # @set! sys.index_cache = ModelingToolkit.IndexCache(sys) + + for (i,x) in enumerate(inputs) + setter = SymbolicIndexingInterface.setsym(sys, x) + sdc = sdcs[i] + + setval = function (integrator, set, val=NaN) + if set + println("::setting $x to $val @ $(integrator.t)s") + setter(integrator, val) + else + println("::saving $x @ $(integrator.t)s") + end + ModelingToolkit.save_callback_discretes!(integrator, sdc) + end + + setters[x] = setval + end + end + + set_input! = function (integrator, var, value) + setters[var](integrator, true, value) + u_modified!(integrator, true) + end + + finalize! = function (integrator) + for ky in keys(setters) + setters[ky](integrator, false) + end + end + + return sys, set_input!, finalize! +end + diff --git a/test/inputs.jl b/test/inputs.jl new file mode 100644 index 0000000000..471ad748c6 --- /dev/null +++ b/test/inputs.jl @@ -0,0 +1,64 @@ +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq +using Plots +using Test +using StaticArrays + + + +# ----------------------------------------- +# ----- example --------------------------- +# ----------------------------------------- + +vars = @variables begin + x(t)=1, [input=true] + + # states + y(t) = 0 +end + +eqs = [ + # equations + D(y) ~ x + +] + +@mtkcompile sys = System(eqs, t, vars, []) inputs=[x] +sys, set_input!, finalize! = ModelingToolkit.setup_inputs(sys); +prob = ODEProblem(sys, [], (0, 4)) + +# indeterminate form ----------------------- + +integrator = init(prob, Tsit5()) + +set_input!(integrator, sys.x, 1.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 2.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 3.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 4.0) +step!(integrator, 1.0, true) + +finalize!(integrator) + +@test integrator.sol(0.0; idxs=sys.x) == 1.0 +@test integrator.sol(1.0; idxs=sys.x) == 2.0 +@test integrator.sol(2.0; idxs=sys.x) == 3.0 +@test integrator.sol(3.0; idxs=sys.x) == 4.0 +@test integrator.sol(4.0; idxs=sys.y) ≈ 10.0 + +# determinate form ----------------------- +input = ModelingToolkit.Input(sys.x, SA[1,2,3,4], SA[0,1,2,3]) +sol = solve(prob, input, Tsit5(); input_funs = (set_input!, finalize!)); + + +@test sol(0.0; idxs=sys.x) == 1.0 +@test sol(1.0; idxs=sys.x) == 2.0 +@test sol(2.0; idxs=sys.x) == 3.0 +@test sol(3.0; idxs=sys.x) == 4.0 +@test sol(4.0; idxs=sys.y) ≈ 10.0 \ No newline at end of file From 41e262ad4b7538b87cab491bfbc8d127194f4577 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Sun, 14 Sep 2025 19:25:11 -0400 Subject: [PATCH 06/13] unbound_inputs --- src/inputs.jl | 131 +++++++++++++++++++++++++---------------------- test/inputs.jl | 11 ++-- test/runtests.jl | 1 + 3 files changed, 78 insertions(+), 65 deletions(-) diff --git a/src/inputs.jl b/src/inputs.jl index f895acac7d..89d44faa03 100644 --- a/src/inputs.jl +++ b/src/inputs.jl @@ -8,87 +8,94 @@ struct Input time::SVector end -function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Union{Input, Vector{Input}}, args...; input_funs, kwargs...) +function Input(var, data::Vector{<:Real}, time::Vector{<:Real}) + n = length(data) + return Input(var, SVector{n}(data), SVector{n}(time)) +end - set_input!, finalize! = input_funs +struct InputFunctions + events::Tuple + vars::Tuple + setters::Tuple +end - tstops = Float64[] - callbacks = DiscreteCallback[] - if !isa(inputs, Vector) - inputs = [inputs] - end +InputFunctions(events::Vector, vars::Vector, setters::Vector) = InputFunctions(Tuple(events), Tuple(vars), Tuple(setters)) + +function set_input!(input_funs::InputFunctions, integrator, var, value::Real) + i = findfirst(isequal(var), input_funs.vars) + setter = input_funs.setters[i] + event = input_funs.events[i] + + setter(integrator, value) + save_callback_discretes!(integrator, event) + u_modified!(integrator, true) + return nothing +end - for input::Input in inputs - tstops = union(tstops, input.time) - - condition = (u,t,integrator) -> any(t .== input.time) - affect! = function (integrator) - i = findfirst(integrator.t .== input.time) - set_input!(integrator, input.var, input.data[i]) - end - callback = DiscreteCallback(condition, affect!) +function finalize!(input_funs::InputFunctions, integrator) - push!(callbacks, callback) + for i in eachindex(input_funs.vars) + save_callback_discretes!(integrator, input_funs.events[i]) end - # finalize! - t_end = prob.tspan[2] - condition = (u,t,integrator) -> (t == t_end) - affect! = (integrator) -> finalize!(integrator) - callback = DiscreteCallback(condition, affect!) + return nothing +end - push!(callbacks, callback) - push!(tstops, t_end) +(input_funs::InputFunctions)(integrator, var, value::Real) = set_input!(input_funs, integrator, var, value) +(input_funs::InputFunctions)(integrator) = finalize!(input_funs, integrator) - return solve(prob, args...; tstops, callback=CallbackSet(callbacks...), kwargs...) -end +function setup_inputs(sys, inputs = unbound_inputs(sys)) + + vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) for x in unwrap.(inputs)] + setters = [] + events = SymbolicDiscreteCallback[] + if !isempty(vars) + + for x in vars + affect = ImperativeAffect((m, o, c, i)->m, modified=(;x)) + sdc = SymbolicDiscreteCallback(Inf, affect) + + push!(events, sdc) + end -function setup_inputs(sys) + @set! sys.discrete_events = events + @set! sys.index_cache = ModelingToolkit.IndexCache(sys) - inputs = ModelingToolkit.unbound_inputs(sys) - setters = Dict{Num, Function}() + setters = [SymbolicIndexingInterface.setsym(sys, x) for x in vars] + + end - if !isempty(inputs) - sdcs = ModelingToolkit.SymbolicDiscreteCallback[] - for x in inputs - affect = ModelingToolkit.ImperativeAffect((m, o, c, i)->m, modified=(;x)) - sdc = ModelingToolkit.SymbolicDiscreteCallback(Inf, affect) + return sys, InputFunctions(events, vars, setters) +end - push!(sdcs, sdc) - end - @set! sys.discrete_events = sdcs - sys = complete(sys) # @set! sys.index_cache = ModelingToolkit.IndexCache(sys) - for (i,x) in enumerate(inputs) - setter = SymbolicIndexingInterface.setsym(sys, x) - sdc = sdcs[i] - setval = function (integrator, set, val=NaN) - if set - println("::setting $x to $val @ $(integrator.t)s") - setter(integrator, val) - else - println("::saving $x @ $(integrator.t)s") - end - ModelingToolkit.save_callback_discretes!(integrator, sdc) - end +function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Input}, args...; input_funs::InputFunctions, kwargs...) - setters[x] = setval - end - end + tstops = Float64[] + callbacks = DiscreteCallback[] - set_input! = function (integrator, var, value) - setters[var](integrator, true, value) - u_modified!(integrator, true) - end + for input::Input in inputs - finalize! = function (integrator) - for ky in keys(setters) - setters[ky](integrator, false) + tstops = union(tstops, input.time) + condition = (u,t,integrator) -> any(t .== input.time) + affect! = function (integrator) + @inbounds begin + i = findfirst(integrator.t .== input.time) + input_funs(integrator, input.var, input.data[i]) + end end + push!(callbacks, DiscreteCallback(condition, affect!)) + end - return sys, set_input!, finalize! -end + # finalize! + t_end = prob.tspan[2] + condition = (u,t,integrator) -> (t == t_end) + affect! = (integrator) -> input_funs(integrator) + push!(callbacks, DiscreteCallback(condition, affect!)) + push!(tstops, t_end) + return solve(prob, args...; tstops, callback=CallbackSet(callbacks...), kwargs...) +end diff --git a/test/inputs.jl b/test/inputs.jl index 471ad748c6..d1ff2d11be 100644 --- a/test/inputs.jl +++ b/test/inputs.jl @@ -25,13 +25,18 @@ eqs = [ ] @mtkcompile sys = System(eqs, t, vars, []) inputs=[x] -sys, set_input!, finalize! = ModelingToolkit.setup_inputs(sys); +ins = ModelingToolkit.unbound_inputs(sys) +# ins_ = [sys.x] +sys, input_funs = ModelingToolkit.setup_inputs(sys, ins); prob = ODEProblem(sys, [], (0, 4)) # indeterminate form ----------------------- integrator = init(prob, Tsit5()) +set_input! = input_funs +finalize! = input_funs + set_input!(integrator, sys.x, 1.0) step!(integrator, 1.0, true) @@ -54,11 +59,11 @@ finalize!(integrator) # determinate form ----------------------- input = ModelingToolkit.Input(sys.x, SA[1,2,3,4], SA[0,1,2,3]) -sol = solve(prob, input, Tsit5(); input_funs = (set_input!, finalize!)); +sol = solve(prob, [input], Tsit5(); input_funs); @test sol(0.0; idxs=sys.x) == 1.0 @test sol(1.0; idxs=sys.x) == 2.0 @test sol(2.0; idxs=sys.x) == 3.0 @test sol(3.0; idxs=sys.x) == 4.0 -@test sol(4.0; idxs=sys.y) ≈ 10.0 \ No newline at end of file +@test sol(4.0; idxs=sys.y) ≈ 10.0 diff --git a/test/runtests.jl b/test/runtests.jl index 522470b896..870ffa6433 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,7 @@ end @safetestset "Direct Usage Test" include("direct.jl") @safetestset "System Linearity Test" include("linearity.jl") @safetestset "Input Output Test" include("input_output_handling.jl") + @safetestset "Inputs" include("inputs.jl") @safetestset "Clock Test" include("clock.jl") @safetestset "ODESystem Test" include("odesystem.jl") @safetestset "Dynamic Quantities Test" include("dq_units.jl") From c2050d3a57193df3dd8f4d652583f7d95dd9bd5b Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Tue, 28 Oct 2025 05:52:40 -0400 Subject: [PATCH 07/13] input_functions property and docs --- docs/src/basics/InputOutput.md | 72 +++++++++++++ src/ModelingToolkit.jl | 3 +- src/inputs.jl | 101 ------------------ src/systems/abstractsystem.jl | 3 +- src/systems/inputs.jl | 188 +++++++++++++++++++++++++++++++++ src/systems/system.jl | 9 +- src/systems/systems.jl | 5 + test/inputs.jl | 14 +-- 8 files changed, 281 insertions(+), 114 deletions(-) delete mode 100644 src/inputs.jl create mode 100644 src/systems/inputs.jl diff --git a/docs/src/basics/InputOutput.md b/docs/src/basics/InputOutput.md index b1eb2905df..33300616f9 100644 --- a/docs/src/basics/InputOutput.md +++ b/docs/src/basics/InputOutput.md @@ -87,6 +87,75 @@ See [Symbolic Metadata](@ref symbolic_metadata). Metadata specified when creatin See [Linearization](@ref linearization). +## Real-time Input Handling During Simulation + +ModelingToolkit supports setting input values during simulation for variables marked with the `[input=true]` metadata. This is useful for real-time simulations, hardware-in-the-loop testing, interactive simulations, or any scenario where input values need to be determined during integration rather than specified beforehand. + +To use this functionality, variables must be marked as inputs using the `[input=true]` metadata and specified in the `inputs` keyword argument of `@mtkcompile`. + +There are two approaches to handling inputs during simulation: + +### Determinate Form: Using `Input` Objects + +When all input values are known beforehand, you can use the [`Input`](@ref) type to specify input values at specific time points. The solver will automatically apply these values using discrete callbacks. + +```@example inputs +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq + +# Define system with an input variable +@variables x(t) [input=true] +@variables y(t) = 0 + +eqs = [D(y) ~ x] + +# Compile with inputs specified +@mtkcompile sys = System(eqs, t, [x, y], []) inputs=[x] + +prob = ODEProblem(sys, [], (0, 4)) + +# Create an Input object with predetermined values +input = Input(sys.x, [1, 2, 3, 4], [0, 1, 2, 3]) + +# Solve with the input - solver handles callbacks automatically +sol = solve(prob, [input], Tsit5()) + +plot(sol) +``` + +Multiple `Input` objects can be passed in a vector to handle multiple input variables simultaneously. + +### Indeterminate Form: Manual Input Setting with `set_input!` + +When input values need to be computed on-the-fly or depend on external data sources, you can manually set inputs during integration using [`set_input!`](@ref). This approach requires explicit control of the integration loop. + +```@example inputs +# Initialize the integrator +integrator = init(prob, Tsit5()) + +# Manually set inputs and step through time +set_input!(integrator, sys.x, 1.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 2.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 3.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 4.0) +step!(integrator, 1.0, true) + +# IMPORTANT: Must call finalize! to save all input callbacks +finalize!(integrator) + +plot(sol) +``` + +!!! warning "Always call `finalize!`" + When using `set_input!`, you must call [`finalize!`](@ref) after integration is complete. This ensures that all discrete callbacks associated with input variables are properly saved in the solution. Without this call, input values may not be correctly recorded when querying the solution. + ## Docstrings ```@index @@ -96,4 +165,7 @@ Pages = ["InputOutput.md"] ```@docs; canonical=false ModelingToolkit.generate_control_function ModelingToolkit.build_explicit_observed_function +ModelingToolkit.Input +ModelingToolkit.set_input! +ModelingToolkit.finalize! ``` diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 3895228999..d39d657390 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -165,6 +165,7 @@ include("constants.jl") include("utils.jl") +export set_input!, finalize!, Input include("systems/index_cache.jl") include("systems/parameter_buffer.jl") include("systems/abstractsystem.jl") @@ -175,6 +176,7 @@ include("systems/state_machines.jl") include("systems/analysis_points.jl") include("systems/imperative_affect.jl") include("systems/callbacks.jl") +include("systems/inputs.jl") include("systems/system.jl") include("systems/codegen_utils.jl") include("problems/docs.jl") @@ -230,7 +232,6 @@ include("structural_transformation/StructuralTransformations.jl") @reexport using .StructuralTransformations include("inputoutput.jl") -include("inputs.jl") include("adjoints.jl") include("deprecations.jl") diff --git a/src/inputs.jl b/src/inputs.jl deleted file mode 100644 index 89d44faa03..0000000000 --- a/src/inputs.jl +++ /dev/null @@ -1,101 +0,0 @@ -using SymbolicIndexingInterface -using Setfield -using StaticArrays - -struct Input - var::Num - data::SVector - time::SVector -end - -function Input(var, data::Vector{<:Real}, time::Vector{<:Real}) - n = length(data) - return Input(var, SVector{n}(data), SVector{n}(time)) -end - -struct InputFunctions - events::Tuple - vars::Tuple - setters::Tuple -end - -InputFunctions(events::Vector, vars::Vector, setters::Vector) = InputFunctions(Tuple(events), Tuple(vars), Tuple(setters)) - -function set_input!(input_funs::InputFunctions, integrator, var, value::Real) - i = findfirst(isequal(var), input_funs.vars) - setter = input_funs.setters[i] - event = input_funs.events[i] - - setter(integrator, value) - save_callback_discretes!(integrator, event) - u_modified!(integrator, true) - return nothing -end - -function finalize!(input_funs::InputFunctions, integrator) - - for i in eachindex(input_funs.vars) - save_callback_discretes!(integrator, input_funs.events[i]) - end - - return nothing -end - -(input_funs::InputFunctions)(integrator, var, value::Real) = set_input!(input_funs, integrator, var, value) -(input_funs::InputFunctions)(integrator) = finalize!(input_funs, integrator) - -function setup_inputs(sys, inputs = unbound_inputs(sys)) - - vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) for x in unwrap.(inputs)] - setters = [] - events = SymbolicDiscreteCallback[] - if !isempty(vars) - - for x in vars - affect = ImperativeAffect((m, o, c, i)->m, modified=(;x)) - sdc = SymbolicDiscreteCallback(Inf, affect) - - push!(events, sdc) - end - - @set! sys.discrete_events = events - @set! sys.index_cache = ModelingToolkit.IndexCache(sys) - - setters = [SymbolicIndexingInterface.setsym(sys, x) for x in vars] - - end - - return sys, InputFunctions(events, vars, setters) -end - - - - -function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Input}, args...; input_funs::InputFunctions, kwargs...) - - tstops = Float64[] - callbacks = DiscreteCallback[] - - for input::Input in inputs - - tstops = union(tstops, input.time) - condition = (u,t,integrator) -> any(t .== input.time) - affect! = function (integrator) - @inbounds begin - i = findfirst(integrator.t .== input.time) - input_funs(integrator, input.var, input.data[i]) - end - end - push!(callbacks, DiscreteCallback(condition, affect!)) - - end - - # finalize! - t_end = prob.tspan[2] - condition = (u,t,integrator) -> (t == t_end) - affect! = (integrator) -> input_funs(integrator) - push!(callbacks, DiscreteCallback(condition, affect!)) - push!(tstops, t_end) - - return solve(prob, args...; tstops, callback=CallbackSet(callbacks...), kwargs...) -end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 51a495e4ff..1c00079a9f 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -789,7 +789,8 @@ const SYS_PROPS = [:eqs :index_cache :isscheduled :costs - :consolidate] + :consolidate + :input_functions] for prop in SYS_PROPS fname_get = Symbol(:get_, prop) diff --git a/src/systems/inputs.jl b/src/systems/inputs.jl new file mode 100644 index 0000000000..34dd90869a --- /dev/null +++ b/src/systems/inputs.jl @@ -0,0 +1,188 @@ +using SymbolicIndexingInterface +using Setfield +using StaticArrays +using OrdinaryDiffEqCore + +""" + Input(var, data::Vector{<:Real}, time::Vector{<:Real}) + +Create an `Input` object that specifies predetermined input values for a variable at specific time points. + +# Arguments +- `var`: The symbolic variable (marked with `[input=true]` metadata) to be used as an input. +- `data`: A vector of real values that the input variable should take at the corresponding time points. +- `time`: A vector of time points at which the input values should be applied. Must be the same length as `data`. + +# Description +The `Input` struct is used with the extended `solve` method to provide time-varying inputs to a system +during simulation. When passed to `solve(prob, [input1, input2, ...], alg)`, the solver will automatically +set the input variable to the specified values at the specified times using discrete callbacks. + +This provides a "determinate form" of input handling where all input values are known a priori, +as opposed to setting inputs manually during integration with [`set_input!`](@ref). + +See also [`set_input!`](@ref), [`finalize!`](@ref) +""" +struct Input + var::Num + data::SVector + time::SVector +end + +function Input(var, data::Vector{<:Real}, time::Vector{<:Real}) + n = length(data) + return Input(var, SVector{n}(data), SVector{n}(time)) +end + +struct InputFunctions + events::Tuple{SymbolicDiscreteCallback} + vars::Tuple{SymbolicUtils.BasicSymbolic{Real}} + setters::Tuple{SymbolicIndexingInterface.ParameterHookWrapper} +end + +InputFunctions(events::Vector, vars::Vector, setters::Vector) = InputFunctions(Tuple(events), Tuple(vars), Tuple(setters)) + +""" + set_input!(integrator, var, value::Real) + +Set the value of an input variable during integration. + +# Arguments +- `integrator`: An ODE integrator object (from `init(prob, alg)` or available in callbacks). +- `var`: The symbolic input variable to set (must be marked with `[input=true]` metadata and included in the `inputs` keyword of `@mtkcompile`). +- `value`: The new real-valued input to assign to the variable. +- `input_funs` (optional): The `InputFunctions` object associated with the system. If not provided, it will be retrieved from `integrator.f.sys`. + +# Description +This function allows you to manually set input values during integration, providing an "indeterminate form" +of input handling where inputs can be computed on-the-fly. This is useful when input values depend on +runtime conditions, external data sources, or interactive user input. + +After setting input values with `set_input!`, you must call [`finalize!`](@ref) at the end of integration +to ensure all discrete callbacks are properly saved. + +# Example +```julia +@variables x(t) [input=true] +@variables y(t) = 0 + +eqs = [D(y) ~ x] +@mtkcompile sys = System(eqs, t, [x, y], []) inputs=[x] + +prob = ODEProblem(sys, [], (0, 4)) +integrator = init(prob, Tsit5()) + +# Set input and step forward +set_input!(integrator, sys.x, 1.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 2.0) +step!(integrator, 1.0, true) + +# Must call finalize! at the end +finalize!(integrator) +``` + +See also [`finalize!`](@ref), [`Input`](@ref) +""" +function set_input!(input_funs::InputFunctions, integrator::OrdinaryDiffEqCore.ODEIntegrator, var, value::Real) + i = findfirst(isequal(var), input_funs.vars) + setter = input_funs.setters[i] + event = input_funs.events[i] + + setter(integrator, value) + save_callback_discretes!(integrator, event) + u_modified!(integrator, true) + return nothing +end +set_input!(integrator, var, value::Real) = set_input!(get_input_functions(integrator.f.sys), integrator, var, value) + +""" + finalize!(integrator) + +Finalize all input callbacks at the end of integration. + +# Arguments +- `integrator`: An ODE integrator object (from `init(prob, alg)` or available in callbacks). +- `input_funs` (optional): The `InputFunctions` object associated with the system. If not provided, it will be retrieved from `integrator.f.sys`. + +# Description +This function must be called after using [`set_input!`](@ref) to manually set input values during integration. +It ensures that all discrete callbacks associated with input variables are properly saved in the solution, +making the input values accessible when querying the solution at specific time points. + +Without calling `finalize!`, input values set with `set_input!` may not be correctly recorded in the +final solution object, leading to incorrect results when indexing the solution. + +See also [`set_input!`](@ref), [`Input`](@ref) +""" +function finalize!(input_funs::InputFunctions, integrator) + + for i in eachindex(input_funs.vars) + save_callback_discretes!(integrator, input_funs.events[i]) + end + + return nothing +end +finalize!(integrator) = finalize!(get_input_functions(integrator.f.sys), integrator) + +(input_funs::InputFunctions)(integrator, var, value::Real) = set_input!(input_funs, integrator, var, value) +(input_funs::InputFunctions)(integrator) = finalize!(input_funs, integrator) + +function build_input_functions(sys, inputs) + + vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) for x in unwrap.(inputs)] + setters = [] + events = SymbolicDiscreteCallback[] + if !isempty(vars) + + for x in vars + affect = ImperativeAffect((m, o, c, i)->m, modified=(;x)) + sdc = SymbolicDiscreteCallback(Inf, affect) + + push!(events, sdc) + end + + @set! sys.discrete_events = events + @set! sys.index_cache = ModelingToolkit.IndexCache(sys) + + setters = [SymbolicIndexingInterface.setsym(sys, x) for x in vars] + + end + + @set! sys.input_functions = InputFunctions(events, vars, setters) + + return sys +end + + + + +function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Input}, args...; kwargs...) + + tstops = Float64[] + callbacks = DiscreteCallback[] + + for input::Input in inputs + + tstops = union(tstops, input.time) + condition = (u,t,integrator) -> any(t .== input.time) + affect! = function (integrator) + @inbounds begin + i = findfirst(integrator.t .== input.time) + set_input!(integrator, input.var, input.data[i]) + end + end + push!(callbacks, DiscreteCallback(condition, affect!)) + + end + + # finalize! + t_end = prob.tspan[2] + condition = (u,t,integrator) -> (t == t_end) + affect! = (integrator) -> finalize!(integrator) + push!(callbacks, DiscreteCallback(condition, affect!)) + push!(tstops, t_end) + + return solve(prob, args...; tstops, callback=CallbackSet(callbacks...), kwargs...) +end diff --git a/src/systems/system.jl b/src/systems/system.jl index 12f5d1d250..87f81403cd 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -259,6 +259,11 @@ struct System <: IntermediateDeprecationSystem The `Schedule` containing additional information about the simplified system. """ schedule::Union{Schedule, Nothing} + """ + $INTERNAL_FIELD_WARNING + Functions used to set input variables with `set_input!` and `finalize!` functions + """ + input_functions::Union{InputFunctions, Nothing} function System( tag, eqs, noise_eqs, jumps, constraints, costs, consolidate, unknowns, ps, @@ -271,7 +276,7 @@ struct System <: IntermediateDeprecationSystem complete = false, index_cache = nothing, ignored_connections = nothing, preface = nothing, parent = nothing, initializesystem = nothing, is_initializesystem = false, is_discrete = false, isscheduled = false, - schedule = nothing; checks::Union{Bool, Int} = true) + schedule = nothing, input_functions = nothing; checks::Union{Bool, Int} = true) if is_initializesystem && iv !== nothing throw(ArgumentError(""" Expected initialization system to be time-independent. Found independent @@ -310,7 +315,7 @@ struct System <: IntermediateDeprecationSystem tstops, inputs, outputs, tearing_state, namespacing, complete, index_cache, ignored_connections, preface, parent, initializesystem, is_initializesystem, is_discrete, - isscheduled, schedule) + isscheduled, schedule, input_functions) end end diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 49efb3e1d4..a5b0325129 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -51,6 +51,11 @@ function mtkcompile( @set! newsys.parent = complete(sys; split = false, flatten = false) end newsys = complete(newsys; split) + + if !isempty(inputs) + newsys = build_input_functions(newsys, inputs) + end + if newsys′ isa Tuple idxs = [parameter_index(newsys, i) for i in io[1]] return newsys, idxs diff --git a/test/inputs.jl b/test/inputs.jl index d1ff2d11be..0644b77050 100644 --- a/test/inputs.jl +++ b/test/inputs.jl @@ -3,8 +3,6 @@ using ModelingToolkit: t_nounits as t, D_nounits as D using OrdinaryDiffEq using Plots using Test -using StaticArrays - # ----------------------------------------- @@ -25,17 +23,15 @@ eqs = [ ] @mtkcompile sys = System(eqs, t, vars, []) inputs=[x] -ins = ModelingToolkit.unbound_inputs(sys) -# ins_ = [sys.x] -sys, input_funs = ModelingToolkit.setup_inputs(sys, ins); + +@test !isnothing(ModelingToolkit.get_input_functions(sys)) + prob = ODEProblem(sys, [], (0, 4)) # indeterminate form ----------------------- integrator = init(prob, Tsit5()) -set_input! = input_funs -finalize! = input_funs set_input!(integrator, sys.x, 1.0) step!(integrator, 1.0, true) @@ -58,8 +54,8 @@ finalize!(integrator) @test integrator.sol(4.0; idxs=sys.y) ≈ 10.0 # determinate form ----------------------- -input = ModelingToolkit.Input(sys.x, SA[1,2,3,4], SA[0,1,2,3]) -sol = solve(prob, [input], Tsit5(); input_funs); +input = Input(sys.x, [1,2,3,4], [0,1,2,3]) +sol = solve(prob, [input], Tsit5()); @test sol(0.0; idxs=sys.x) == 1.0 From 85e13e58f59b7901620975b83e36545540cf4c6a Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Tue, 28 Oct 2025 17:31:24 -0400 Subject: [PATCH 08/13] full working version with docs --- docs/src/basics/InputOutput.md | 5 +++-- src/systems/inputs.jl | 25 +++++++++++++++++++++---- test/inputs.jl | 3 +-- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/docs/src/basics/InputOutput.md b/docs/src/basics/InputOutput.md index 33300616f9..dc05672337 100644 --- a/docs/src/basics/InputOutput.md +++ b/docs/src/basics/InputOutput.md @@ -103,6 +103,7 @@ When all input values are known beforehand, you can use the [`Input`](@ref) type using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D using OrdinaryDiffEq +using Plots # Define system with an input variable @variables x(t) [input=true] @@ -121,7 +122,7 @@ input = Input(sys.x, [1, 2, 3, 4], [0, 1, 2, 3]) # Solve with the input - solver handles callbacks automatically sol = solve(prob, [input], Tsit5()) -plot(sol) +plot(sol; idxs=[x,y]) ``` Multiple `Input` objects can be passed in a vector to handle multiple input variables simultaneously. @@ -150,7 +151,7 @@ step!(integrator, 1.0, true) # IMPORTANT: Must call finalize! to save all input callbacks finalize!(integrator) -plot(sol) +plot(sol; idxs=[x,y]) ``` !!! warning "Always call `finalize!`" diff --git a/src/systems/inputs.jl b/src/systems/inputs.jl index 34dd90869a..644a20b13f 100644 --- a/src/systems/inputs.jl +++ b/src/systems/inputs.jl @@ -131,9 +131,12 @@ finalize!(integrator) = finalize!(get_input_functions(integrator.f.sys), integra function build_input_functions(sys, inputs) - vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) for x in unwrap.(inputs)] + # Here we ensure the inputs have metadata marking the discrete variables as parameters. In some + # cases the inputs can be fed to this function before they are converted to parameters by mtkcompile. + vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) for x in unwrap.(inputs)] setters = [] events = SymbolicDiscreteCallback[] + defaults = get_defaults(sys) if !isempty(vars) for x in vars @@ -141,16 +144,24 @@ function build_input_functions(sys, inputs) sdc = SymbolicDiscreteCallback(Inf, affect) push!(events, sdc) + + # ensure that the ODEProblem does not complain about missing parameter map + if !haskey(defaults, x) + push!(defaults, x => 0.0) + end + end @set! sys.discrete_events = events @set! sys.index_cache = ModelingToolkit.IndexCache(sys) + @set! sys.defaults = defaults setters = [SymbolicIndexingInterface.setsym(sys, x) for x in vars] - - end + + @set! sys.input_functions = InputFunctions(events, vars, setters) - @set! sys.input_functions = InputFunctions(events, vars, setters) + end + return sys end @@ -163,6 +174,7 @@ function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Inpu tstops = Float64[] callbacks = DiscreteCallback[] + # set_input! for input::Input in inputs tstops = union(tstops, input.time) @@ -174,6 +186,11 @@ function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Inpu end end push!(callbacks, DiscreteCallback(condition, affect!)) + + # DiscreteCallback doesn't hit on t==0, workaround... + if input.time[1] == 0 + prob.ps[input.var] = input.data[1] + end end diff --git a/test/inputs.jl b/test/inputs.jl index 0644b77050..b9540a2d8d 100644 --- a/test/inputs.jl +++ b/test/inputs.jl @@ -1,7 +1,6 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D using OrdinaryDiffEq -using Plots using Test @@ -10,7 +9,7 @@ using Test # ----------------------------------------- vars = @variables begin - x(t)=1, [input=true] + x(t), [input=true] # states y(t) = 0 From ac9d23227bb6000c46be805487f2240b39511bc3 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Tue, 28 Oct 2025 17:35:58 -0400 Subject: [PATCH 09/13] format --- src/systems/inputs.jl | 406 +++++++++++++++++++++--------------------- test/inputs.jl | 124 +++++++------ test/runtests.jl | 2 +- 3 files changed, 262 insertions(+), 270 deletions(-) diff --git a/src/systems/inputs.jl b/src/systems/inputs.jl index 644a20b13f..799d09f9b4 100644 --- a/src/systems/inputs.jl +++ b/src/systems/inputs.jl @@ -1,205 +1,201 @@ -using SymbolicIndexingInterface -using Setfield -using StaticArrays -using OrdinaryDiffEqCore - -""" - Input(var, data::Vector{<:Real}, time::Vector{<:Real}) - -Create an `Input` object that specifies predetermined input values for a variable at specific time points. - -# Arguments -- `var`: The symbolic variable (marked with `[input=true]` metadata) to be used as an input. -- `data`: A vector of real values that the input variable should take at the corresponding time points. -- `time`: A vector of time points at which the input values should be applied. Must be the same length as `data`. - -# Description -The `Input` struct is used with the extended `solve` method to provide time-varying inputs to a system -during simulation. When passed to `solve(prob, [input1, input2, ...], alg)`, the solver will automatically -set the input variable to the specified values at the specified times using discrete callbacks. - -This provides a "determinate form" of input handling where all input values are known a priori, -as opposed to setting inputs manually during integration with [`set_input!`](@ref). - -See also [`set_input!`](@ref), [`finalize!`](@ref) -""" -struct Input - var::Num - data::SVector - time::SVector -end - -function Input(var, data::Vector{<:Real}, time::Vector{<:Real}) - n = length(data) - return Input(var, SVector{n}(data), SVector{n}(time)) -end - -struct InputFunctions - events::Tuple{SymbolicDiscreteCallback} - vars::Tuple{SymbolicUtils.BasicSymbolic{Real}} - setters::Tuple{SymbolicIndexingInterface.ParameterHookWrapper} -end - -InputFunctions(events::Vector, vars::Vector, setters::Vector) = InputFunctions(Tuple(events), Tuple(vars), Tuple(setters)) - -""" - set_input!(integrator, var, value::Real) - -Set the value of an input variable during integration. - -# Arguments -- `integrator`: An ODE integrator object (from `init(prob, alg)` or available in callbacks). -- `var`: The symbolic input variable to set (must be marked with `[input=true]` metadata and included in the `inputs` keyword of `@mtkcompile`). -- `value`: The new real-valued input to assign to the variable. -- `input_funs` (optional): The `InputFunctions` object associated with the system. If not provided, it will be retrieved from `integrator.f.sys`. - -# Description -This function allows you to manually set input values during integration, providing an "indeterminate form" -of input handling where inputs can be computed on-the-fly. This is useful when input values depend on -runtime conditions, external data sources, or interactive user input. - -After setting input values with `set_input!`, you must call [`finalize!`](@ref) at the end of integration -to ensure all discrete callbacks are properly saved. - -# Example -```julia -@variables x(t) [input=true] -@variables y(t) = 0 - -eqs = [D(y) ~ x] -@mtkcompile sys = System(eqs, t, [x, y], []) inputs=[x] - -prob = ODEProblem(sys, [], (0, 4)) -integrator = init(prob, Tsit5()) - -# Set input and step forward -set_input!(integrator, sys.x, 1.0) -step!(integrator, 1.0, true) - -set_input!(integrator, sys.x, 2.0) -step!(integrator, 1.0, true) - -# Must call finalize! at the end -finalize!(integrator) -``` - -See also [`finalize!`](@ref), [`Input`](@ref) -""" -function set_input!(input_funs::InputFunctions, integrator::OrdinaryDiffEqCore.ODEIntegrator, var, value::Real) - i = findfirst(isequal(var), input_funs.vars) - setter = input_funs.setters[i] - event = input_funs.events[i] - - setter(integrator, value) - save_callback_discretes!(integrator, event) - u_modified!(integrator, true) - return nothing -end -set_input!(integrator, var, value::Real) = set_input!(get_input_functions(integrator.f.sys), integrator, var, value) - -""" - finalize!(integrator) - -Finalize all input callbacks at the end of integration. - -# Arguments -- `integrator`: An ODE integrator object (from `init(prob, alg)` or available in callbacks). -- `input_funs` (optional): The `InputFunctions` object associated with the system. If not provided, it will be retrieved from `integrator.f.sys`. - -# Description -This function must be called after using [`set_input!`](@ref) to manually set input values during integration. -It ensures that all discrete callbacks associated with input variables are properly saved in the solution, -making the input values accessible when querying the solution at specific time points. - -Without calling `finalize!`, input values set with `set_input!` may not be correctly recorded in the -final solution object, leading to incorrect results when indexing the solution. - -See also [`set_input!`](@ref), [`Input`](@ref) -""" -function finalize!(input_funs::InputFunctions, integrator) - - for i in eachindex(input_funs.vars) - save_callback_discretes!(integrator, input_funs.events[i]) - end - - return nothing -end -finalize!(integrator) = finalize!(get_input_functions(integrator.f.sys), integrator) - -(input_funs::InputFunctions)(integrator, var, value::Real) = set_input!(input_funs, integrator, var, value) -(input_funs::InputFunctions)(integrator) = finalize!(input_funs, integrator) - -function build_input_functions(sys, inputs) - - # Here we ensure the inputs have metadata marking the discrete variables as parameters. In some - # cases the inputs can be fed to this function before they are converted to parameters by mtkcompile. - vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) for x in unwrap.(inputs)] - setters = [] - events = SymbolicDiscreteCallback[] - defaults = get_defaults(sys) - if !isempty(vars) - - for x in vars - affect = ImperativeAffect((m, o, c, i)->m, modified=(;x)) - sdc = SymbolicDiscreteCallback(Inf, affect) - - push!(events, sdc) - - # ensure that the ODEProblem does not complain about missing parameter map - if !haskey(defaults, x) - push!(defaults, x => 0.0) - end - - end - - @set! sys.discrete_events = events - @set! sys.index_cache = ModelingToolkit.IndexCache(sys) - @set! sys.defaults = defaults - - setters = [SymbolicIndexingInterface.setsym(sys, x) for x in vars] - - @set! sys.input_functions = InputFunctions(events, vars, setters) - - end - - - return sys -end - - - - -function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Input}, args...; kwargs...) - - tstops = Float64[] - callbacks = DiscreteCallback[] - - # set_input! - for input::Input in inputs - - tstops = union(tstops, input.time) - condition = (u,t,integrator) -> any(t .== input.time) - affect! = function (integrator) - @inbounds begin - i = findfirst(integrator.t .== input.time) - set_input!(integrator, input.var, input.data[i]) - end - end - push!(callbacks, DiscreteCallback(condition, affect!)) - - # DiscreteCallback doesn't hit on t==0, workaround... - if input.time[1] == 0 - prob.ps[input.var] = input.data[1] - end - - end - - # finalize! - t_end = prob.tspan[2] - condition = (u,t,integrator) -> (t == t_end) - affect! = (integrator) -> finalize!(integrator) - push!(callbacks, DiscreteCallback(condition, affect!)) - push!(tstops, t_end) - - return solve(prob, args...; tstops, callback=CallbackSet(callbacks...), kwargs...) -end +using SymbolicIndexingInterface +using Setfield +using StaticArrays +using OrdinaryDiffEqCore + +""" + Input(var, data::Vector{<:Real}, time::Vector{<:Real}) + +Create an `Input` object that specifies predetermined input values for a variable at specific time points. + +# Arguments +- `var`: The symbolic variable (marked with `[input=true]` metadata) to be used as an input. +- `data`: A vector of real values that the input variable should take at the corresponding time points. +- `time`: A vector of time points at which the input values should be applied. Must be the same length as `data`. + +# Description +The `Input` struct is used with the extended `solve` method to provide time-varying inputs to a system +during simulation. When passed to `solve(prob, [input1, input2, ...], alg)`, the solver will automatically +set the input variable to the specified values at the specified times using discrete callbacks. + +This provides a "determinate form" of input handling where all input values are known a priori, +as opposed to setting inputs manually during integration with [`set_input!`](@ref). + +See also [`set_input!`](@ref), [`finalize!`](@ref) +""" +struct Input + var::Num + data::SVector + time::SVector +end + +function Input(var, data::Vector{<:Real}, time::Vector{<:Real}) + n = length(data) + return Input(var, SVector{n}(data), SVector{n}(time)) +end + +struct InputFunctions + events::Tuple{SymbolicDiscreteCallback} + vars::Tuple{SymbolicUtils.BasicSymbolic{Real}} + setters::Tuple{SymbolicIndexingInterface.ParameterHookWrapper} +end + +function InputFunctions(events::Vector, vars::Vector, setters::Vector) + InputFunctions(Tuple(events), Tuple(vars), Tuple(setters)) +end + +""" + set_input!(integrator, var, value::Real) + +Set the value of an input variable during integration. + +# Arguments +- `integrator`: An ODE integrator object (from `init(prob, alg)` or available in callbacks). +- `var`: The symbolic input variable to set (must be marked with `[input=true]` metadata and included in the `inputs` keyword of `@mtkcompile`). +- `value`: The new real-valued input to assign to the variable. +- `input_funs` (optional): The `InputFunctions` object associated with the system. If not provided, it will be retrieved from `integrator.f.sys`. + +# Description +This function allows you to manually set input values during integration, providing an "indeterminate form" +of input handling where inputs can be computed on-the-fly. This is useful when input values depend on +runtime conditions, external data sources, or interactive user input. + +After setting input values with `set_input!`, you must call [`finalize!`](@ref) at the end of integration +to ensure all discrete callbacks are properly saved. + +# Example +```julia +@variables x(t) [input=true] +@variables y(t) = 0 + +eqs = [D(y) ~ x] +@mtkcompile sys = System(eqs, t, [x, y], []) inputs=[x] + +prob = ODEProblem(sys, [], (0, 4)) +integrator = init(prob, Tsit5()) + +# Set input and step forward +set_input!(integrator, sys.x, 1.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 2.0) +step!(integrator, 1.0, true) + +# Must call finalize! at the end +finalize!(integrator) +``` + +See also [`finalize!`](@ref), [`Input`](@ref) +""" +function set_input!(input_funs::InputFunctions, integrator::OrdinaryDiffEqCore.ODEIntegrator, var, value::Real) + i = findfirst(isequal(var), input_funs.vars) + setter = input_funs.setters[i] + event = input_funs.events[i] + + setter(integrator, value) + save_callback_discretes!(integrator, event) + u_modified!(integrator, true) + return nothing +end +function set_input!(integrator, var, value::Real) + set_input!(get_input_functions(integrator.f.sys), integrator, var, value) +end + +""" + finalize!(integrator) + +Finalize all input callbacks at the end of integration. + +# Arguments +- `integrator`: An ODE integrator object (from `init(prob, alg)` or available in callbacks). +- `input_funs` (optional): The `InputFunctions` object associated with the system. If not provided, it will be retrieved from `integrator.f.sys`. + +# Description +This function must be called after using [`set_input!`](@ref) to manually set input values during integration. +It ensures that all discrete callbacks associated with input variables are properly saved in the solution, +making the input values accessible when querying the solution at specific time points. + +Without calling `finalize!`, input values set with `set_input!` may not be correctly recorded in the +final solution object, leading to incorrect results when indexing the solution. + +See also [`set_input!`](@ref), [`Input`](@ref) +""" +function finalize!(input_funs::InputFunctions, integrator) + for i in eachindex(input_funs.vars) + save_callback_discretes!(integrator, input_funs.events[i]) + end + + return nothing +end +finalize!(integrator) = finalize!(get_input_functions(integrator.f.sys), integrator) + +function (input_funs::InputFunctions)(integrator, var, value::Real) + set_input!(input_funs, integrator, var, value) +end +(input_funs::InputFunctions)(integrator) = finalize!(input_funs, integrator) + +function build_input_functions(sys, inputs) + + # Here we ensure the inputs have metadata marking the discrete variables as parameters. In some + # cases the inputs can be fed to this function before they are converted to parameters by mtkcompile. + vars = SymbolicUtils.BasicSymbolic[isparameter(x) ? x : toparam(x) + for x in unwrap.(inputs)] + setters = [] + events = SymbolicDiscreteCallback[] + defaults = get_defaults(sys) + if !isempty(vars) + for x in vars + affect = ImperativeAffect((m, o, c, i)->m, modified = (; x)) + sdc = SymbolicDiscreteCallback(Inf, affect) + + push!(events, sdc) + + # ensure that the ODEProblem does not complain about missing parameter map + if !haskey(defaults, x) + push!(defaults, x => 0.0) + end + end + + @set! sys.discrete_events = events + @set! sys.index_cache = ModelingToolkit.IndexCache(sys) + @set! sys.defaults = defaults + + setters = [SymbolicIndexingInterface.setsym(sys, x) for x in vars] + + @set! sys.input_functions = InputFunctions(events, vars, setters) + end + + return sys +end + +function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Input}, args...; kwargs...) + tstops = Float64[] + callbacks = DiscreteCallback[] + + # set_input! + for input::Input in inputs + tstops = union(tstops, input.time) + condition = (u, t, integrator) -> any(t .== input.time) + affect! = function (integrator) + @inbounds begin + i = findfirst(integrator.t .== input.time) + set_input!(integrator, input.var, input.data[i]) + end + end + push!(callbacks, DiscreteCallback(condition, affect!)) + + # DiscreteCallback doesn't hit on t==0, workaround... + if input.time[1] == 0 + prob.ps[input.var] = input.data[1] + end + end + + # finalize! + t_end = prob.tspan[2] + condition = (u, t, integrator) -> (t == t_end) + affect! = (integrator) -> finalize!(integrator) + push!(callbacks, DiscreteCallback(condition, affect!)) + push!(tstops, t_end) + + return solve(prob, args...; tstops, callback = CallbackSet(callbacks...), kwargs...) +end diff --git a/test/inputs.jl b/test/inputs.jl index b9540a2d8d..0f82693ff4 100644 --- a/test/inputs.jl +++ b/test/inputs.jl @@ -1,64 +1,60 @@ -using ModelingToolkit -using ModelingToolkit: t_nounits as t, D_nounits as D -using OrdinaryDiffEq -using Test - - -# ----------------------------------------- -# ----- example --------------------------- -# ----------------------------------------- - -vars = @variables begin - x(t), [input=true] - - # states - y(t) = 0 -end - -eqs = [ - # equations - D(y) ~ x - -] - -@mtkcompile sys = System(eqs, t, vars, []) inputs=[x] - -@test !isnothing(ModelingToolkit.get_input_functions(sys)) - -prob = ODEProblem(sys, [], (0, 4)) - -# indeterminate form ----------------------- - -integrator = init(prob, Tsit5()) - - -set_input!(integrator, sys.x, 1.0) -step!(integrator, 1.0, true) - -set_input!(integrator, sys.x, 2.0) -step!(integrator, 1.0, true) - -set_input!(integrator, sys.x, 3.0) -step!(integrator, 1.0, true) - -set_input!(integrator, sys.x, 4.0) -step!(integrator, 1.0, true) - -finalize!(integrator) - -@test integrator.sol(0.0; idxs=sys.x) == 1.0 -@test integrator.sol(1.0; idxs=sys.x) == 2.0 -@test integrator.sol(2.0; idxs=sys.x) == 3.0 -@test integrator.sol(3.0; idxs=sys.x) == 4.0 -@test integrator.sol(4.0; idxs=sys.y) ≈ 10.0 - -# determinate form ----------------------- -input = Input(sys.x, [1,2,3,4], [0,1,2,3]) -sol = solve(prob, [input], Tsit5()); - - -@test sol(0.0; idxs=sys.x) == 1.0 -@test sol(1.0; idxs=sys.x) == 2.0 -@test sol(2.0; idxs=sys.x) == 3.0 -@test sol(3.0; idxs=sys.x) == 4.0 -@test sol(4.0; idxs=sys.y) ≈ 10.0 +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq +using Test + +# ----------------------------------------- +# ----- example --------------------------- +# ----------------------------------------- + +vars = @variables begin + x(t), [input=true] + + # states + y(t) = 0 +end + +eqs = [ +# equations + D(y) ~ x +] + +@mtkcompile sys=System(eqs, t, vars, []) inputs=[x] + +@test !isnothing(ModelingToolkit.get_input_functions(sys)) + +prob = ODEProblem(sys, [], (0, 4)) + +# indeterminate form ----------------------- + +integrator = init(prob, Tsit5()) + +set_input!(integrator, sys.x, 1.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 2.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 3.0) +step!(integrator, 1.0, true) + +set_input!(integrator, sys.x, 4.0) +step!(integrator, 1.0, true) + +finalize!(integrator) + +@test integrator.sol(0.0; idxs = sys.x) == 1.0 +@test integrator.sol(1.0; idxs = sys.x) == 2.0 +@test integrator.sol(2.0; idxs = sys.x) == 3.0 +@test integrator.sol(3.0; idxs = sys.x) == 4.0 +@test integrator.sol(4.0; idxs = sys.y) ≈ 10.0 + +# determinate form ----------------------- +input = Input(sys.x, [1, 2, 3, 4], [0, 1, 2, 3]) +sol = solve(prob, [input], Tsit5()); + +@test sol(0.0; idxs = sys.x) == 1.0 +@test sol(1.0; idxs = sys.x) == 2.0 +@test sol(2.0; idxs = sys.x) == 3.0 +@test sol(3.0; idxs = sys.x) == 4.0 +@test sol(4.0; idxs = sys.y) ≈ 10.0 diff --git a/test/runtests.jl b/test/runtests.jl index 870ffa6433..344ee2f595 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -104,7 +104,7 @@ end @safetestset "Fractional Differential Equations Tests" include("fractional_to_ordinary.jl") end end - + if GROUP == "All" || GROUP == "SymbolicIndexingInterface" @safetestset "SymbolicIndexingInterface test" include("symbolic_indexing_interface.jl") @safetestset "SciML Problem Input Test" include("sciml_problem_inputs.jl") From 9f04501710cc5de72fd77032945cd5ea47dc2353 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Tue, 28 Oct 2025 19:22:06 -0400 Subject: [PATCH 10/13] format docs --- docs/src/basics/InputOutput.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/src/basics/InputOutput.md b/docs/src/basics/InputOutput.md index dc05672337..4568e10b0e 100644 --- a/docs/src/basics/InputOutput.md +++ b/docs/src/basics/InputOutput.md @@ -112,7 +112,7 @@ using Plots eqs = [D(y) ~ x] # Compile with inputs specified -@mtkcompile sys = System(eqs, t, [x, y], []) inputs=[x] +@mtkcompile sys=System(eqs, t, [x, y], []) inputs=[x] prob = ODEProblem(sys, [], (0, 4)) @@ -122,7 +122,7 @@ input = Input(sys.x, [1, 2, 3, 4], [0, 1, 2, 3]) # Solve with the input - solver handles callbacks automatically sol = solve(prob, [input], Tsit5()) -plot(sol; idxs=[x,y]) +plot(sol; idxs = [x, y]) ``` Multiple `Input` objects can be passed in a vector to handle multiple input variables simultaneously. @@ -151,10 +151,11 @@ step!(integrator, 1.0, true) # IMPORTANT: Must call finalize! to save all input callbacks finalize!(integrator) -plot(sol; idxs=[x,y]) +plot(sol; idxs = [x, y]) ``` !!! warning "Always call `finalize!`" + When using `set_input!`, you must call [`finalize!`](@ref) after integration is complete. This ensures that all discrete callbacks associated with input variables are properly saved in the solution. Without this call, input values may not be correctly recorded when querying the solution. ## Docstrings From 9301b01edea4edda4687b50e501efcc5404cc464 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Wed, 29 Oct 2025 07:24:36 -0400 Subject: [PATCH 11/13] fixes concrete type --- src/systems/inputs.jl | 4 ++-- test/inputs.jl | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/systems/inputs.jl b/src/systems/inputs.jl index 799d09f9b4..3b71b37d55 100644 --- a/src/systems/inputs.jl +++ b/src/systems/inputs.jl @@ -34,10 +34,10 @@ function Input(var, data::Vector{<:Real}, time::Vector{<:Real}) return Input(var, SVector{n}(data), SVector{n}(time)) end -struct InputFunctions +struct InputFunctions{S, O} events::Tuple{SymbolicDiscreteCallback} vars::Tuple{SymbolicUtils.BasicSymbolic{Real}} - setters::Tuple{SymbolicIndexingInterface.ParameterHookWrapper} + setters::Tuple{SymbolicIndexingInterface.ParameterHookWrapper{S, O}} end function InputFunctions(events::Vector, vars::Vector, setters::Vector) diff --git a/test/inputs.jl b/test/inputs.jl index 0f82693ff4..eb9314edd2 100644 --- a/test/inputs.jl +++ b/test/inputs.jl @@ -15,7 +15,6 @@ vars = @variables begin end eqs = [ -# equations D(y) ~ x ] From 3c003c9ed9fdc15f5179258f95c19148506435b7 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Wed, 29 Oct 2025 07:25:56 -0400 Subject: [PATCH 12/13] removed usings statements --- src/systems/inputs.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/systems/inputs.jl b/src/systems/inputs.jl index 3b71b37d55..048eb3db44 100644 --- a/src/systems/inputs.jl +++ b/src/systems/inputs.jl @@ -1,8 +1,3 @@ -using SymbolicIndexingInterface -using Setfield -using StaticArrays -using OrdinaryDiffEqCore - """ Input(var, data::Vector{<:Real}, time::Vector{<:Real}) From 0c2dd029ec18c24dcabf9a229cbe4fb373d499d4 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Wed, 29 Oct 2025 07:37:05 -0400 Subject: [PATCH 13/13] Fixed solve contract --- src/systems/inputs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/inputs.jl b/src/systems/inputs.jl index 048eb3db44..413ab5dbf7 100644 --- a/src/systems/inputs.jl +++ b/src/systems/inputs.jl @@ -163,7 +163,7 @@ function build_input_functions(sys, inputs) return sys end -function DiffEqBase.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Input}, args...; kwargs...) +function CommonSolve.solve(prob::SciMLBase.AbstractDEProblem, inputs::Vector{Input}, args...; kwargs...) tstops = Float64[] callbacks = DiscreteCallback[]