Skip to content

ForwardDiff Overload Fixes #629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jul 17, 2025
Merged

Conversation

jClugstor
Copy link
Member

@jClugstor jClugstor commented Jul 2, 2025

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Add any other context about the problem here.

@jClugstor
Copy link
Member Author

The overloads won't work with higher order derivatives or nested Duals. We would need a special way to handle higher order derivatives / maybe a general expression for x_p for higher order derivatives. So for now I just want to exclude nested Duals from going through this.
So I did this to try to exclude NestedDualLinearProblem from going through the overloads, but this is pretty restrictive. But if I do SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Number , P} then a NestedDualLinearProblem is also a NestedDualLinearProblem, which defeats the purpose. Essentially I want to make sure that if the value of the Dual is another Dual, go through the normal path and not the ForwardDiffExt path. @oscardssmith I wonder if you have any thoughts on how I could do that?

# Define type for non-nested dual numbers
const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Float64 , P}

# Define type for nested dual numbers
const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual, P}

const SingleDualLinearProblem = LinearProblem{
    <:Union{Number, <:AbstractArray, Nothing}, iip,
    <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
    <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
    <:Any
} where {iip}

const NestedDualLinearProblem = LinearProblem{
    <:Union{Number, <:AbstractArray, Nothing}, iip,
    <:Union{<:NestedDual, <:AbstractArray{<:NestedDual}},
    <:Union{<:NestedDual, <:AbstractArray{<:NestedDual}},
    <:Any
} where {iip}

const DualALinearProblem = LinearProblem{
    <:Union{Number, <:AbstractArray, Nothing},
    iip,
    <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
    <:Union{Number, <:AbstractArray},
    <:Any
} where {iip}

const DualBLinearProblem = LinearProblem{
    <:Union{Number, <:AbstractArray, Nothing},
    iip,
    <:Union{Number, <:AbstractArray},
    <:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
    <:Any
} where {iip}

const DualAbstractLinearProblem = Union{
    SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem}

@oscardssmith
Copy link
Member

I think making the higher order/nested cases work out shouldn't be too hard. Derivatives are a linear operator, so I think everything should compose nicely here.

@jClugstor
Copy link
Member Author

Using the general Leibniz rule for products I can get this:
image
For the partials we might need to use this:
image
I think handling the partials will be the trickiest part of this.

@jClugstor
Copy link
Member Author

@oscardssmith this actually does work recursively, I just had to make a few adjustments first. I'll test it with more nested Dual stuff but this should be good pretty soon actually.

@jClugstor jClugstor force-pushed the forwarddiff_overloads branch from dc898eb to 7abb243 Compare July 14, 2025 20:16
@jClugstor jClugstor marked this pull request as ready for review July 15, 2025 14:40
@jClugstor
Copy link
Member Author

@oscardssmith this is mostly ready to go, but there is one problem with one of the SciMLSensitivity tests that doesn't run in the integration tests. https://github.com/SciML/SciMLSensitivity.jl/blob/e2916ef299f5e39998e2b758fd43e29e993a18c7/test/stiff_adjoints.jl#L24
In the stiff adjoint tests when I try to differentiate loss_function the solve becomes unstable: Warning: At t=1.5507202241421e-7, dt was forced below floating point epsilon 2.6469779601696886e-23 ... .

This only happens on this branch of LinearSolve, and so far it only happens with RadauIIA5 and RadauIIA9. RadauIIA3 works, and so do all of the other implicit solvers I've tested.

using SciMLSensitivity
using OrdinaryDiffEq, ForwardDiff, Test
using ForwardDiff: Dual
using Zygote


function lotka_volterra(u, p, t)
    x, y = u
    α, β, δ, γ = p
    [α * x - β * x * y, -δ * y + γ * x * y]
end

function lotka_volterra(du, u, p, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = α * x - β * x * y
    du[2] = -δ * y + γ * x * y
end

u0 = [1.0, 1.0];
tspan = (0.0, 10.0);
p0 = [1.5, 1.0, 3.0, 1.0];
prob0 = ODEProblem{true, SciMLBase.FullSpecialize}(lotka_volterra, u0, tspan, p0);
# Solve the ODE and collect solutions at fixed intervals
target_data = solve(prob0, FBDF(), saveat=0:0.5:10.0, abstol=1e-10,
    reltol=1e-12)


loss_function = function (p)
    prob = remake(prob0; p = p)
    prediction = solve(prob, RadauIIA3(); u0=convert.(eltype(p), prob0.u0), saveat = 0.0:0.5:10.0, abstol=1e-10,
        reltol = 1e-12)
    tmpdata = prediction[[1, 2], :]
    tdata = target_data[[1, 2], :]
    #Main.@infiltrate
    # Calculate squared error
    return sum(abs2, tmpdata - tdata)
end
p = [2.0, 1.2, 3.3, 1.5];

fdgrad = ForwardDiff.gradient(loss_function, p)

Zygote.gradient(loss_function,p)

@oscardssmith
Copy link
Member

that's very odd. It's unclear to me why Radau would be any different here.

@oscardssmith
Copy link
Member

@ChrisRackauckas or @Shreyas-Ekanathan any idea why RadauIIA5 would be different here?

@ChrisRackauckas
Copy link
Member

Complex numbers?

@oscardssmith
Copy link
Member

Seems very plausible. @jClugstor you might need a rule for complex matrices.

@jClugstor
Copy link
Member Author

Yes, seems likely there's an issue there. Although it's strange because in the case of Complex{Dual{ ... values in A or b, the overloads aren't even being gone through, because then it isn't a DualLinearProblem.

@jClugstor
Copy link
Member Author

Maybe a silly question, but when we call dolinsolve (from OrdinaryDiffEqDifferentiation) like this

  linres1 = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff),
                linu = _vec(dw1))

we want dw1 to end up with the same value as linres1.u, right?

@oscardssmith
Copy link
Member

yeah this relies on some weird aliasing.

@jClugstor
Copy link
Member Author

Yeah that's my issue I believe. The complex stuff was a red herring, those linear solves don't even go through the overloads since they're Complex{Dual{ ... and not Dual{Complex{ .... I'll just need to fix the aliasing up.

@jClugstor
Copy link
Member Author

Ok, assuming that the MTK integration test is green, this all works with support for

  • Nested Duals, so Hessian / Jacobian of gradient works
  • Sparse / Float only solvers work, including KLUFactorization and UMFPACKFactorization
  • anything aliased with cache.u should also be updated when the cache is solved, preserving previous behavior
  • Default solvers are chosen based on the type of the lowest level number i.e. the number type that actually goes through the linear solver

But

  • this does not work with problems where values are Complex{Dual. I don't think it would be difficult to implement, but for now those type of problems will just have the old behavior and have the Duals go through the solvers, so it won't change anything

@ChrisRackauckas ChrisRackauckas merged commit 47eeb04 into SciML:main Jul 17, 2025
23 of 27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants