Skip to content

Commit

Permalink
Merge pull request #377 from SciML/ordering_update
Browse files Browse the repository at this point in the history
Updating the ordering used in MQS computation
  • Loading branch information
pogudingleb authored Jan 15, 2025
2 parents 85c2f52 + 619dc71 commit 2eeff8f
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 69 deletions.
6 changes: 3 additions & 3 deletions docs/src/tutorials/reparametrization.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ reparam[:new_ode]

In order to analyze this result, let us give more interpretable names to the new variables and parameters:

$I := I, \; \widetilde{E} := \alpha E, \widetilde{S} := \alpha, \; \widetilde{I} := \alpha (I + U), \; \gamma := \gamma,\;\delta := \delta,\;\widetilde{\beta} := \frac{\beta}{\alpha N}$
$I := I, \; \widetilde{E} := \alpha E, \widetilde{S} := \alpha, \; \widetilde{I} := \alpha (I + U), \; \gamma := \gamma,\;\delta := \delta,\;\widetilde{N} := \frac{\alpha N}{\beta}$

Then the reparametrize system becomes

$\begin{cases}
\widetilde{S}'(t) = -\widetilde{\beta} \widetilde{S}(t) \widetilde{I}(t),\\
\widetilde{E}'(t) = \widetilde{\beta} \widetilde{S}(t) \widetilde{I}(t) - \gamma \widetilde{E}(t),\\
\widetilde{S}'(t) = -\widetilde{S}(t) \widetilde{I}(t) / \widetilde{N},\\
\widetilde{E}'(t) = \widetilde{S}(t) \widetilde{I}(t) / \widetilde{N} - \gamma \widetilde{E}(t),\\
\widetilde{I}'(t) = -\delta \widetilde{I}(t) + \gamma\widetilde{E}(t),\\
I'(t) = \gamma\widetilde{E}(t) - \delta I(t),\\
y(t) = I(t)
Expand Down
16 changes: 12 additions & 4 deletions src/RationalFunctionFields/RationalFunctionField.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ Applies the following passes:
rff::RationalFunctionField;
discard_redundant = true,
reversed_order = false,
priority_variables = [],
)
time_start = time_ns()
fracs = dennums_to_fractions(rff.dennums)
Expand All @@ -290,7 +291,11 @@ Applies the following passes:
end
# Remove redundant pass
if discard_redundant
sort!(fracs, lt = rational_function_cmp)
fracs_priority = filter(f -> issubset(vars(f), priority_variables), fracs)
fracs_rest = filter(f -> !(f in fracs_priority), fracs)
sort!(fracs_priority, lt = rational_function_cmp)
sort!(fracs_rest, lt = rational_function_cmp)
fracs = vcat(fracs_priority, fracs_rest)
@debug "The pool of fractions:\n$(join(map(repr, fracs), ",\n"))"
if reversed_order
non_redundant = collect(1:length(fracs))
Expand Down Expand Up @@ -323,7 +328,7 @@ Applies the following passes:
@debug "Out of $(length(fracs)) simplified generators there are $(length(non_redundant)) non redundant"
fracs = fracs[non_redundant]
end
sort!(fracs, lt = rational_function_cmp)
sort!(fracs, lt = (f, g) -> rational_function_cmp(f, g))
spring_cleaning_pass!(fracs)
_runtime_logger[:id_beautifulization] += (time_ns() - time_start) / 1e9
return fracs
Expand Down Expand Up @@ -552,6 +557,7 @@ Result is correct (in the Monte-Carlo sense) with probability at least `prob_thr
simplify = :standard,
check_variables = false, # almost always slows down and thus turned off
rational_interpolator = :VanDerHoevenLecerf,
priority_variables = [],
)
@info "Simplifying generating set. Simplification level: $simplify"
_runtime_logger[:id_groebner_time] = 0.0
Expand Down Expand Up @@ -620,8 +626,10 @@ Result is correct (in the Monte-Carlo sense) with probability at least `prob_thr
@debug """
Final cleaning and simplification of generators.
Out of $(length(new_fracs)) fractions $(length(new_fracs_unique)) are syntactically unique."""
runtime =
@elapsed new_fracs = beautiful_generators(RationalFunctionField(new_fracs_unique))
runtime = @elapsed new_fracs = beautiful_generators(
RationalFunctionField(new_fracs_unique),
priority_variables = priority_variables,
)
@debug "Checking inclusion with probability $prob_threshold"
runtime =
@elapsed result = issubfield(rff, RationalFunctionField(new_fracs), prob_threshold)
Expand Down
21 changes: 9 additions & 12 deletions src/RationalFunctionFields/rankings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function generating_set_rank(funcs)
end

