From 8d74f76af05af61a5cdbe50d49d6339d579c10db Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 1 Dec 2021 21:16:53 -0500 Subject: [PATCH 1/3] Fix the enzyme default check --- src/concrete_solve.jl | 2 +- test/complex_adjoints.jl | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 9d35a7a7f..3612407fc 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -29,7 +29,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{ODEProblem,SDEProblem}, try Enzyme.autodiff(Enzyme.Duplicated(du, du), u0,p,prob.tspan[1]) do out,u,_p,t - f(out, u, _p, t) + prob.f(out, u, _p, t) nothing end true diff --git a/test/complex_adjoints.jl b/test/complex_adjoints.jl index 824a6bf1a..500162b65 100644 --- a/test/complex_adjoints.jl +++ b/test/complex_adjoints.jl @@ -51,3 +51,17 @@ dp2 = Zygote.gradient((p)->inner_loop(prob, p, loss_fun; sensealg = QuadratureAd dp3 = Zygote.gradient((p)->inner_loop(prob, p, loss_fun; sensealg = BacksolveAdjoint()), p)[1] @test dp1 ≈ dp2 ≈ dp3 @test eltype(dp1) <: Float64 + +function fiip(du,u,p,t) + du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] + du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] +end +p = [1.5,1.0,3.0,1.0]; u0 = [1.0; 1.0] +prob = ODEProblem(fiip,complex(u0),(0.0,10.0),complex(p)) + +function sum_of_solution(u0, p) + _prob = remake(prob,u0=u0,p=p) + real(sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1))) +end + +dx = Zygote.gradient(sum_of_solution, complex(u0), complex(p)) From 4cffcc8d4f75172a536e529ac89e1edadcbaf455 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 1 Dec 2021 21:49:25 -0500 Subject: [PATCH 2/3] run on v1.6 --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 9d94ef58b..17870ffe3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,7 +23,7 @@ jobs: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: - version: 1 + version: 1.6 - uses: actions/cache@v1 env: cache-name: cache-artifacts From abc75c33cc6e3188777912b5a0b1bc6450a325da Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 1 Dec 2021 22:30:59 -0500 Subject: [PATCH 3/3] fix downstream --- src/concrete_solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 3612407fc..d8f8be15a 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -611,7 +611,7 @@ function DiffEqBase._concrete_solve_adjoint(prob,alg,sensealg::TrackerAdjoint, out,pullback = Tracker.forward(tracker_adjoint_forwardpass,u0,p) function tracker_adjoint_backpass(ybar) tmp = if eltype(ybar) <: Number - ybar + Array(ybar) elseif typeof(ybar[1]) <: Array return Array(ybar) else