Skip to content

Commit 6b22f32

Browse files
fix: make some SII impls of IndexCache more type-stable
1 parent a2d9efa commit 6b22f32

File tree

1 file changed

+16
-23
lines changed

1 file changed

+16
-23
lines changed

src/systems/index_cache.jl

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -410,14 +410,16 @@ function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym)
410410
parameter_index(ic, sym) !== nothing
411411
end
412412

413-
function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
414-
if sym isa Symbol
415-
sym = get(ic.symbol_to_variable, sym, nothing)
416-
sym === nothing && return nothing
417-
end
418-
sym = unwrap(sym)
419-
validate_size = Symbolics.isarraysymbolic(sym) && symtype(sym) <: AbstractArray &&
420-
symbolic_has_known_size(sym)
413+
function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap})
414+
parameter_index(ic, unwrap(sym))
415+
end
416+
function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::Symbol)
417+
sym = get(ic.symbol_to_variable, sym, nothing)
418+
sym === nothing && return nothing
419+
parameter_index(ic, sym)
420+
end
421+
function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::SymbolicT)
422+
validate_size = Symbolics.isarraysymbolic(sym) && symbolic_has_known_size(sym)
421423
return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing
422424
ParameterIndex(SciMLStructures.Tunable(), idx, validate_size)
423425
elseif (idx = check_index_map(ic.initials_idx, sym)) !== nothing
@@ -464,23 +466,14 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy
464466
idx.timeseries_idx, (idx.parameter_idx..., args[2:end]...))
465467
end
466468

467-
function check_index_map(idxmap, sym)
468-
if (idx = get(idxmap, sym, nothing)) !== nothing
469-
return idx
470-
elseif !isa(sym, Symbol) && (!iscall(sym) || operation(sym) !== getindex) &&
471-
hasname(sym) && (idx = get(idxmap, getname(sym), nothing)) !== nothing
472-
return idx
473-
end
469+
function check_index_map(idxmap::Dict{SymbolicT, V}, sym::SymbolicT)::Union{V, Nothing} where {V}
470+
idx = get(idxmap, sym, nothing)
471+
idx === nothing || return idx
474472
dsym = default_toterm(sym)
475473
isequal(sym, dsym) && return nothing
476-
if (idx = get(idxmap, dsym, nothing)) !== nothing
477-
idx
478-
elseif !isa(dsym, Symbol) && (!iscall(dsym) || operation(dsym) !== getindex) &&
479-
hasname(dsym) && (idx = get(idxmap, getname(dsym), nothing)) !== nothing
480-
idx
481-
else
482-
nothing
483-
end
474+
idx = get(idxmap, dsym, nothing)
475+
idx === nothing || return idx
476+
return nothing
484477
end
485478

486479
function reorder_parameters(

0 commit comments

Comments
 (0)