"""
rational_function_cmp(f, g; by=:naive)
rational_function_cmp(f, g; by=:naive, priority_variables=[])
Returns `true` if `f < g`.
Expand All @@ -47,27 +47,24 @@ permutation.*
Provides keyword argument `by`, a sorting order. Possible options are:
- `:rank`: Compare by rank.
- `:naive`: Compare features one by one. Features in the order of importance:
- Constant fractions are smaller.
- Fractions with constant denominators are smaller.
- Fractions with less terms and total degree are smaller.
- Fractions with smaller leading monomial in numerator / denominator are
smaller.
- Sum of total degrees
- Total number of terms
- Total degree of the denominator
- Leading terms of denominator and numerator
"""
function rational_function_cmp(f, g; by = :naive)
if by === :naive
flag = compare_rational_func_by(f, g, !is_constant)
flag == 1 && return false
flag == -1 && return true
flag = compare_rational_func_by(f, g, is_constant, :denominator)
flag = compare_rational_func_by(f, g, total_degree, :additive)
flag == 1 && return false
flag == -1 && return true
flag = compare_rational_func_by(f, g, length, :additive)
flag == 1 && return false
flag == -1 && return true
flag = compare_rational_func_by(f, g, total_degree)
flag = compare_rational_func_by(f, g, total_degree, :denominator)
flag == 1 && return false
flag == -1 && return true
flag = compare_rational_func_by(f, g, leading_monomial)
# promotes constants in denominators
flag = compare_rational_func_by(f, g, leading_monomial, :denominator)
flag == 1 && return false
flag == -1 && return true
flag = compare_rational_func_by(f, g, collect monomials)
Expand Down
18 changes: 2 additions & 16 deletions src/global_identifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ are identifiable functions containing or not the state variables
Dict(:with_states => Array{Array{P, 1}, 1}(), :no_states => Array{Array{P, 1}, 1}())
varnames = [var_to_str(p) for p in ode.parameters]
if with_states
append!(varnames, map(var_to_str, ode.x_vars))
varnames = vcat(map(var_to_str, ode.x_vars), varnames)
end
bring, _ = Nemo.polynomial_ring(base_ring(ode.poly_ring), varnames)

Expand Down Expand Up @@ -60,21 +60,7 @@ are identifiable functions containing or not the state variables
end
end

# Returned entities live in a new ring, different from the one
# attached to the input ODE.
# The new ring includes only parameter variables and, occasionally, states.
new_vars = ode.parameters
if with_states
new_vars = vcat(new_vars, ode.x_vars)
end
new_ring, _ = polynomial_ring(Nemo.QQ, map(Symbol, new_vars))
new_coeff_lists = empty(coeff_lists)
for (key, coeff_list) in coeff_lists
new_coeff_lists[key] =
map(coeffs -> map(c -> parent_ring_change(c, new_ring), coeffs), coeff_list)
end

return new_coeff_lists, new_ring
return coeff_lists, bring
end

# ------------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions src/identifiable_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ function _find_identifiable_functions(
seed = seed,
simplify = simplify,
rational_interpolator = rational_interpolator,
priority_variables = [parent_ring_change(p, bring) for p in ode.parameters],
)
else
id_funcs_fracs = dennums_to_fractions(id_funcs)
Expand Down
11 changes: 11 additions & 0 deletions test/RationalFunctionFields/rational_function_cmp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@testset "Rational function comparison" begin
R, (a, b, c) = QQ["a", "b", "c"]

@test rational_function_cmp(a, b * c)
@test rational_function_cmp(a, b + c)
@test rational_function_cmp(a * b, a // b)
@test rational_function_cmp(a // b, b // a)
@test rational_function_cmp(a // b, a * b + b * c + c^2)
@test rational_function_cmp(a^2, a^3)
@test rational_function_cmp((a^2 + a * b + c) // (a - b), (a - b) // (a^2 + a * b + c))
end
85 changes: 52 additions & 33 deletions test/identifiable_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ ode = StructuralIdentifiability.@ODEmodel(
x'(t) = (-V_m * x(t)) / (k_m + x(t)) + k01 * x(t),
y(t) = c * x(t)
)
ident_funcs = [k01, k_m // V_m, V_m * c]
ident_funcs = [k01, c * k_m, V_m * c]
push!(test_cases, (ode = ode, ident_funcs = ident_funcs))

# Parameters b and c enter the io equations only as the product b * c.
Expand Down Expand Up @@ -72,7 +72,7 @@ ode = StructuralIdentifiability.@ODEmodel(
x2'(t) = k2 * x1(t) - (k3 + k4) * x2(t),
y(t) = x1(t)
)
ident_funcs = [(k1 + k2), (k1 + k2 + k3 + k4), ((k1 + k2) * (k3 + k4) - k2 * k3)]
ident_funcs = [k1 + k2, k3 + k4, k2 * k3]
push!(test_cases, (ode = ode, ident_funcs = ident_funcs))

# Diagonal with simple spectrum and observable states
Expand Down Expand Up @@ -123,7 +123,7 @@ ode = StructuralIdentifiability.@ODEmodel(
x2'(t) = p3 * x1^2 + p4 * x1 * x2,
y(t) = x1
)
ident_funcs = [p1 + p4, p2 * p3 - p4 * p1]
ident_funcs = [p1 + p4, -p2 * p3 + p4 * p1]
push!(test_cases, (ode = ode, ident_funcs = ident_funcs))

# Goowdin oscillator
Expand Down Expand Up @@ -159,7 +159,7 @@ ode = StructuralIdentifiability.@ODEmodel(
Q'(t) = -gamma * Q(t) + psi * I(t),
y1(t) = Q(t)
)
ident_funcs = [gamma, beta // psi, gamma * psi - v - psi, gamma * psi - v * psi]
ident_funcs = [gamma * psi - psi * v, beta // psi, gamma, psi * v - psi - v]
push!(test_cases, (ode = ode, ident_funcs = ident_funcs))

# Bilirubin2_io.
Expand Down Expand Up @@ -211,7 +211,7 @@ ode = StructuralIdentifiability.@ODEmodel(
y1(t) = x4(t),
y2(t) = x5(t)
)
ident_funcs = [k7, k6, k5, k10^2, k9 * k10, k8 + 1 // 2 * k10]
ident_funcs = [k7, k5, k6, k10 * k9, k9^2, k10 + 2 * k8]
push!(test_cases, (ode = ode, ident_funcs = ident_funcs))

# SLIQR
Expand All @@ -223,17 +223,15 @@ ode = StructuralIdentifiability.@ODEmodel(
y(t) = In(t) * Ninv
)
ident_funcs = [
g + a,
s + g + a,
(a * e) // (a + e * s - s),
b,
a + g,
(
a^2 * e * s + a^2 * g + 3 * a * e * g * s - a * e * s^2 - 2 * a * g * s +
e^2 * g * s^2 - 2 * e * g * s^2 + g * s^2
) // (a + e * s - s),
s,
Ninv,
b,
(e * s - s + a) // (e * s^2 * g - s^2 * g - s^2 * a + s * g * a + s * a^2),
e * s * g + s * a + g * a,
(e * s^2 + e * s * g - s^2 - s * g + g * a + a^2) //
(e * s^2 * g - s^2 * g - s^2 * a + s * g * a + s * a^2),
e * s * g * a,
2 * e * Ninv * s * g + 2 * Ninv * s * a + 2 * Ninv * g * a,
]
push!(test_cases, (ode = ode, ident_funcs = ident_funcs))

Expand All @@ -251,7 +249,7 @@ ident_funcs = [
T,
Dd,
e - rR + dr * T + d * T + g - r + a * T,
(2 * rR * d - 2 * dr * r) // (dr - d),
(d * rR - dr * r) // (d - dr),
(dr^2 + d^2 + 2 * d * a + a^2) // (dr * d + dr * a),
(e * dr - e * d + rR * a + dr * g - d * g - r * a) // (dr - d),
(e * dr^2 - e * dr * d + rR * dr * a + dr * d * g - dr * r * a - d^2 * g) //
Expand Down Expand Up @@ -826,7 +824,7 @@ ident_funcs = [α, x1^2 + x2^2]
push!(test_cases, (ode = ode, with_states = true, ident_funcs = ident_funcs))

ode = StructuralIdentifiability.@ODEmodel(x'(t) = a * x(t) + b * u(t), y(t) = c * x(t))
ident_funcs = [b * c, a, x // b]
ident_funcs = [a, b * c, x * c]
push!(test_cases, (ode = ode, with_states = true, ident_funcs = ident_funcs))

# llw1987 model
Expand All @@ -843,7 +841,7 @@ ident_funcs = [
p2 * p4 // one(x1),
(p3 + p1) // one(x1),
(p2 * x2 + p4 * x1) // one(x1),
(p3 - p1) // (p2 * x2 - p4 * x1),
(p2 * x2 - p4 * x1) // (p3 - p1),
]
push!(test_cases, (ode = ode, with_states = true, ident_funcs = ident_funcs))

Expand Down Expand Up @@ -883,26 +881,26 @@ ode = StructuralIdentifiability.@ODEmodel(
y3(t) = a3 * pS6(t)
)
ident_funcs = [
(EGF_EGFR * reaction_9_k1) // pS6,
reaction_8_k1,
a3 // reaction_5_k1,
reaction_3_k1,
reaction_2_k2,
reaction_2_k1 // reaction_5_k1,
a1 // reaction_5_k1,
pAkt * reaction_5_k1,
pEGFR * reaction_5_k1,
pS6 * reaction_5_k1,
reaction_5_k2,
reaction_3_k1,
reaction_6_k1,
a2 // reaction_5_k1,
reaction_7_k1,
reaction_4_k1,
reaction_2_k1 * pAkt_S6,
reaction_2_k1 * S6,
reaction_5_k1 * pAkt_S6,
pEGFR_Akt * reaction_2_k1,
Akt * reaction_2_k1,
a1 * pAkt_S6,
pEGFR * reaction_2_k1,
pAkt * reaction_2_k1,
a3 * pAkt_S6,
pS6 * reaction_2_k1,
a2 * pAkt_S6,
reaction_9_k1 * reaction_2_k1 * EGF_EGFR,
reaction_1_k1 - reaction_9_k1 - reaction_1_k2,
pAkt_S6 * reaction_5_k1,
reaction_1_k1 - reaction_1_k2 - reaction_9_k1,
pEGFR_Akt * reaction_5_k1,
S6 * reaction_5_k1,
Akt * reaction_5_k1,
]
push!(test_cases, (ode = ode, with_states = true, ident_funcs = ident_funcs))

Expand Down Expand Up @@ -933,18 +931,34 @@ ode = StructuralIdentifiability.@ODEmodel(
y2(t) = k6 * (x1 + x2 + 2x3),
y3(t) = k7 * EpoR_A
)
ident_funcs = [k3, k1 // k7, k5 // k2, k6 // k2, k7 * EpoR_A]
ident_funcs = [k2 // k6, k3, EpoR_A * k7, EpoR_A * k1, k5 // k6]
push!(test_cases, (ode = ode, ident_funcs = ident_funcs))

ode = @ODEmodel(x1'(t) = x1, x2'(t) = x2, y(t) = x1 + x2(t))
ident_funcs = [x1 + x2]
push!(test_cases, (ode = ode, ident_funcs = ident_funcs, with_states = true))

# SEUIR model from the reparametrization docs
# https://docs.sciml.ai/StructuralIdentifiability/stable/tutorials/reparametrization/
ode = @ODEmodel(
S'(t) = -b * (U(t) + I(t)) * S(t) / N,
E'(t) = b * (U(t) + I(t)) * S(t) / N - g * E(t),
U'(t) = (1 - a) * g * E(t) - d * U(t),
I'(t) = a * g * E(t) - d * I(t),
y(t) = I(t)
)
ident_funcs = [I, d, g, S * a, E * a, U * a + I * a, (N * a) // b]
push!(test_cases, (ode = ode, ident_funcs = ident_funcs, with_states = true))

ode = @ODEmodel(x'(t) = alpha * x(t)^2, y(t) = x(t)^2)
ident_funcs = [alpha^2, x * alpha]
push!(test_cases, (ode = ode, ident_funcs = ident_funcs, with_states = true))

# TODO: verify that Maple returns the same
@testset "Identifiable functions of parameters" begin
p = 0.99
for case in test_cases
for simplify in [:weak, :standard] # :strong]
for simplify in [:weak, :standard] #:strong?
ode = case.ode
true_ident_funcs = case.ident_funcs
with_states = false
Expand Down Expand Up @@ -980,6 +994,11 @@ push!(test_cases, (ode = ode, ident_funcs = ident_funcs, with_states = true))
StructuralIdentifiability.RationalFunctionField(true_ident_funcs),
p,
)
if simplify != :weak
@info gens(parent(numerator(first(result_funcs))))
# To keep track of changes in the simplification
@test Set(result_funcs) == Set(true_ident_funcs)
end
end
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ using StructuralIdentifiability:
x_equations,
y_equations,
inputs,
quotient_basis
quotient_basis,
rational_function_cmp

const GROUP = get(ENV, "GROUP", "All")

Expand Down

0 comments on commit 2eeff8f

Please sign in to comment.