-
-
Notifications
You must be signed in to change notification settings - Fork 62
ForwardDiff Overload Fixes #629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The overloads won't work with higher order derivatives or nested Duals. We would need a special way to handle higher order derivatives / maybe a general expression for
|
I think making the higher order/nested cases work out shouldn't be too hard. Derivatives are a linear operator, so I think everything should compose nicely here. |
@oscardssmith this actually does work recursively, I just had to make a few adjustments first. I'll test it with more nested Dual stuff but this should be good pretty soon actually. |
dc898eb
to
7abb243
Compare
@oscardssmith this is mostly ready to go, but there is one problem with one of the SciMLSensitivity tests that doesn't run in the integration tests. https://github.com/SciML/SciMLSensitivity.jl/blob/e2916ef299f5e39998e2b758fd43e29e993a18c7/test/stiff_adjoints.jl#L24 This only happens on this branch of LinearSolve, and so far it only happens with using SciMLSensitivity
using OrdinaryDiffEq, ForwardDiff, Test
using ForwardDiff: Dual
using Zygote
function lotka_volterra(u, p, t)
x, y = u
α, β, δ, γ = p
[α * x - β * x * y, -δ * y + γ * x * y]
end
function lotka_volterra(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = α * x - β * x * y
du[2] = -δ * y + γ * x * y
end
u0 = [1.0, 1.0];
tspan = (0.0, 10.0);
p0 = [1.5, 1.0, 3.0, 1.0];
prob0 = ODEProblem{true, SciMLBase.FullSpecialize}(lotka_volterra, u0, tspan, p0);
# Solve the ODE and collect solutions at fixed intervals
target_data = solve(prob0, FBDF(), saveat=0:0.5:10.0, abstol=1e-10,
reltol=1e-12)
loss_function = function (p)
prob = remake(prob0; p = p)
prediction = solve(prob, RadauIIA3(); u0=convert.(eltype(p), prob0.u0), saveat = 0.0:0.5:10.0, abstol=1e-10,
reltol = 1e-12)
tmpdata = prediction[[1, 2], :]
tdata = target_data[[1, 2], :]
#Main.@infiltrate
# Calculate squared error
return sum(abs2, tmpdata - tdata)
end
p = [2.0, 1.2, 3.3, 1.5];
fdgrad = ForwardDiff.gradient(loss_function, p)
Zygote.gradient(loss_function,p)
|
that's very odd. It's unclear to me why Radau would be any different here. |
@ChrisRackauckas or @Shreyas-Ekanathan any idea why RadauIIA5 would be different here? |
Complex numbers? |
Seems very plausible. @jClugstor you might need a rule for complex matrices. |
Yes, seems likely there's an issue there. Although it's strange because in the case of |
Maybe a silly question, but when we call linres1 = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff),
linu = _vec(dw1)) we want |
yeah this relies on some weird aliasing. |
Yeah that's my issue I believe. The complex stuff was a red herring, those linear solves don't even go through the overloads since they're |
Ok, assuming that the MTK integration test is green, this all works with support for
But
|
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
Add any other context about the problem here.