Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solution interpolation does not infer type #2610

Open
hersle opened this issue Feb 22, 2025 · 14 comments
Open

Solution interpolation does not infer type #2610

hersle opened this issue Feb 22, 2025 · 14 comments
Assignees
Labels

Comments

@hersle
Copy link

hersle commented Feb 22, 2025

The simple example

using OrdinaryDiffEq
prob = ODEProblem((u, p, t) -> [1.0], [0.0], (0.0, 1.0))
sol = solve(prob, Tsit5())
ts = range(0.0, 1.0, length = 100000)
@code_warntype sol(ts)

infers the output only as Any:

MethodInstance for (::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, ODEFunction{false, SciMLBase.AutoSpecialize, var"#17#18", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, OrdinaryDiffEqCore.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#17#18", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing})(::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64})
  from (sol::SciMLBase.AbstractODESolution)(t; ...) @ SciMLBase C:\Users\herma\.julia\packages\SciMLBase\dJpGC\src\solutions\ode_solutions.jl:216
Arguments
  sol::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, ODEFunction{false, SciMLBase.AutoSpecialize, var"#17#18", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, OrdinaryDiffEqCore.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#17#18", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing}
  t::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}
Body::Any
1 ─ %1 = Core.apply_type(SciMLBase.Val, 0)::Core.Const(Val{0})
│   %2 = (sol)(t, %1)::Any
└──      return %2

This is what a flamegraph looks like:

@profview sol(ts)

Image

Is this to be expected?

@hersle hersle added the bug label Feb 22, 2025
@oscardssmith oscardssmith self-assigned this Feb 22, 2025
@ChrisRackauckas
Copy link
Member

That's not intentional.

@hersle
Copy link
Author

hersle commented Feb 22, 2025

With

using Cthulhu
@descend sol(ts)

I have digged down to this as a possible source (screenshot for colors), where it doesn't seem infer the vector element type?

Image

This also looks suspiciously similar to #1270.

@ChrisRackauckas
Copy link
Member

Yeah isolate that do function: why is the element type not inferable there? It seems like it should be fine...

@hersle
Copy link
Author

hersle commented Feb 22, 2025

Separated the do j ... into a separate function func(j, all_needed_args). Calling e.g. func(1, ...) by itself inside ode_interpolation is inferred. But it is not inferred if I do [func(1, ...) for i in 1:10] 😕

@ChrisRackauckas
Copy link
Member

Alright... I'll leave that to @oscardssmith to figure out then 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅 😅

@devmotion
Copy link
Member

But it is not inferred if I do [func(1, ...) for i in 1:10]

Did you try the let-block trick/workaround?

@hersle
Copy link
Author

hersle commented Feb 22, 2025

Did you try the let-block trick/workaround?

What do you mean?

@hersle
Copy link
Author

hersle commented Feb 22, 2025

I can at least confirm the Any issue goes away if I "enforce" the first element of vals like so (in ode_interpolation):

...
vals = [dofunc(tvals[1], ts, tdir, timeseries, cache, idxs, deriv, ks, id, p, differential_vars, i₋₊ref, f)]
for t in tvals[2:end]
    push!(vals, dofunc(t, ts, tdir, timeseries, cache, idxs, deriv, ks, id, p, differential_vars, i₋₊ref, f))
end
vals = vals[idx]
...

But I don't think this is an optimal solution.

@devmotion
Copy link
Member

devmotion commented Feb 22, 2025

What do you mean?

That it does not infer anymore in the [... for ...] list comprehension, suggests that this might be an instance of JuliaLang/julia#15276. Closures that capture variables are known to be problematic, and the common workaround for this issue is to put a let ... end around the closure with the captured variables: JuliaLang/julia#15276 (comment) https://github.com/c42f/FastClosures.jl

@hersle
Copy link
Author

hersle commented Feb 22, 2025

Hmm ok, thanks. That's a rabbit hole I didn't know about 😅 I'm not sure I got it right, but I tried this ugly thing, which does not help:

function ode_interpolation(tvals, id::I, idxs, deriv::D, p,
        continuity::Symbol = :left) where {I, D}
    @unpack ts, timeseries, ks, f, cache, differential_vars = id
    @inbounds tdir = sign(ts[end] - ts[1])
    idx = sortperm(tvals, rev = tdir < 0)
    # start the search thinking it's ts[1]-ts[2]
    i₋₊ref = Ref((1, 2))
    vals = map(idx) do j
        let j=j, tvals=tvals, tdir=tdir, i₋₊ref=i₋₊ref, ts=ts, timeseries=timeseries, ks=ks, f=f, cache=cache, differential_vars=differential_vars, continuity=continuity, id=id, idxs=idxs, deriv=deriv, p=p
            dofunc(j, tvals, tdir, i₋₊ref, ts, timeseries, ks, f, cache, differential_vars, continuity, id, idxs, deriv, p) # moved everything in the "do" into this function
        end
    end
    invpermute!(vals, idx)
    DiffEqArray(vals, tvals)
end

@devmotion
Copy link
Member

You'd have to swap the map(idx) do j and the let tvals = tvals, tdir = tdir, ... (without the j=j): The closure, i.e., the whole do j ... end has to be inside of a let block.

@hersle
Copy link
Author

hersle commented Feb 22, 2025

Oh, of course, now I get it:

function ode_interpolation(tvals, id::I, idxs, deriv::D, p,
        continuity::Symbol = :left) where {I, D}
    @unpack ts, timeseries, ks, f, cache, differential_vars = id
    @inbounds tdir = sign(ts[end] - ts[1])
    idx = sortperm(tvals, rev = tdir < 0)
    # start the search thinking it's ts[1]-ts[2]
    i₋₊ref = Ref((1, 2))
    vals = nothing # must be declared outside "let"; should at least make types infer to Union{Nothing, ...}?
    let tvals=tvals, tdir=tdir, i₋₊ref=i₋₊ref, ts=ts, timeseries=timeseries, ks=ks, f=f, cache=cache, differential_vars=differential_vars, continuity=continuity, id=id, idxs=idxs, deriv=deriv, p=p
        vals = map(idx) do j
            dofunc(j, tvals, tdir, i₋₊ref, ts, timeseries, ks, f, cache, differential_vars, continuity, id, idxs, deriv, p) # moved everything in the "do" into this function
        end
    end
    invpermute!(vals, idx)
    DiffEqArray(vals, tvals)
end

Still doesn't help, unfortunately :( But thank you for the suggestion!

@devmotion
Copy link
Member

One additional comment: You don't need the vals = nothing, you could instead write

vals = let ...
    map(..) do ..
        ...
    end
end

Another comment: In my experience, tuples sometimes lead to better type inference in broadcasting etc than Refs. You could check whether replacing the Ref with a tuple helps (if that doesn't break the function call).

@hersle
Copy link
Author

hersle commented Feb 22, 2025

Thanks, appreciate the tips. It also runs, but doesn't help the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants