Skip to content
1 change: 1 addition & 0 deletions src/structural_transformation/pantelides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
D(eq.lhs)
end
rhs = ModelingToolkit.expand_derivatives(D(eq.rhs))
rhs = fast_substitute(rhs, state.param_derivative_map)
substitution_dict = Dict(x.lhs => x.rhs
for x in out_eqs if x !== nothing && x.lhs isa Symbolic)
sub_rhs = substitute(rhs, substitution_dict)
Expand Down
18 changes: 17 additions & 1 deletion src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,23 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int; kwargs...)

sys = ts.sys
eq = equations(ts)[ieq]
eq = 0 ~ Symbolics.derivative(eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true)
eq = 0 ~ fast_substitute(
ModelingToolkit.derivative(
eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true), ts.param_derivative_map)

vs = ModelingToolkit.vars(eq.rhs)
for v in vs
# parameters with unknown derivatives have a value of `nothing` in the map,
# so use `missing` as the default.
get(ts.param_derivative_map, v, missing) === nothing || continue
_original_eq = equations(ts)[ieq]
error("""
Encountered derivative of discrete variable `$(only(arguments(v)))` when \
differentiating equation `$(_original_eq)`. This may indicate a model error or a \
missing equation of the form `$v ~ ...` that defines this derivative.
""")
end

push!(equations(ts), eq)
# Analyze the new equation and update the graph/solvable_graph
# First, copy the previous incidence and add the derivative terms.
Expand Down
31 changes: 30 additions & 1 deletion src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
fullvars::Vector
structure::SystemStructure
extra_eqs::Vector
param_derivative_map::Dict{BasicSymbolic, Any}
end

TransformationState(sys::AbstractSystem) = TearingState(sys)
Expand Down Expand Up @@ -253,6 +254,12 @@ function Base.push!(ev::EquationsView, eq)
push!(ev.ts.extra_eqs, eq)
end

function is_time_dependent_parameter(p, iv)
return iv !== nothing && isparameter(p) && iscall(p) &&
(operation(p) === getindex && is_time_dependent_parameter(arguments(p)[1], iv) ||
(args = arguments(p); length(args)) == 1 && isequal(only(args), iv))
end

function TearingState(sys; quick_cancel = false, check = true)
sys = flatten(sys)
ivs = independent_variables(sys)
Expand All @@ -264,6 +271,7 @@ function TearingState(sys; quick_cancel = false, check = true)
var2idx = Dict{Any, Int}()
symbolic_incidence = []
fullvars = []
param_derivative_map = Dict{BasicSymbolic, Any}()
var_counter = Ref(0)
var_types = VariableType[]
addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
Expand All @@ -276,11 +284,23 @@ function TearingState(sys; quick_cancel = false, check = true)

vars = OrderedSet()
varsvec = []
eqs_to_retain = trues(length(eqs))
for (i, eq′) in enumerate(eqs)
if eq′.lhs isa Connection
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
return nothing
end
if iscall(eq′.lhs) && (op = operation(eq′.lhs)) isa Differential &&
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq′.lhs)), iv)
# parameter derivatives are opted out by specifying `D(p) ~ missing`, but
# we want to store `nothing` in the map because that means `fast_substitute`
# will ignore the rule. We will this identify the presence of `eq′.lhs` in
# the differentiated expression and error.
param_derivative_map[eq′.lhs] = coalesce(eq′.rhs, nothing)
eqs_to_retain[i] = false
# change the equation if the RHS is `missing` so the rest of this loop works
eq′ = eq′.lhs ~ coalesce(eq′.rhs, 0.0)
end
if _iszero(eq′.lhs)
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
eq = eq′
Expand All @@ -295,6 +315,12 @@ function TearingState(sys; quick_cancel = false, check = true)
any(isequal(_var), ivs) && continue
if isparameter(_var) ||
(iscall(_var) && isparameter(operation(_var)) || isconstant(_var))
if is_time_dependent_parameter(_var, iv) &&
!haskey(param_derivative_map, Differential(iv)(_var))
# Parameter derivatives default to zero - they stay constant
# between callbacks
param_derivative_map[Differential(iv)(_var)] = 0.0
end
continue
end
v = scalarize(v)
Expand Down Expand Up @@ -351,6 +377,9 @@ function TearingState(sys; quick_cancel = false, check = true)
eqs[i] = eqs[i].lhs ~ rhs
end
end
eqs = eqs[eqs_to_retain]
neqs = length(eqs)
symbolic_incidence = symbolic_incidence[eqs_to_retain]

### Handle discrete variables
lowest_shift = Dict()
Expand Down Expand Up @@ -438,7 +467,7 @@ function TearingState(sys; quick_cancel = false, check = true)
ts = TearingState(sys, fullvars,
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
Any[])
Any[], param_derivative_map)
if sys isa DiscreteSystem
ts = shift_discrete_system(ts)
end
Expand Down
2 changes: 1 addition & 1 deletion test/state_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using ModelingToolkit, OrdinaryDiffEq, Test
using ModelingToolkit: t_nounits as t, D_nounits as D

sts = @variables x1(t) x2(t) x3(t) x4(t)
params = @parameters u1(t) u2(t) u3(t) u4(t)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hehe

params = @parameters u1 u2 u3 u4
eqs = [x1 + x2 + u1 ~ 0
x1 + x2 + x3 + u2 ~ 0
x1 + D(x3) + x4 + u3 ~ 0
Expand Down
109 changes: 109 additions & 0 deletions test/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using SparseArrays
using UnPack
using ModelingToolkit: t_nounits as t, D_nounits as D, default_toterm
using Symbolics: unwrap
using DataInterpolations
const ST = StructuralTransformations

# Define some variables
Expand Down Expand Up @@ -282,3 +283,111 @@ end
@test length(mapping) == 3
end
end

@testset "Issue#3480: Derivatives of time-dependent parameters" begin
@component function FilteredInput(; name, x0 = 0, T = 0.1)
params = @parameters begin
k(t) = x0
T = T
end
vars = @variables begin
x(t) = k
dx(t) = 0
ddx(t)
end
systems = []
eqs = [D(x) ~ dx
D(dx) ~ ddx
dx ~ (k - x) / T]
return ODESystem(eqs, t, vars, params; systems, name)
end

@component function FilteredInputExplicit(; name, x0 = 0, T = 0.1)
params = @parameters begin
k(t)[1:1] = [x0]
T = T
end
vars = @variables begin
x(t) = k
dx(t) = 0
ddx(t)
end
systems = []
eqs = [D(x) ~ dx
D(dx) ~ ddx
D(k[1]) ~ 1.0
dx ~ (k[1] - x) / T]
return ODESystem(eqs, t, vars, params; systems, name)
end

@component function FilteredInputErr(; name, x0 = 0, T = 0.1)
params = @parameters begin
k(t) = x0
T = T
end
vars = @variables begin
x(t) = k
dx(t) = 0
ddx(t)
end
systems = []
eqs = [D(x) ~ dx
D(dx) ~ ddx
dx ~ (k - x) / T
D(k) ~ missing]
return ODESystem(eqs, t, vars, params; systems, name)
end

@named sys = FilteredInputErr()
@test_throws ["derivative of discrete variable", "k(t)"] structural_simplify(sys)

@mtkbuild sys = FilteredInput()
vs = Set()
for eq in equations(sys)
ModelingToolkit.vars!(vs, eq)
end
for eq in observed(sys)
ModelingToolkit.vars!(vs, eq)
end

@test !(D(sys.k) in vs)

@mtkbuild sys = FilteredInputExplicit()
obsfn1 = ModelingToolkit.build_explicit_observed_function(sys, sys.ddx)
obsfn2 = ModelingToolkit.build_explicit_observed_function(sys, sys.dx)
u = [1.0]
p = MTKParameters(sys, [sys.k => [2.0], sys.T => 3.0])
@test obsfn1(u, p, 0.0) ≈ (1 - obsfn2(u, p, 0.0)) / 3.0

@testset "Called parameter still has derivative" begin
@component function FilteredInput2(; name, x0 = 0, T = 0.1)
ts = collect(0.0:0.1:10.0)
spline = LinearInterpolation(ts .^ 2, ts)
params = @parameters begin
(k::LinearInterpolation)(..) = spline
T = T
end
vars = @variables begin
x(t) = k(t)
dx(t) = 0
ddx(t)
end
systems = []
eqs = [D(x) ~ dx
D(dx) ~ ddx
dx ~ (k(t) - x) / T]
return ODESystem(eqs, t, vars, params; systems, name)
end

@mtkbuild sys = FilteredInput2()
vs = Set()
for eq in equations(sys)
ModelingToolkit.vars!(vs, eq)
end
for eq in observed(sys)
ModelingToolkit.vars!(vs, eq)
end

@test D(sys.k(t)) in vs
end
end
Loading