diff --git a/.gitignore b/.gitignore index f3fc3a6b4..80be6a962 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ *.jl.*.cov *.jl.mem Manifest.toml -/docs/build/ \ No newline at end of file +/docs/build/ +test/Manifest.toml +test/gpu/Manifest.toml \ No newline at end of file diff --git a/Project.toml b/Project.toml index ffd78dd20..b926bda17 100644 --- a/Project.toml +++ b/Project.toml @@ -34,15 +34,14 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" -SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] ADTypes = "0.1, 0.2" @@ -61,52 +60,26 @@ FunctionProperties = "0.1" FunctionWrappersWrappers = "0.1" Functors = "0.4" GPUArraysCore = "0.1" +LinearAlgebra = "<0.0.1, 1" LinearSolve = "2" +Markdown = "<0.0.1, 1" OrdinaryDiffEq = "6.19.1" Parameters = "0.12" PreallocationTools = "0.4.4" QuadGK = "2.1" +Random = "<0.0.1, 1" RandomNumbers = "1.5.3" RecursiveArrayTools = "2.4.2, 3" Reexport = "0.2, 1.0" ReverseDiff = "1.9" SciMLBase = "1.66.0, 2" SciMLOperators = "0.1, 0.2, 0.3" -SimpleNonlinearSolve = "0.1.8" SparseDiffTools = "2.5" +StaticArrays = "1.8.0" StaticArraysCore = "1.4" Statistics = "1" StochasticDiffEq = "6.20" Tracker = "0.2" TruncatedStacktraces = "1.2" Zygote = "0.6" -ZygoteRules = "0.2" julia = "1.9" - -[extras] -AlgebraicMultigrid = "2169fc97-5a83-5252-b627-83903c6c433c" -Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" -Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" -NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" -Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" -OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86" -OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" -OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -SciMLNLSolve = "e9a6253c-8580-4d32-9898-8661bb511710" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "StaticArrays", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "NonlinearSolve", "SparseArrays", "SciMLNLSolve", "OptimizationOptimisers", "Functors", "Lux"] diff --git a/docs/Project.toml b/docs/Project.toml index 91899da42..476acd130 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,23 +1,16 @@ [deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" +DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" -DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" -IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" -NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationNLopt = "4e6fcdb7-1186-4e1f-a706-475e75c168bb" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" @@ -26,36 +19,28 @@ OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" -SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -CUDA = "4, 5" Calculus = "0.5" ComponentArrays = "0.15" DataInterpolations = "3.10, 4" -DiffEqBase = "6.106" +# DiffEqBase = "6.106" DiffEqCallbacks = "2.24" -DiffEqFlux = "1.52, 2" +DiffEqFlux = "3" DiffEqNoiseProcess = "5.14" -DifferentialEquations = "7.6" Documenter = "1" -Flux = "0.13, 0.14" ForwardDiff = "0.10" -GraphNeuralNetworks = "0.5, 0.6" -IterTools = "1.4" -Lux = "0.5" -MLDatasets = "0.7" -NNlib = "0.9" -Optimisers = "0.3" +# GraphNeuralNetworks = "0.5, 0.6" +# IterTools = "1.4" +Lux = "0.5.7" +# MLDatasets = "0.7" Optimization = "3.9" OptimizationNLopt = "0.1" OptimizationOptimJL = "0.1" @@ -67,9 +52,8 @@ QuadGK = "2.6" RecursiveArrayTools = "2.32" ReverseDiff = "1.14" SciMLSensitivity = "7.11" -SimpleChains = "0.4" -StaticArrays = "1.5" +# SimpleChains = "0.4" Statistics = "1" StochasticDiffEq = "6.56" Tracker = "0.2" -Zygote = "0.6" +Zygote = "0.6" \ No newline at end of file diff --git a/docs/src/Benchmark.md b/docs/src/Benchmark.md index 3c4038b31..7e2b91419 100644 --- a/docs/src/Benchmark.md +++ b/docs/src/Benchmark.md @@ -6,9 +6,9 @@ From our [recent papers](https://arxiv.org/abs/1812.01892), it's clear that `Enz especially when the program is set up to be fully non-allocating mutating functions. Thus for all benchmarking, especially with PDEs, this should be done. Neural network libraries don't make use of mutation effectively [except for SimpleChains.jl](https://julialang.org/blog/2022/04/simple-chains/), so we recommend creating a -neural ODE / universal ODE with `ZygoteVJP` and Flux first, but then check the correctness by moving the +neural ODE / universal ODE with `ZygoteVJP` and Lux first, but then check the correctness by moving the implementation over to SimpleChains and if possible `EnzymeVJP`. This can be an order of magnitude improvement -(or more) in many situations over all the previous benchmarks using Zygote and Flux, and thus it's +(or more) in many situations over all the previous benchmarks using Zygote and Lux, and thus it's highly recommended in scenarios that require performance. ## Vs Torchdiffeq 1 million and less ODEs @@ -33,12 +33,12 @@ at this time. Quick summary: - `BacksolveAdjoint` can be the fastest (but use with caution!); about 25% faster - - Using `ZygoteVJP` is faster than other vjp choices with FastDense due to the overloads + - Using `ZygoteVJP` is faster than other vjp choices for larger neural networks + - `ReverseDiffVJP(compile = true)` works well for small Lux neural networks ```julia using DiffEqFlux, - OrdinaryDiffEq, Flux, Optim, Plots, SciMLSensitivity, - Zygote, BenchmarkTools, Random + OrdinaryDiffEq, Lux, SciMLSensitivity, Zygote, BenchmarkTools, Random, ComponentArrays u0 = Float32[2.0; 0.0] datasize = 30 @@ -53,116 +53,39 @@ end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) -dudt2 = FastChain((x, p) -> x .^ 3, - FastDense(2, 50, tanh), - FastDense(50, 2)) +dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2)) Random.seed!(100) -p = initial_params(dudt2) -prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) - -function loss_neuralode(p) - pred = Array(prob_neuralode(u0, p)) - loss = sum(abs2, ode_data .- pred) - return loss -end - -@btime Zygote.gradient(loss_neuralode, p) -# 2.709 ms (56506 allocations: 6.62 MiB) - -prob_neuralode_interpolating = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, - sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true))) - -function loss_neuralode_interpolating(p) - pred = Array(prob_neuralode_interpolating(u0, p)) - loss = sum(abs2, ode_data .- pred) - return loss -end - -@btime Zygote.gradient(loss_neuralode_interpolating, p) -# 5.501 ms (103835 allocations: 2.57 MiB) - -prob_neuralode_interpolating_zygote = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, - sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())) - -function loss_neuralode_interpolating_zygote(p) - pred = Array(prob_neuralode_interpolating_zygote(u0, p)) - loss = sum(abs2, ode_data .- pred) - return loss -end - -@btime Zygote.gradient(loss_neuralode_interpolating_zygote, p) -# 2.899 ms (56150 allocations: 6.61 MiB) - -prob_neuralode_backsolve = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, - sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP(true))) - -function loss_neuralode_backsolve(p) - pred = Array(prob_neuralode_backsolve(u0, p)) - loss = sum(abs2, ode_data .- pred) - return loss -end - -@btime Zygote.gradient(loss_neuralode_backsolve, p) -# 4.871 ms (85855 allocations: 2.20 MiB) - -prob_neuralode_quad = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, - sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))) - -function loss_neuralode_quad(p) - pred = Array(prob_neuralode_quad(u0, p)) - loss = sum(abs2, ode_data .- pred) - return loss -end - -@btime Zygote.gradient(loss_neuralode_quad, p) -# 11.748 ms (79549 allocations: 3.87 MiB) - -prob_neuralode_backsolve_tracker = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, - sensealg = BacksolveAdjoint(autojacvec = TrackerVJP())) - -function loss_neuralode_backsolve_tracker(p) - pred = Array(prob_neuralode_backsolve_tracker(u0, p)) - loss = sum(abs2, ode_data .- pred) - return loss -end - -@btime Zygote.gradient(loss_neuralode_backsolve_tracker, p) -# 27.604 ms (186143 allocations: 12.22 MiB) - -prob_neuralode_backsolve_zygote = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, - sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP())) - -function loss_neuralode_backsolve_zygote(p) - pred = Array(prob_neuralode_backsolve_zygote(u0, p)) - loss = sum(abs2, ode_data .- pred) - return loss -end - -@btime Zygote.gradient(loss_neuralode_backsolve_zygote, p) -# 2.091 ms (49883 allocations: 6.28 MiB) - -prob_neuralode_backsolve_false = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, - sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP(false))) - -function loss_neuralode_backsolve_false(p) - pred = Array(prob_neuralode_backsolve_false(u0, p)) - loss = sum(abs2, ode_data .- pred) - return loss +for sensealg in (InterpolatingAdjoint(autojacvec = ZygoteVJP()), + InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)), + BacksolveAdjoint(autojacvec = ReverseDiffVJP(true)), + BacksolveAdjoint(autojacvec = ZygoteVJP()), + BacksolveAdjoint(autojacvec = ReverseDiffVJP(false)), + BacksolveAdjoint(autojacvec = TrackerVJP()), + QuadratureAdjoint(autojacvec = ReverseDiffVJP(true)), + TrackerAdjoint()) + prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps, + sensealg = sensealg) + ps, st = Lux.setup(Random.default_rng(), prob_neuralode) + ps = ComponentArray(ps) + + loss_neuralode = function (u0, p, st) + pred = Array(first(prob_neuralode(u0, p, st))) + loss = sum(abs2, ode_data .- pred) + return loss + end + + t = @belapsed Zygote.gradient($loss_neuralode, $u0, $ps, $st) + println("$(sensealg) took $(t)s") end -@btime Zygote.gradient(loss_neuralode_backsolve_false, p) -# 4.822 ms (9956 allocations: 1.03 MiB) - -prob_neuralode_tracker = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, - sensealg = TrackerAdjoint()) - -function loss_neuralode_tracker(p) - pred = Array(prob_neuralode_tracker(u0, p)) - loss = sum(abs2, ode_data .- pred) - return loss -end +# InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}(ZygoteVJP(false), false, false) took 0.029134224s +# InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), false, false) took 0.001657377s +# BacksolveAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), true, false) took 0.002477057s +# BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}(ZygoteVJP(false), true, false) took 0.031533335s +# BacksolveAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}}(ReverseDiffVJP{false}(), true, false) took 0.004605386s +# BacksolveAdjoint{0, true, Val{:central}, TrackerVJP}(TrackerVJP(false), true, false) took 0.044568018s +# QuadratureAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), 1.0e-6, 0.001) took 0.002489559s +# TrackerAdjoint() took 0.003759097s -@btime Zygote.gradient(loss_neuralode_tracker, p) -# 12.614 ms (76346 allocations: 3.12 MiB) ``` diff --git a/docs/src/examples/dae/physical_constraints.md b/docs/src/examples/dae/physical_constraints.md index d975df51e..1189e17c2 100644 --- a/docs/src/examples/dae/physical_constraints.md +++ b/docs/src/examples/dae/physical_constraints.md @@ -10,7 +10,7 @@ terms must add to one. An example of this is as follows: ```@example dae using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt, - DifferentialEquations, Plots + OrdinaryDiffEq, Plots using Random rng = Random.default_rng() @@ -74,7 +74,7 @@ result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100) ```@example dae2 using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt, - DifferentialEquations, Plots + OrdinaryDiffEq, Plots using Random rng = Random.default_rng() @@ -133,8 +133,8 @@ Because this is a DAE, we need to make sure to use a **compatible solver**. ### Neural Network Layers Next, we create our layers using `Lux.Chain`. We use this instead of `Flux.Chain` because it -is more suited to SciML applications (similarly for -`Lux.Dense`). The input to our network will be the initial conditions fed in as `u₀`. +is more suited to SciML applications (similarly for `Lux.Dense`). The input to our network +will be the initial conditions fed in as `u₀`. ```@example dae2 nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh), diff --git a/docs/src/examples/dde/delay_diffeq.md b/docs/src/examples/dde/delay_diffeq.md index a5f63c6ea..72b5eabff 100644 --- a/docs/src/examples/dde/delay_diffeq.md +++ b/docs/src/examples/dde/delay_diffeq.md @@ -5,8 +5,8 @@ supported. For example, we can build a layer with a delay differential equation like: ```@example dde -using DifferentialEquations, Optimization, SciMLSensitivity, - OptimizationPolyalgorithms +using OrdinaryDiffEq, Optimization, SciMLSensitivity, OptimizationPolyalgorithms, + DelayDiffEq # Define the same LV equation, but including a delay parameter function delay_lotka_volterra!(du, u, h, p, t) @@ -32,8 +32,7 @@ prob_dde = DDEProblem(delay_lotka_volterra!, u0, h, (0.0, 10.0), function predict_dde(p) return Array(solve(prob_dde, MethodOfSteps(Tsit5()), - u0 = u0, p = p, saveat = 0.1, - sensealg = ReverseDiffAdjoint())) + u0 = u0, p = p, saveat = 0.1, sensealg = ReverseDiffAdjoint())) end loss_dde(p) = sum(abs2, x - 1 for x in predict_dde(p)) diff --git a/docs/src/examples/hybrid_jump/bouncing_ball.md b/docs/src/examples/hybrid_jump/bouncing_ball.md index 900fbbab5..0bd8401c8 100644 --- a/docs/src/examples/hybrid_jump/bouncing_ball.md +++ b/docs/src/examples/hybrid_jump/bouncing_ball.md @@ -8,7 +8,8 @@ data. Assume we have data for the ball's height after 15 seconds. Let's first start by implementing the ODE: ```@example bouncing_ball -using Optimization, OptimizationPolyalgorithms, SciMLSensitivity, DifferentialEquations +using Optimization, + OptimizationPolyalgorithms, SciMLSensitivity, OrdinaryDiffEq, DiffEqCallbacks function f(du, u, p, t) du[1] = u[2] diff --git a/docs/src/examples/hybrid_jump/hybrid_diffeq.md b/docs/src/examples/hybrid_jump/hybrid_diffeq.md index 987a2efe0..9176441e8 100644 --- a/docs/src/examples/hybrid_jump/hybrid_diffeq.md +++ b/docs/src/examples/hybrid_jump/hybrid_diffeq.md @@ -8,7 +8,9 @@ model and the universal differential equation is trained to uncover the missing dynamical equations. ```@example -using DiffEqFlux, Flux, DifferentialEquations, Plots +using DiffEqFlux, ComponentArrays, Random, + Lux, OrdinaryDiffEq, Plots, Optimization, OptimizationOptimisers, DiffEqCallbacks + u0 = Float32[2.0; 0.0] datasize = 100 tspan = (0.0f0, 10.5f0) @@ -25,37 +27,36 @@ t = range(tspan[1], tspan[2], length = datasize) prob = ODEProblem(trueODEfunc, u0, tspan) ode_data = Array(solve(prob, Tsit5(), callback = cb_, saveat = t)) -dudt2 = Flux.Chain(Flux.Dense(2, 50, tanh), - Flux.Dense(50, 2)) -p, re = Flux.destructure(dudt2) # use this p as the initial condition! + +dudt2 = Lux.Chain(Lux.Dense(2, 50, tanh), Lux.Dense(50, 2)) +ps, st = Lux.setup(Random.default_rng(), dudt2) function dudt(du, u, p, t) du[1:2] .= -u[1:2] - du[3:end] .= re(p)(u[1:2]) #re(p)(u[3:end]) + du[3:end] .= first(dudt2(u[1:2], p, st)) end z0 = Float32[u0; u0] prob = ODEProblem(dudt, z0, tspan) affect!(integrator) = integrator.u[1:2] .= integrator.u[3:end] -callback = PresetTimeCallback(dosetimes, affect!, save_positions = (false, false)) +cb = PresetTimeCallback(dosetimes, affect!, save_positions = (false, false)) -function predict_n_ode() +function predict_n_ode(p) _prob = remake(prob, p = p) - Array(solve(_prob, Tsit5(), u0 = z0, p = p, callback = callback, saveat = t, - sensealg = ReverseDiffAdjoint()))[1:2, - :] + Array(solve(_prob, Tsit5(), u0 = z0, p = p, callback = cb, saveat = t, + sensealg = ReverseDiffAdjoint()))[1:2, :] #Array(solve(prob,Tsit5(),u0=z0,p=p,saveat=t))[1:2,:] end -function loss_n_ode() - pred = predict_n_ode() +function loss_n_ode(p, _) + pred = predict_n_ode(p) loss = sum(abs2, ode_data .- pred) loss end -loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE +loss_n_ode(ps, nothing) -cba = function (; doplot = false) #callback function to observe training - pred = predict_n_ode() +cba = function (p, l; doplot = false) #callback function to observe training + pred = predict_n_ode(p) display(sum(abs2, ode_data .- pred)) # plot current prediction against data pl = scatter(t, ode_data[1, :], label = "data") @@ -63,11 +64,9 @@ cba = function (; doplot = false) #callback function to observe training display(plot(pl)) return false end -cba() -ps = Flux.params(p) -data = Iterators.repeated((), 200) -Flux.train!(loss_n_ode, ps, data, ADAM(0.05), cb = cba) +res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), + ComponentArray(ps)), Adam(0.05); callback = cba, maxiters = 1000) ``` ![Hybrid Universal Differential Equation](https://user-images.githubusercontent.com/1814174/91687561-08fc5900-eb2e-11ea-9f26-6b794e1e1248.gif) diff --git a/docs/src/examples/neural_ode/minibatch.md b/docs/src/examples/neural_ode/minibatch.md index 66140a20c..8e22aa5db 100644 --- a/docs/src/examples/neural_ode/minibatch.md +++ b/docs/src/examples/neural_ode/minibatch.md @@ -67,7 +67,7 @@ function cb() end end -opt = ADAM(0.05) +opt = Adam(0.05) Flux.train!(loss_adjoint, Flux.params(θ), ncycle(train_loader, numEpochs), opt, cb = Flux.throttle(cb, 10)) @@ -176,7 +176,7 @@ function cb() end end -opt = ADAM(0.05) +opt = Adam(0.05) Flux.train!(loss_adjoint, Flux.params(θ), ncycle(train_loader, numEpochs), opt, cb = Flux.throttle(cb, 10)) ``` diff --git a/docs/src/examples/neural_ode/neural_gde.md b/docs/src/examples/neural_ode/neural_gde.md index bc03fcc06..25a5f8036 100644 --- a/docs/src/examples/neural_ode/neural_gde.md +++ b/docs/src/examples/neural_ode/neural_gde.md @@ -125,7 +125,7 @@ function train() st = st |> device ## Optimizer - opt = Optimisers.ADAM(0.01f0) + opt = Optimisers.Adam(0.01f0) st_opt = Optimisers.setup(opt, ps) ## Training Loop @@ -303,7 +303,7 @@ st = st |> device ### Optimizer -For this task, we will be using the `ADAM` optimizer with a learning rate of `0.01`. +For this task, we will be using the `Adam` optimizer with a learning rate of `0.01`. ```julia opt = Optimisers.Adam(0.01f0) diff --git a/docs/src/examples/ode/exogenous_input.md b/docs/src/examples/ode/exogenous_input.md index 3dacfcf3b..4b56a3e04 100644 --- a/docs/src/examples/ode/exogenous_input.md +++ b/docs/src/examples/ode/exogenous_input.md @@ -40,7 +40,7 @@ In the following example, a discrete exogenous input signal `ex` is defined and used as an input into the neural network of a neural ODE system. ```@example exogenous -using DifferentialEquations, Lux, ComponentArrays, DiffEqFlux, Optimization, +using OrdinaryDiffEq, Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationPolyalgorithms, OptimizationOptimisers, Plots, Random rng = Random.default_rng() diff --git a/docs/src/examples/ode/prediction_error_method.md b/docs/src/examples/ode/prediction_error_method.md index 98e2c9863..5ba313f5e 100644 --- a/docs/src/examples/ode/prediction_error_method.md +++ b/docs/src/examples/ode/prediction_error_method.md @@ -19,7 +19,7 @@ In both of these examples, we may make use of measurements we have of the evolut We start by defining a model of the pendulum. The model takes a parameter $L$ corresponding to the length of the pendulum. ```@example PEM -using DifferentialEquations, Optimization, OptimizationPolyalgorithms, Plots, Statistics, +using OrdinaryDiffEq, Optimization, OptimizationPolyalgorithms, Plots, Statistics, DataInterpolations, ForwardDiff tspan = (0.1, 20.0) diff --git a/docs/src/examples/ode/second_order_adjoints.md b/docs/src/examples/ode/second_order_adjoints.md index 464f012c6..2a5ef7cfd 100644 --- a/docs/src/examples/ode/second_order_adjoints.md +++ b/docs/src/examples/ode/second_order_adjoints.md @@ -14,8 +14,8 @@ with Hessian-vector products (never forming the Hessian) for large parameter optimizations. ```@example secondorderadjoints -using Flux, DiffEqFlux, Optimization, OptimizationOptimisers, DifferentialEquations, - Plots, Random, OptimizationOptimJL +using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationOptimisers, + OrdinaryDiffEq, Plots, Random, OptimizationOptimJL u0 = Float32[2.0; 0.0] datasize = 30 @@ -30,13 +30,14 @@ end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) -dudt2 = Flux.Chain(x -> x .^ 3, - Flux.Dense(2, 50, tanh), - Flux.Dense(50, 2)) +dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2)) prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) +ps, st = Lux.setup(Random.default_rng(), prob_neuralode) +ps = ComponentArray(ps) +prob_neuralode = Lux.Experimental.StatefulLuxLayer(prob_neuralode, ps, st) function predict_neuralode(p) - Array(prob_neuralode(u0, p)[1]) + Array(prob_neuralode(u0, p)) end function loss_neuralode(p) @@ -72,14 +73,12 @@ end adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype) -optprob1 = Optimization.OptimizationProblem(optf, prob_neuralode.p) -pstart = Optimization.solve(optprob1, ADAM(0.01), callback = callback, maxiters = 100).u +optprob1 = Optimization.OptimizationProblem(optf, prob_neuralode.ps) +pstart = Optimization.solve(optprob1, Adam(0.01), callback = callback, maxiters = 100).u optprob2 = Optimization.OptimizationProblem(optf, pstart) pmin = Optimization.solve(optprob2, NewtonTrustRegion(), callback = callback, maxiters = 200) -pmin = Optimization.solve(optprob2, Optim.KrylovTrustRegion(), callback = callback, - maxiters = 200) ``` Note that we do not demonstrate `Newton()` because we have not found a single diff --git a/docs/src/examples/ode/second_order_neural.md b/docs/src/examples/ode/second_order_neural.md index 15f29fe7c..0ca5a5606 100644 --- a/docs/src/examples/ode/second_order_neural.md +++ b/docs/src/examples/ode/second_order_neural.md @@ -21,19 +21,20 @@ neural network by the mass!) An example of training a neural network on a second order ODE is as follows: ```@example secondorderneural -using DifferentialEquations, - Flux, Optimization, OptimizationOptimisers, RecursiveArrayTools, - Random +using OrdinaryDiffEq, Lux, Optimization, OptimizationOptimisers, RecursiveArrayTools, + Random, ComponentArrays u0 = Float32[0.0; 2.0] du0 = Float32[0.0; 0.0] tspan = (0.0f0, 1.0f0) t = range(tspan[1], tspan[2], length = 20) -model = Flux.Chain(Flux.Dense(2, 50, tanh), Flux.Dense(50, 2)) -p, re = Flux.destructure(model) +model = Chain(Dense(2, 50, tanh), Dense(50, 2)) +ps, st = Lux.setup(Random.default_rng(), model) +ps = ComponentArray(ps) +model = Lux.Experimental.StatefulLuxLayer(model, ps, st) -ff(du, u, p, t) = re(p)(u) +ff(du, u, p, t) = model(u, p) prob = SecondOrderODEProblem{false}(ff, du0, u0, tspan, p) function predict(p) @@ -47,18 +48,16 @@ function loss_n_ode(p) sum(abs2, correct_pos .- pred[1:2, :]), pred end -data = Iterators.repeated((), 1000) -opt = ADAM(0.01) - -l1 = loss_n_ode(p) +l1 = loss_n_ode(ps) callback = function (p, l, pred) println(l) l < 0.01 end + adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_n_ode(x), adtype) -optprob = Optimization.OptimizationProblem(optf, p) +optprob = Optimization.OptimizationProblem(optf, ps) -res = Optimization.solve(optprob, opt; callback = callback, maxiters = 1000) +res = Optimization.solve(optprob, Adam(0.01); callback = callback, maxiters = 1000) ``` diff --git a/docs/src/examples/optimal_control/feedback_control.md b/docs/src/examples/optimal_control/feedback_control.md index e3320cacc..61f4f9a9b 100644 --- a/docs/src/examples/optimal_control/feedback_control.md +++ b/docs/src/examples/optimal_control/feedback_control.md @@ -10,40 +10,34 @@ on the current state of the dynamical system that will control the second equation to stay close to 1. ```@example udeneuralcontrol -using Flux, Optimization, OptimizationPolyalgorithms, - SciMLSensitivity, Zygote, DifferentialEquations, Plots, Random +using Lux, Optimization, OptimizationPolyalgorithms, ComponentArrays, + SciMLSensitivity, Zygote, OrdinaryDiffEq, Plots, Random rng = Random.default_rng() -u0 = 1.1 +u0 = [1.1] tspan = (0.0, 25.0) tsteps = 0.0:1.0:25.0 -model_univ = Flux.Chain(Flux.Dense(2, 16, tanh), - Flux.Dense(16, 16, tanh), - Flux.Dense(16, 1)) |> f64 - -# The model weights are destructured into a vector of parameters -p_model, re = Flux.destructure(model_univ) -n_weights = length(p_model) +model_univ = Chain(Dense(2, 16, tanh), Dense(16, 16, tanh), Dense(16, 1)) +ps, st = Lux.setup(Random.default_rng(), model_univ) +p_model = ComponentArray(ps) # Parameters of the second equation (linear dynamics) p_system = [0.5, -0.5] - -p_all = [p_model; p_system] -θ = [u0; p_all] +p_all = ComponentArray(; p_model, p_system) +θ = ComponentArray(; u0, p_all) function dudt_univ!(du, u, p, t) # Destructure the parameters - model_weights = p[1:n_weights] - α = p[end - 1] - β = p[end] + model_weights = p.p_model + α, β = p.p_system # The neural network outputs a control taken by the system # The system then produces an output model_control, system_output = u # Dynamics of the control and system - dmodel_control = re(model_weights)(u)[1] + dmodel_control = first(model_univ(u, model_weights, st))[1] dsystem_output = α * system_output + β * model_control # Update in place @@ -51,11 +45,11 @@ function dudt_univ!(du, u, p, t) du[2] = dsystem_output end -prob_univ = ODEProblem(dudt_univ!, [0.0, u0], tspan, p_all) +prob_univ = ODEProblem(dudt_univ!, [0.0, u0[1]], tspan, p_all) sol_univ = solve(prob_univ, Tsit5(), abstol = 1e-8, reltol = 1e-6) function predict_univ(θ) - return Array(solve(prob_univ, Tsit5(), u0 = [0.0, θ[1]], p = θ[2:end], + return Array(solve(prob_univ, Tsit5(), u0 = [0.0, θ.u0[1]], p = θ.p_all, sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)), saveat = tsteps)) end @@ -67,7 +61,7 @@ l = loss_univ(θ) ```@example udeneuralcontrol list_plots = [] iter = 0 -callback = function (θ, l) +cb = function (θ, l) global list_plots, iter if iter == 0 @@ -88,6 +82,5 @@ end adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_univ(x), adtype) optprob = Optimization.OptimizationProblem(optf, θ) -result_univ = Optimization.solve(optprob, PolyOpt(), - callback = callback) +result_univ = Optimization.solve(optprob, PolyOpt(), callback = cb) ``` diff --git a/docs/src/examples/optimal_control/optimal_control.md b/docs/src/examples/optimal_control/optimal_control.md index a995f31d7..f7e0ddc58 100644 --- a/docs/src/examples/optimal_control/optimal_control.md +++ b/docs/src/examples/optimal_control/optimal_control.md @@ -36,17 +36,23 @@ will first reduce control cost (the last term) by 10x in order to bump the netwo of a local minimum. This looks like: ```@example neuraloptimalcontrol -using Flux, DifferentialEquations, Optimization, OptimizationNLopt, OptimizationOptimisers, - SciMLSensitivity, Zygote, Plots, Statistics, Random +using Lux, ComponentArrays, OrdinaryDiffEq, Optimization, OptimizationNLopt, + OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random rng = Random.default_rng() tspan = (0.0f0, 8.0f0) -ann = Flux.Chain(Flux.Dense(1, 32, tanh), Flux.Dense(32, 32, tanh), Flux.Dense(32, 1)) -θ, re = Flux.destructure(ann) + +ann = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1)) +ps, st = Lux.setup(rng, ann) +p = ComponentArray(ps) + +θ, ax = getdata(p), getaxes(p) + function dxdt_(dx, x, p, t) + ps = ComponentArray(p, ax) x1, x2 = x dx[1] = x[2] - dx[2] = re(p)([t])[1]^3 + dx[2] = first(ann([t], ps, st))[1]^3 end x0 = [-4.0f0, 0.0f0] ts = Float32.(collect(0.0:0.01:tspan[2])) @@ -59,17 +65,20 @@ function predict_adjoint(θ) end function loss_adjoint(θ) x = predict_adjoint(θ) + ps = ComponentArray(θ, ax) mean(abs2, 4.0 .- x[1, :]) + 2mean(abs2, x[2, :]) + - mean(abs2, [first(re(θ)([t])) for t in ts]) / 10 + mean(abs2, [first(first(ann([t], ps, st))) for t in ts]) / 10 end l = loss_adjoint(θ) -callback = function (θ, l; doplot = false) +cb = function (θ, l; doplot = true) println(l) + ps = ComponentArray(θ, ax) + if doplot p = plot(solve(remake(prob, p = θ), Tsit5(), saveat = 0.01), ylim = (-6, 6), lw = 3) - plot!(p, ts, [first(re(θ)([t])) for t in ts], label = "u(t)", lw = 3) + plot!(p, ts, [first(first(ann([t], ps, st))) for t in ts], label = "u(t)", lw = 3) display(p) end @@ -78,7 +87,7 @@ end # Display the ODE with the current parameter values. -callback(θ, l) +cb(θ, l) # Setup and run the optimization @@ -87,11 +96,10 @@ adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype) optprob = Optimization.OptimizationProblem(optf, θ) -res1 = Optimization.solve(optprob, ADAM(0.005), callback = callback, maxiters = 100) +res1 = Optimization.solve(optprob, Adam(0.01), callback = cb, maxiters = 100) optprob2 = Optimization.OptimizationProblem(optf, res1.u) -res2 = Optimization.solve(optprob2, - NLopt.LD_LBFGS(), maxiters = 100) +res2 = Optimization.solve(optprob2, NLopt.LD_LBFGS(), callback = cb, maxiters = 100) ``` Now that the system is in a better behaved part of parameter space, we return to @@ -100,23 +108,23 @@ the original loss function to finish the optimization: ```@example neuraloptimalcontrol function loss_adjoint(θ) x = predict_adjoint(θ) + ps = ComponentArray(θ, ax) mean(abs2, 4.0 .- x[1, :]) + 2mean(abs2, x[2, :]) + - mean(abs2, [first(re(θ)([t])) for t in ts]) + mean(abs2, [first(first(ann([t], ps, st))) for t in ts]) end optf3 = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype) optprob3 = Optimization.OptimizationProblem(optf3, res2.u) -res3 = Optimization.solve(optprob3, - NLopt.LD_LBFGS(), maxiters = 100) +res3 = Optimization.solve(optprob3, NLopt.LD_LBFGS(), maxiters = 100) ``` Now let's see what we received: ```@example neuraloptimalcontrol l = loss_adjoint(res3.u) -callback(res3.u, l) +cb(res3.u, l) p = plot(solve(remake(prob, p = res3.u), Tsit5(), saveat = 0.01), ylim = (-6, 6), lw = 3) -plot!(p, ts, [first(re(res3.u)([t])) for t in ts], label = "u(t)", lw = 3) +plot!(p, ts, [first(first(ann([t], ComponentArray(res3.u, ax), st))) for t in ts], label = "u(t)", lw = 3) ``` ![](https://user-images.githubusercontent.com/1814174/81859169-db65b280-9532-11ea-8394-dbb5efcd4036.png) diff --git a/docs/src/examples/pde/pde_constrained.md b/docs/src/examples/pde/pde_constrained.md index cc0c098ff..e861b5bd8 100644 --- a/docs/src/examples/pde/pde_constrained.md +++ b/docs/src/examples/pde/pde_constrained.md @@ -5,7 +5,7 @@ This example uses a prediction model to optimize the one-dimensional Heat Equati ```@example pde using DelimitedFiles, Plots -using DifferentialEquations, Optimization, OptimizationPolyalgorithms, Zygote +using OrdinaryDiffEq, Optimization, OptimizationPolyalgorithms, Zygote # Problem setup parameters: Lx = 10.0 @@ -74,7 +74,7 @@ LOSS = [] # Loss accumulator PRED = [] # prediction accumulator PARS = [] # parameters accumulator -callback = function (θ, l, pred) #callback function to observe training +cb = function (θ, l, pred) #callback function to observe training display(l) append!(PRED, [pred]) append!(LOSS, l) @@ -82,7 +82,7 @@ callback = function (θ, l, pred) #callback function to observe training false end -callback(ps, loss(ps)...) # Testing callback function +cb(ps, loss(ps)...) # Testing callback function # Let see prediction vs. Truth scatter(sol[:, end], label = "Truth", size = (800, 500)) @@ -92,7 +92,7 @@ adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) optprob = Optimization.OptimizationProblem(optf, ps) -res = Optimization.solve(optprob, PolyOpt(), callback = callback) +res = Optimization.solve(optprob, PolyOpt(), callback = cb) @show res.u # returns [0.999999999613485, 0.9999999991343996] ``` @@ -102,8 +102,7 @@ res = Optimization.solve(optprob, PolyOpt(), callback = callback) ```@example pde2 using DelimitedFiles, Plots -using DifferentialEquations, Optimization, OptimizationPolyalgorithms, - Zygote +using OrdinaryDiffEq, Optimization, OptimizationPolyalgorithms, Zygote ``` ### Parameters @@ -233,7 +232,7 @@ size(pred), size(sol), size(t) # Checking sizes #### Optimizer -The optimizers `ADAM` with a learning rate of 0.01 and `BFGS` are directly passed in +The optimizers `Adam` with a learning rate of 0.01 and `BFGS` are directly passed in training (see below) #### Callback @@ -247,7 +246,7 @@ LOSS = [] # Loss accumulator PRED = [] # prediction accumulator PARS = [] # parameters accumulator -callback = function (θ, l, pred) #callback function to observe training +cb = function (θ, l, pred) #callback function to observe training display(l) append!(PRED, [pred]) append!(LOSS, l) @@ -255,7 +254,7 @@ callback = function (θ, l, pred) #callback function to observe training false end -callback(ps, loss(ps)...) # Testing callback function +cb(ps, loss(ps)...) # Testing callback function ``` ### Plotting Prediction vs Ground Truth @@ -281,7 +280,7 @@ adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) optprob = Optimization.OptimizationProblem(optf, ps) -res = Optimization.solve(optprob, PolyOpt(), callback = callback) +res = Optimization.solve(optprob, PolyOpt(), callback = cb) @show res.u # returns [0.999999999613485, 0.9999999991343996] ``` diff --git a/docs/src/examples/sde/SDE_control.md b/docs/src/examples/sde/SDE_control.md index 04b2f5417..0e4533eeb 100644 --- a/docs/src/examples/sde/SDE_control.md +++ b/docs/src/examples/sde/SDE_control.md @@ -20,9 +20,7 @@ follow a full explanation of the definition and training process: ```@example # load packages -using DiffEqFlux -using SciMLSensitivity -using Optimization, OptimizationOptimisers +using DiffEqFlux, SciMLSensitivity, Optimization, OptimizationOptimisers using StochasticDiffEq, DiffEqCallbacks, DiffEqNoiseProcess using Zygote, Statistics, LinearAlgebra, Random using Lux, Random, ComponentArrays @@ -293,7 +291,7 @@ visualization_callback(p_nn, l; doplot = true) # training loop @info "Start Training.." -# optimize the parameters for a few epochs with ADAM on time span +# optimize the parameters for a few epochs with Adam on time span # Setup and run the optimization adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) @@ -646,7 +644,7 @@ end ### Training -We use the `ADAM` optimizer to optimize the parameters of the neural network. +We use the `Adam` optimizer to optimize the parameters of the neural network. In each epoch, we draw new initial quantum states, compute the forward evolution, and, subsequently, the gradients of the loss function with respect to the parameters of the neural network. @@ -656,7 +654,7 @@ sensitivity methods. The necessary correction between Ito and Stratonovich integ is computed under the hood in the SciMLSensitivity package. ```@example sdecontrol -# optimize the parameters for a few epochs with ADAM on time span +# optimize the parameters for a few epochs with Adam on time span # Setup and run the optimization adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) diff --git a/docs/src/examples/sde/optimization_sde.md b/docs/src/examples/sde/optimization_sde.md index 5b0103b9f..22729c14a 100644 --- a/docs/src/examples/sde/optimization_sde.md +++ b/docs/src/examples/sde/optimization_sde.md @@ -15,7 +15,7 @@ is a stochastic process. Each time we solve this equation we get a different solution, so we need a sensible data source. ```@example sde -using DifferentialEquations, SciMLSensitivity, Plots +using StochasticDiffEq, SciMLSensitivity, Plots function lotka_volterra!(du, u, p, t) x, y = u @@ -34,7 +34,7 @@ end p = [1.5, 1.0, 3.0, 1.0, 0.3, 0.3] prob = SDEProblem(lotka_volterra!, multiplicative_noise!, u0, tspan, p) -sol = solve(prob) +sol = solve(prob, SOSRI()) plot(sol) ``` @@ -91,7 +91,7 @@ pinit = [1.2, 0.8, 2.5, 0.8, 0.1, 0.1] adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) optprob = Optimization.OptimizationProblem(optf, pinit) -@time res = Optimization.solve(optprob, ADAM(0.05), callback = cb2, maxiters = 100) +@time res = Optimization.solve(optprob, Adam(0.05), callback = cb2, maxiters = 100) ``` Notice that **both the parameters of the deterministic drift equations and the @@ -121,7 +121,7 @@ In this example, we will find the parameters of the SDE that force the solution to be close to the constant 1. ```@example sde -using DifferentialEquations, DiffEqFlux, Optimization, OptimizationOptimisers, Plots +using StochasticDiffEq, DiffEqFlux, Optimization, OptimizationOptimisers, Plots function lotka_volterra!(du, u, p, t) x, y = u @@ -149,7 +149,7 @@ loss_sde(p) = sum(abs2, x - 1 for x in predict_sde(p)) ``` For this training process, because the loss function is stochastic, we will use -the `ADAM` optimizer from Flux.jl. The `Optimization.solve` function is the same as +the `Adam` optimizer from Flux.jl. The `Optimization.solve` function is the same as before. However, to speed up the training process, we will use a global counter so that way we only plot the current results every 10 iterations. This looks like: @@ -171,8 +171,7 @@ adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_sde(x), adtype) optprob = Optimization.OptimizationProblem(optf, p) -result_sde = Optimization.solve(optprob, ADAM(0.1), - callback = callback, maxiters = 100) +result_sde = Optimization.solve(optprob, Adam(0.1), callback = callback, maxiters = 100) ``` ![](https://user-images.githubusercontent.com/1814174/51399524-2c6abf80-1b14-11e9-96ae-0192f7debd03.gif) diff --git a/docs/src/getting_started.md b/docs/src/getting_started.md index 62be46a22..5e3be8dcc 100644 --- a/docs/src/getting_started.md +++ b/docs/src/getting_started.md @@ -28,7 +28,7 @@ Let's first define a differential equation we wish to solve. We will choose the Lotka-Volterra equation. This is done via DifferentialEquations.jl using: ```@example diffode -using DifferentialEquations +using OrdinaryDiffEq function lotka_volterra!(du, u, p, t) du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] diff --git a/docs/src/index.md b/docs/src/index.md index 5ce6cdde2..5289824e9 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -41,8 +41,7 @@ Pkg.add("SciMLSensitivity") The highest level interface is provided by the function `solve`: ```julia -solve(prob, args...; sensealg = InterpolatingAdjoint(), - checkpoints = sol.t, kwargs...) +solve(prob, args...; sensealg = InterpolatingAdjoint(), checkpoints = sol.t, kwargs...) ``` `solve` is fully compatible with automatic differentiation libraries @@ -195,7 +194,7 @@ use and swap out the ODE solver between any common interface compatible library, - [… etc. many other choices!](https://docs.sciml.ai/DiffEqDocs/stable/solvers/ode_solve/) In addition, due to the composability of the system, none of the components are directly -tied to the Flux.jl machine learning framework. For example, you can [use SciMLSensitivity.jl +tied to the Lux.jl machine learning framework. For example, you can [use SciMLSensitivity.jl to generate TensorFlow graphs and train the neural network with TensorFlow.jl](https://youtu.be/n2MwJ1guGVQ?t=284), [use PyTorch arrays via Torch.jl](https://github.com/FluxML/Torch.jl), and more all with single line code changes by utilizing the underlying code generation. The tutorials shown here diff --git a/docs/src/tutorials/data_parallel.md b/docs/src/tutorials/data_parallel.md index 63397aeef..6eb4cb0b5 100644 --- a/docs/src/tutorials/data_parallel.md +++ b/docs/src/tutorials/data_parallel.md @@ -17,7 +17,10 @@ matrix multiplication). Thus for example, with `Chain` we can define an ODE: ```@example dataparallel -using Lux, DiffEqFlux, DifferentialEquations, CUDA, Random +using Lux, DiffEqFlux, OrdinaryDiffEq, LuxCUDA, Random + +gdev = gpu_device() + rng = Random.default_rng() dudt = Lux.Chain(Lux.Dense(2, 50, tanh), Lux.Dense(50, 2)) @@ -51,7 +54,7 @@ GPU: ```@example dataparallel xs = Float32.([0 1 2 0 0 0]) -prob = ODEProblem(f, Lux.gpu(u0), (0.0f0, 1.0f0), Lux.gpu(p)) +prob = ODEProblem(f, gdev(u0), (0.0f0, 1.0f0), gdev(p)) solve(prob, Tsit5()) ``` @@ -83,7 +86,7 @@ The following is a full copy-paste example for the multithreading. Distributed and GPU minibatching are described below. ```@example dataparallel -using DifferentialEquations, Optimization, OptimizationOptimisers +using OrdinaryDiffEq, Optimization, OptimizationOptimisers pa = [1.0] u0 = [3.0] θ = [u0; pa] @@ -108,7 +111,7 @@ callback = function (θ, l) # callback function to observe training false end -opt = ADAM(0.1) +opt = Adam(0.1) l1 = loss_serial(θ) adtype = Optimization.AutoZygote() @@ -198,7 +201,7 @@ using Distributed addprocs(4) @everywhere begin - using DifferentialEquations, Optimization, OptimizationOptimisers + using OrdinaryDiffEq, Optimization, OptimizationOptimisers function f(u, p, t) 1.01u .* p end @@ -224,7 +227,7 @@ callback = function (θ, l) # callback function to observe training false end -opt = ADAM(0.1) +opt = Adam(0.1) loss_distributed(θ) = sum(abs2, 1.0 .- Array(model1(θ, EnsembleDistributed()))) l1 = loss_distributed(θ) @@ -278,7 +281,7 @@ callback = function (θ, l) # callback function to observe training false end -opt = ADAM(0.1) +opt = Adam(0.1) loss_gpu(θ) = sum(abs2, 1.0 .- Array(model1(θ, EnsembleGPUArray()))) l1 = loss_gpu(θ) diff --git a/docs/src/tutorials/parameter_estimation_ode.md b/docs/src/tutorials/parameter_estimation_ode.md index 2626933ad..ee6d6a3e6 100644 --- a/docs/src/tutorials/parameter_estimation_ode.md +++ b/docs/src/tutorials/parameter_estimation_ode.md @@ -6,7 +6,7 @@ If you want to just get things running, try the following! Explanation will follow. ```@example optode_cp -using DifferentialEquations, +using OrdinaryDiffEq, Optimization, OptimizationPolyalgorithms, SciMLSensitivity, Zygote, Plots @@ -62,8 +62,8 @@ result_ode = Optimization.solve(optprob, PolyOpt(), ## Explanation -First, let's create a Lotka-Volterra ODE using DifferentialEquations.jl. For -more details, [see the DifferentialEquations.jl documentation](https://docs.sciml.ai/DiffEqDocs/stable/). The Lotka-Volterra equations have the form: +First, let's create a Lotka-Volterra ODE using OrdinaryDiffEq.jl. For +more details, [see the OrdinaryDiffEq.jl documentation](https://docs.sciml.ai/DiffEqDocs/stable/). The Lotka-Volterra equations have the form: ```math \begin{aligned} @@ -73,7 +73,7 @@ more details, [see the DifferentialEquations.jl documentation](https://docs.scim ``` ```@example optode -using DifferentialEquations, +using OrdinaryDiffEq, Optimization, OptimizationPolyalgorithms, SciMLSensitivity, Zygote, Plots @@ -122,7 +122,7 @@ function loss(p) end ``` -Lastly, we use the `Optimization.solve` function to train the parameters using `ADAM` to +Lastly, we use the `Optimization.solve` function to train the parameters using `Adam` to arrive at parameters which optimize for our goal. `Optimization.solve` allows defining a callback that will be called at each step of our training loop. It takes in the current parameter vector and the returns of the last call to the loss diff --git a/docs/src/tutorials/training_tips/divergence.md b/docs/src/tutorials/training_tips/divergence.md index 014f99527..d9eecccfe 100644 --- a/docs/src/tutorials/training_tips/divergence.md +++ b/docs/src/tutorials/training_tips/divergence.md @@ -28,8 +28,7 @@ end A full example making use of this trick is: ```@example divergence -using DifferentialEquations, - SciMLSensitivity, Optimization, OptimizationOptimisers, +using OrdinaryDiffEq, SciMLSensitivity, Optimization, OptimizationOptimisers, OptimizationNLopt, Plots function lotka_volterra!(du, u, p, t) @@ -44,14 +43,14 @@ u0 = [1.0, 1.0] tspan = (0.0, 10.0) p = [1.5, 1.0, 3.0, 1.0] prob = ODEProblem(lotka_volterra!, u0, tspan, p) -sol = solve(prob, saveat = 0.1) +sol = solve(prob, Tsit5(); saveat = 0.1) plot(sol) dataset = Array(sol) scatter!(sol.t, dataset') tmp_prob = remake(prob, p = [1.2, 0.8, 2.5, 0.8]) -tmp_sol = solve(tmp_prob) +tmp_sol = solve(tmp_prob, Tsit5()) plot(tmp_sol) scatter!(sol.t, dataset') @@ -70,14 +69,16 @@ adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) optprob = Optimization.OptimizationProblem(optf, pinit) -res = Optimization.solve(optprob, ADAM(), maxiters = 1000) +res = Optimization.solve(optprob, Adam(), maxiters = 1000) # res = Optimization.solve(optprob,NLopt.LD_LBFGS(), maxiters = 1000) ### errors! ``` -You might notice that `AutoZygote` (default) fails for the above `Optimization.solve` call with Optim's optimizers, which happens because -of Zygote's behavior for zero gradients, in which case it returns `nothing`. To avoid such issues, you can just use a different version of the same check which compares the size of the obtained -solution and the data we have, shown below, which is easier to AD. +You might notice that `AutoZygote` (default) fails for the above `Optimization.solve` call +with Optim's optimizers, which happens because of Zygote's behavior for zero gradients, in +which case it returns `nothing`. To avoid such issues, you can just use a different version +of the same check which compares the size of the obtained solution and the data we have, +shown below, which is easier to AD. ```julia function loss(p) diff --git a/docs/src/tutorials/training_tips/local_minima.md b/docs/src/tutorials/training_tips/local_minima.md index b5b3808ca..db57e215d 100644 --- a/docs/src/tutorials/training_tips/local_minima.md +++ b/docs/src/tutorials/training_tips/local_minima.md @@ -16,10 +16,9 @@ before, except with one small twist: we wish to find the neural ODE that fits on `(0,5.0)`. Naively, we use the same training strategy as before: ```@example iterativefit -using DifferentialEquations, - ComponentArrays, SciMLSensitivity, Optimization, - OptimizationOptimisers -using Lux, Plots, Random +using OrdinaryDiffEq, + ComponentArrays, SciMLSensitivity, Optimization, OptimizationOptimisers +using Lux, Plots, Random, Zygote rng = Random.default_rng() u0 = Float32[2.0; 0.0] @@ -79,7 +78,7 @@ optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype) optprob = Optimization.OptimizationProblem(optf, pinit) result_neuralode = Optimization.solve(optprob, - ADAM(0.05), callback = callback, + Adam(0.05), callback = callback, maxiters = 300) pred = predict_neuralode(result_neuralode.u) @@ -108,7 +107,7 @@ optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype) optprob = Optimization.OptimizationProblem(optf, pinit) result_neuralode2 = Optimization.solve(optprob, - ADAM(0.05), callback = callback, + Adam(0.05), callback = callback, maxiters = 300) pred = predict_neuralode(result_neuralode2.u) @@ -128,7 +127,7 @@ end optprob = Optimization.OptimizationProblem(optf, result_neuralode2.u) result_neuralode3 = Optimization.solve(optprob, - ADAM(0.05), maxiters = 300, + Adam(0.05), maxiters = 300, callback = callback) pred = predict_neuralode(result_neuralode3.u) @@ -148,7 +147,7 @@ end optprob = Optimization.OptimizationProblem(optf, result_neuralode3.u) result_neuralode4 = Optimization.solve(optprob, - ADAM(0.01), maxiters = 500, + Adam(0.01), maxiters = 500, callback = callback) pred = predict_neuralode(result_neuralode4.u) @@ -167,7 +166,9 @@ time span and (0, 10), any longer and more iterations will be required. Alternat one could use a mix of (3) and (4), or breaking up the trajectory into chunks and just (4). ```@example resetic -using Flux, Plots, DifferentialEquations, SciMLSensitivity +using OrdinaryDiffEq, + ComponentArrays, SciMLSensitivity, Optimization, OptimizationOptimisers +using Lux, Plots, Random, Zygote #Starting example with tspan (0, 5) u0 = Float32[2.0; 0.0] @@ -183,27 +184,28 @@ end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) -#Using flux here to easily demonstrate the idea, but this can be done with Optimization.solve! -dudt2 = Chain(Dense(2, 16, tanh), - Dense(16, 2)) +dudt2 = Chain(Dense(2, 16, tanh), Dense(16, 2)) -p, re = Flux.destructure(dudt2) # use this p as the initial condition! -dudt(u, p, t) = re(p)(u) # need to restructure for backprop! +ps, st = Lux.setup(Random.default_rng(), dudt2) +p = ComponentArray(ps) +dudt(u, p, t) = first(dudt2(u, p, st)) prob = ODEProblem(dudt, u0, tspan) -function predict_n_ode() - Array(solve(prob, u0 = u0, p = p, saveat = tsteps)) +function predict_n_ode(pu0) + u0 = pu0.u0 + p = pu0.p + Array(solve(prob, Tsit5(), u0 = u0, p = p, saveat = tsteps)) end -function loss_n_ode() - pred = predict_n_ode() +function loss_n_ode(pu0, _) + pred = predict_n_ode(pu0) sqnorm(x) = sum(abs2, x) loss = sum(abs2, ode_data .- pred) loss end -function callback(; doplot = true) #callback function to observe training - pred = predict_n_ode() +function callback(p, l; doplot = true) #callback function to observe training + pred = predict_n_ode(p) display(sum(abs2, ode_data .- pred)) if doplot # plot current prediction against data @@ -213,22 +215,42 @@ function callback(; doplot = true) #callback function to observe training end return false end -predict_n_ode() -loss_n_ode() -callback() -data = Iterators.repeated((), 1000) +p_init = ComponentArray(; u0 = u0, p = p) -#Specify to flux to include both the initial conditions (IC) and parameters of the NODE to train -Flux.train!(loss_n_ode, Flux.params(u0, p), data, - Flux.Optimise.ADAM(0.05), cb = callback) +predict_n_ode(p_init) +loss_n_ode(p_init, nothing) + +res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), p_init), + Adam(0.05); callback = callback, maxiters = 1000) + +function predict_n_ode2(p) + Array(solve(prob, Tsit5(), u0 = u0, p = p, saveat = tsteps)) +end + +function loss_n_ode2(p, _) + pred = predict_n_ode2(p) + sqnorm(x) = sum(abs2, x) + loss = sum(abs2, ode_data .- pred) + loss +end + +function callback2(p, l; doplot = true) #callback function to observe training + pred = predict_n_ode2(p) + display(sum(abs2, ode_data .- pred)) + if doplot + # plot current prediction against data + pl = plot(tsteps, ode_data[1, :], label = "data") + plot!(pl, tsteps, pred[1, :], label = "prediction") + display(plot(pl)) + end + return false +end #Here we reset the IC back to the original and train only the NODE parameters u0 = Float32[2.0; 0.0] -Flux.train!(loss_n_ode, Flux.params(p), data, - Flux.Optimise.ADAM(0.05), cb = callback) - -callback() +res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode2, AutoZygote()), p_init.p), + Adam(0.05); callback = callback2, maxiters = 1000) #Now use the same technique for a longer tspan (0, 10) datasize = 30 @@ -238,23 +260,19 @@ tsteps = range(tspan[1], tspan[2], length = datasize) prob_trueode = ODEProblem(trueODEfunc, u0, tspan) ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) -dudt2 = Chain(Dense(2, 16, tanh), - Dense(16, 2)) +dudt2 = Chain(Dense(2, 16, tanh), Dense(16, 2)) -p, re = Flux.destructure(dudt2) # use this p as the initial condition! -dudt(u, p, t) = re(p)(u) # need to restrcture for backprop! +ps, st = Lux.setup(Random.default_rng(), dudt2) +p = ComponentArray(ps) +dudt(u, p, t) = first(dudt2(u, p, st)) prob = ODEProblem(dudt, u0, tspan) -data = Iterators.repeated((), 1500) - -Flux.train!(loss_n_ode, Flux.params(u0, p), data, - Flux.Optimise.ADAM(0.05), cb = callback) - -u0 = Float32[2.0; 0.0] -Flux.train!(loss_n_ode, Flux.params(p), data, - Flux.Optimise.ADAM(0.05), cb = callback) +p_init = ComponentArray(; u0 = u0, p = p) +res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), p_init), + Adam(0.05); callback = callback, maxiters = 1000) -callback() +res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode2, AutoZygote()), p_init.p), + Adam(0.05); callback = callback2, maxiters = 1000) ``` And there we go, a set of robust strategies for fitting an equation that would otherwise diff --git a/docs/src/tutorials/training_tips/multiple_nn.md b/docs/src/tutorials/training_tips/multiple_nn.md index 1e9b1a3cd..e07ef2462 100644 --- a/docs/src/tutorials/training_tips/multiple_nn.md +++ b/docs/src/tutorials/training_tips/multiple_nn.md @@ -8,7 +8,7 @@ The following is a fully working demo on the Fitzhugh-Nagumo ODE: ```@example using Lux, DiffEqFlux, ComponentArrays, Optimization, OptimizationNLopt, - OptimizationOptimisers, DifferentialEquations, Random + OptimizationOptimisers, OrdinaryDiffEq, Random rng = Random.default_rng() Random.seed!(rng, 1) diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 778db4404..e62dab353 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -9,11 +9,10 @@ using StochasticDiffEq import DiffEqNoiseProcess import RandomNumbers: Xorshifts using Random -import ZygoteRules, Zygote, ReverseDiff +import Zygote, ReverseDiff import ArrayInterface import Enzyme import GPUArraysCore -using StaticArraysCore using ADTypes using SparseDiffTools using SciMLOperators @@ -28,8 +27,7 @@ using FunctionProperties: hasbranching using Markdown using Reexport -import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ProjectTo, - project_type, _eltype_projectto, rrule +import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented abstract type SensitivityFunction end abstract type TransformedFunction end @@ -58,7 +56,6 @@ include("concrete_solve.jl") include("second_order.jl") include("steadystate_adjoint.jl") include("sde_tools.jl") -include("staticarrays.jl") export extract_local_sensitivities diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 52212466f..25609c50c 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1010,8 +1010,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractDiscre SciMLBase.AbstractDDEProblem, SciMLBase.AbstractSDEProblem, SciMLBase.AbstractSDDEProblem, - SciMLBase.AbstractRODEProblem, - NonlinearProblem, SteadyStateProblem, + SciMLBase.AbstractRODEProblem }, alg, sensealg::ZygoteAdjoint, u0, p, originator::SciMLBase.ADOriginator, @@ -1023,6 +1022,20 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractDiscre p) end +# NOTE: This is needed to prevent a method ambiguity error +function DiffEqBase._concrete_solve_adjoint(prob::Union{ + NonlinearProblem, + SteadyStateProblem, + }, alg, sensealg::ZygoteAdjoint, + u0, p, originator::SciMLBase.ADOriginator, + args...; kwargs...) + kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) + Zygote.pullback((u0, p) -> solve(prob, alg, args...; u0 = u0, p = p, + sensealg = SensitivityADPassThrough(), + kwargs_filtered...), u0, + p) +end + function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractDiscreteProblem, SciMLBase.AbstractODEProblem, SciMLBase.AbstractDAEProblem, diff --git a/src/lss.jl b/src/lss.jl index 96fce512f..565f90474 100644 --- a/src/lss.jl +++ b/src/lss.jl @@ -370,7 +370,7 @@ function shadow_forward(prob::ForwardLSSProblem, sensealg::ForwardLSS, for (j, u) in enumerate(ures) vtmp = @view v[((n0 + j - 2) * numindvar + 1):((n0 + j - 1) * numindvar)] # final gradient result for ith parameter - accumulate_cost!(u, uf.p, uf.t, sensealg, diffcache, n0 + j - 1) + lss_accumulate_cost!(u, uf.p, uf.t, sensealg, diffcache, n0 + j - 1) if dg_val isa Tuple res[i] += dot(dg_val[1], vtmp) @@ -415,7 +415,7 @@ function shadow_forward(prob::ForwardLSSProblem, sensealg::ForwardLSS, for (j, u) in enumerate(sol.u) vtmp = @view v[((j - 1) * numindvar + 1):(j * numindvar)] # final gradient result for ith parameter - accumulate_cost!(u, uf.p, uf.t, sensealg, diffcache, j) + lss_accumulate_cost!(u, uf.p, uf.t, sensealg, diffcache, j) if dg_val isa Tuple res[i] += dot(dg_val[1], vtmp) * window[j] res[i] += dg_val[2][i] * window[j] @@ -449,7 +449,7 @@ function shadow_forward(prob::ForwardLSSProblem, sensealg::ForwardLSS, for (j, u) in enumerate(sol.u) vtmp = @view v[((j - 1) * numindvar + 1):(j * numindvar)] # final gradient result for ith parameter - accumulate_cost!(u, uf.p, uf.t, sensealg, diffcache, j) + lss_accumulate_cost!(u, uf.p, uf.t, sensealg, diffcache, j) if dg_val isa Tuple res[i] += dot(dg_val[1], vtmp) * window[j] res[i] += dg_val[2][i] * window[j] @@ -461,7 +461,7 @@ function shadow_forward(prob::ForwardLSSProblem, sensealg::ForwardLSS, return res end -function accumulate_cost!(u, p, t, sensealg::ForwardLSS, diffcache, indx) +function lss_accumulate_cost!(u, p, t, sensealg::ForwardLSS, diffcache, indx) @unpack dgdu, dgdp, dg_val, pgpu, pgpu_config, pgpp, pgpp_config, uf = diffcache if dgdu === nothing diff --git a/src/parameters_handling.jl b/src/parameters_handling.jl index 4fee1ca6a..7bd578515 100644 --- a/src/parameters_handling.jl +++ b/src/parameters_handling.jl @@ -11,7 +11,8 @@ recursive_copyto!(y::Tuple, x::Tuple) = map(recursive_copyto!, y, x) recursive_copyto!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} = map(recursive_copyto!, values(y), values(x)) recursive_copyto!(y::T, x::T) where {T} = fmap(recursive_copyto!, y, x) -recursive_copyto!(y, x::Nothing) = y +recursive_copyto!(y, ::Nothing) = y +recursive_copyto!(::Nothing, ::Nothing) = nothing """ neg!(x) @@ -35,6 +36,7 @@ recursive_sub!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} = NamedTuple{F}(map(recursive_sub!, values(y), values(x))) recursive_sub!(y::T, x::T) where {T} = fmap(recursive_sub!, y, x) recursive_sub!(y, ::Nothing) = y +recursive_sub!(::Nothing, ::Nothing) = nothing """ recursive_add!(y, x) @@ -47,6 +49,7 @@ recursive_add!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} = NamedTuple{F}(recursive_add!(values(y), values(x))) recursive_add!(y::T, x::T) where {T} = fmap(recursive_add!, y, x) recursive_add!(y, ::Nothing) = y +recursive_add!(::Nothing, ::Nothing) = nothing """ allocate_vjp(λ, x) diff --git a/src/staticarrays.jl b/src/staticarrays.jl deleted file mode 100644 index 7e608954f..000000000 --- a/src/staticarrays.jl +++ /dev/null @@ -1,23 +0,0 @@ -### Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so overloaded here -function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::StaticArraysCore.SArray) - dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray - dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) - return project_type(project)(dz...) -end - -### Project SArray to SArray -function ProjectTo(x::StaticArraysCore.SArray{S, T}) where {S, T} - return ProjectTo{StaticArraysCore.SArray}(; element = _eltype_projectto(T), axes = S) -end - -function (project::ProjectTo{StaticArraysCore.SArray})(dx::AbstractArray{S, M}) where {S, M} - return StaticArraysCore.SArray{project.axes}(dx) -end - -### Adjoint for SArray constructor - -function rrule(::Type{T}, x::Tuple) where {T <: StaticArraysCore.SArray} - project_x = ProjectTo(x) - Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) - return T(x), Array_pullback -end diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index 3585399f8..79273f69d 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -15,7 +15,7 @@ end TruncatedStacktraces.@truncate_stacktrace SteadyStateAdjointSensitivityFunction function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, f, - colorvec, needs_jac) + colorvec, needs_jac) @unpack p, u0 = sol.prob diffcache, y = adjointdiffcache(g, sensealg, false, sol, dgdu, dgdp, f, alg; @@ -29,27 +29,22 @@ function SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp λ, vjp, linsolve) end +@inline __needs_concrete_A(l) = LinearSolve.needs_concrete_A(l) +@inline __needs_concrete_A(::Nothing) = false + @noinline function SteadyStateAdjointProblem(sol, sensealg::SteadyStateAdjoint, alg, - dgdu::DG1 = nothing, dgdp::DG2 = nothing, - g::G = nothing; kwargs...) where {DG1, DG2, G} + dgdu::DG1 = nothing, dgdp::DG2 = nothing, g::G = nothing; + kwargs...) where {DG1, DG2, G} @unpack f, p, u0 = sol.prob - if sol.prob isa NonlinearProblem - f = ODEFunction(f) - end + sol.prob isa NonlinearProblem && (f = ODEFunction(f)) dgdu === nothing && dgdp === nothing && g === nothing && error("Either `dgdu`, `dgdp`, or `g` must be specified.") - needs_jac = if has_adjoint(f) - false - # TODO: What is the correct heuristic? Can we afford to compute Jacobian for - # cases where the length(u0) > 50 and if yes till what threshold - elseif sensealg.linsolve === nothing - length(u0) ≤ 50 - else - LinearSolve.needs_concrete_A(sensealg.linsolve) - end + needs_jac = ifelse(has_adjoint(f), false, + ifelse(sensealg.linsolve === nothing, length(u0) ≤ 50, + __needs_concrete_A(sensealg.linsolve))) p === DiffEqBase.NullParameters() && error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") diff --git a/test/HybridNODE.jl b/test/HybridNODE.jl index 9f5a26b2a..70f261c32 100644 --- a/test/HybridNODE.jl +++ b/test/HybridNODE.jl @@ -1,6 +1,5 @@ -using SciMLSensitivity, OrdinaryDiffEq, DiffEqCallbacks, Flux -using Random, Test -using Zygote +using SciMLSensitivity, OrdinaryDiffEq, DiffEqCallbacks, Lux, ComponentArrays +using Optimization, OptimizationOptimisers, Random, Test, Zygote function test_hybridNODE(sensealg) Random.seed!(12345) @@ -9,10 +8,10 @@ function test_hybridNODE(sensealg) t = range(tspan[1], tspan[2], length = datalength) target = 3.0 * (1:datalength) ./ datalength # some dummy data to fit to cbinput = rand(1, datalength) #some external ODE contribution - pmodel = Chain(Dense(2, 10, init = zeros), - Dense(10, 2, init = zeros)) - p, re = Flux.destructure(pmodel) - dudt(u, p, t) = re(p)(u) + pmodel = Chain(Dense(2, 10, init_weight = zeros32), Dense(10, 2, init_weight = zeros32)) + ps, st = Lux.setup(Xoshiro(0), pmodel) + ps = ComponentArray{Float64}(ps) + dudt(u, p, t) = first(pmodel(u, p, st)) # callback changes the first component of the solution every time # t is an integer @@ -23,28 +22,27 @@ function test_hybridNODE(sensealg) callback = PresetTimeCallback(collect(1:datalength), (int) -> affect!(int, cbinput)) # ODE with Callback - prob = ODEProblem(dudt, [0.0, 1.0], tspan, p) + prob = ODEProblem(dudt, [0.0, 1.0], tspan, ps) function predict_n_ode(p) arr = Array(solve(prob, Tsit5(), - p = p, sensealg = sensealg, saveat = 2.0, callback = callback))[1, - 2:2:end] + p = p, sensealg = sensealg, saveat = 2.0, callback = callback))[1, 2:2:end] return arr[1:datalength] end - function loss_n_ode() + function loss_n_ode(p, _) pred = predict_n_ode(p) loss = sum(abs2, target .- pred) ./ datalength end - cb = function () #callback function to observe training - pred = predict_n_ode(p) - display(loss_n_ode()) + cb = function (p, l) #callback function to observe training + @show l + return false end @show sensealg - Flux.train!(loss_n_ode, Flux.params(p), Iterators.repeated((), 20), ADAM(0.005), - cb = cb) - @test loss_n_ode() < 0.5 + res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), ps), + Adam(0.005); callback = cb, maxiters = 200) + @test loss_n_ode(res.u, nothing) < 0.5 println(" ") end @@ -70,14 +68,15 @@ function test_hybridNODE2(sensealg) ode_data = Array(sol)[1:2, 1:end]' ## Make model - dudt2 = Chain(Dense(4, 50, tanh), - Dense(50, 2)) - p, re = Flux.destructure(dudt2) # use this p as the initial condition! + dudt2 = Chain(Dense(4, 50, tanh), Dense(50, 2)) + ps, st = Lux.setup(Xoshiro(0), dudt2) + ps = ComponentArray{Float32}(ps) + function affect!(integrator) integrator.u[3:4] = -3 * integrator.u[1:2] end function ODEfunc(dx, x, p, t) - dx[1:2] .= re(p)(x) + dx[1:2] .= first(dudt2(x, p, st)) dx[3:4] .= 0.0f0 end z0 = u0 @@ -86,34 +85,27 @@ function test_hybridNODE2(sensealg) initial_affect = true) ## Initialize learning functions - function predict_n_ode() - _prob = remake(prob, p = p) - Array(solve(_prob, Tsit5(), u0 = z0, p = p, callback = cb, save_everystep = false, - save_start = true, sensealg = sensealg))[1:2, - :] - end - function loss_n_ode() - pred = predict_n_ode()[1:2, 1:end]' + function predict_n_ode(ps) + Array(solve(prob, Tsit5(), u0 = z0, p = ps, callback = cb, save_everystep = false, + save_start = true, sensealg = sensealg))[1:2, :] + end + function loss_n_ode(ps, _) + pred = predict_n_ode(ps)[1:2, 1:end]' loss = sum(abs2, ode_data .- pred) loss end - loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE - cba = function () #callback function to observe training - pred = predict_n_ode()[1:2, 1:end]' - display(sum(abs2, ode_data .- pred)) + + cba = function (p, loss) #callback function to observe training + @show loss return false end - cba() - - ## Learn - ps = Flux.params(p) - data = Iterators.repeated((), 25) @show sensealg - Flux.train!(loss_n_ode, ps, data, ADAM(0.0025), cb = cba) + res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), ps), + Adam(0.0025); callback = cba, maxiters = 200) - @test loss_n_ode() < 0.5 + @test loss_n_ode(res.u, nothing) < 0.5 println(" ") end @@ -142,14 +134,16 @@ function test_hybridNODE3(sensealg) true_data = reshape(ode_data, (2, length(t), 1)) true_data = convert.(Float32, true_data) callback_data = true_data * 1.0f-3 - train_dataloader = Flux.Data.DataLoader((true_data = true_data, - callback_data = callback_data), batchsize = 1) - dudt2 = Chain(Dense(2, 50, tanh), - Dense(50, 2)) - p, re = Flux.destructure(dudt2) + + data = (true_data[:, :, 1], callback_data[:, :, 1]) + dudt2 = Chain(Dense(2, 50, tanh), Dense(50, 2)) + ps, st = Lux.setup(Xoshiro(0), dudt2) + ps = ComponentArray{Float32}(ps) + function dudt(du, u, p, t) - du .= re(p)(u) + du .= first(dudt2(u, p, st)) end + z0 = Float32[2.0; 0.0] prob = ODEProblem(dudt, z0, tspan) @@ -159,62 +153,44 @@ function test_hybridNODE3(sensealg) DiscreteCallback(condition, affect!, save_positions = (false, false)) end - function predict_n_ode(true_data_0, callback_data, sense) + function predict_n_ode(p, true_data_0, callback_data, sense) _prob = remake(prob, p = p, u0 = true_data_0) solve(_prob, Tsit5(), callback = callback_(callback_data), saveat = t, sensealg = sense) end - function loss_n_ode(true_data, callback_data) - sol = predict_n_ode((vec(true_data[:, 1, :])), callback_data, sensealg) + function loss_n_ode(p, (true_data, callback_data)) + sol = predict_n_ode(p, (vec(true_data[:, 1, :])), callback_data, sensealg) pred = Array(sol) - loss = Flux.mse((true_data[:, :, 1]), pred) + loss = sum(abs2, true_data[:, :, 1] .- pred) loss end - ps = Flux.params(p) - opt = ADAM(0.1) - epochs = 10 - function cb1(true_data, callback_data) - display(loss_n_ode(true_data, callback_data)) + cba = function (p, loss) #callback function to observe training + @show loss return false end - function train!(loss, ps, data, opt, cb) - ps = Params(ps) - for (true_data, callback_data) in data - gs = gradient(ps) do - loss(true_data, callback_data) - end - Flux.update!(opt, ps, gs) - cb(true_data, callback_data) - end - return nothing - end + @show sensealg + + res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), ps, + data), Adam(0.01); maxiters = 1000, callback = cba) + loss = loss_n_ode(res.u, (true_data, callback_data)) - Flux.@epochs epochs train!(loss_n_ode, Params(ps), train_dataloader, opt, cb1) - loss = loss_n_ode(true_data[:, :, 1], callback_data) - @info loss @test loss < 0.5 end -@testset "PresetTimeCallback" begin - test_hybridNODE(ForwardDiffSensitivity()) - test_hybridNODE(BacksolveAdjoint()) - test_hybridNODE(InterpolatingAdjoint()) - test_hybridNODE(QuadratureAdjoint()) +@testset "PresetTimeCallback: $(sensealg)" for sensealg in [ForwardDiffSensitivity(), + BacksolveAdjoint(), InterpolatingAdjoint(), QuadratureAdjoint()] + test_hybridNODE(sensealg) end -@testset "PeriodicCallback" begin - test_hybridNODE2(ReverseDiffAdjoint()) - test_hybridNODE2(BacksolveAdjoint()) - test_hybridNODE2(InterpolatingAdjoint()) - test_hybridNODE2(QuadratureAdjoint()) +@testset "PeriodicCallback: $(sensealg)" for sensealg in [ReverseDiffAdjoint(), + BacksolveAdjoint(), InterpolatingAdjoint(), QuadratureAdjoint()] + test_hybridNODE2(sensealg) end -@testset "tprevCallback" begin - test_hybridNODE3(ReverseDiffAdjoint()) - test_hybridNODE3(BacksolveAdjoint()) - test_hybridNODE3(InterpolatingAdjoint()) - test_hybridNODE3(QuadratureAdjoint()) +@testset "tprevCallback: $(sensealg)" for sensealg in [ReverseDiffAdjoint(), + BacksolveAdjoint(), InterpolatingAdjoint(), QuadratureAdjoint()] + test_hybridNODE3(sensealg) end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 000000000..ce98e9f9e --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,77 @@ +[deps] +AlgebraicMultigrid = "2169fc97-5a83-5252-b627-83903c6c433c" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" +DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" +QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" +StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +AlgebraicMultigrid = "0.6.0" +Aqua = "0.8.4" +Calculus = "0.5.1" +ChainRulesCore = "0.10.7, 1" +ComponentArrays = "0.15.5" +DelayDiffEq = "5.43.2" +DiffEqBase = "6.142.0" +DiffEqCallbacks = "2.34.0" +DiffEqNoiseProcess = "5.19.0" +Distributed = "<0.0.1, 1" +Distributions = "0.25" +FiniteDiff = "2" +ForwardDiff = "0.10" +Functors = "0.4.5" +LinearAlgebra = "<0.0.1, 1" +LinearSolve = "2.20.2" +Lux = "0.5.10" +NLsolve = "4.5.1" +NonlinearSolve = "3.0.1" +Optimization = "3.19.3" +OptimizationOptimisers = "0.1.6" +OrdinaryDiffEq = "6.60.0" +Pkg = "<0.0.1, 1" +PreallocationTools = "0.4" +QuadGK = "2.9.1" +Random = "<0.0.1, 1" +RecursiveArrayTools = "2.4.2" +ReverseDiff = "1.15.1" +SafeTestsets = "0.1.0" +SparseArrays = "<0.0.1, 1" +StaticArrays = "1.8.0" +Statistics = "<0.0.1, 1" +SteadyStateDiffEq = "2.0.1" +StochasticDiffEq = "6.63.2" +Test = "<0.0.1, 1" +Tracker = "0.2.30" +Zygote = "0.6.67" diff --git a/test/adjoint_oop.jl b/test/adjoint_oop.jl index 1e9aee892..f775be137 100644 --- a/test/adjoint_oop.jl +++ b/test/adjoint_oop.jl @@ -1,5 +1,4 @@ -using SciMLSensitivity, OrdinaryDiffEq, StaticArrays, QuadGK, ForwardDiff, - Zygote +using SciMLSensitivity, OrdinaryDiffEq, StaticArrays, QuadGK, ForwardDiff, Zygote using Test ##StaticArrays rrule @@ -59,20 +58,17 @@ sol = solve(prob, Tsit5(), saveat = tsteps, abstol = 1e-14, reltol = 1e-14) dg_disc(u, p, t, i; outtype = nothing) = u du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, - sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, - autojacvec = ZygoteVJP())) + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, autojacvec = ZygoteVJP())) @test !iszero(du0) @test !iszero(dp) # adj_prob = ODEAdjointProblem(sol, - QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, - autojacvec = SciMLSensitivity.ZygoteVJP()), + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, autojacvec = ZygoteVJP()), Tsit5(), tsteps, dg_disc) adj_sol = solve(adj_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) integrand = AdjointSensitivityIntegrand(sol, adj_sol, - QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, - autojacvec = SciMLSensitivity.ZygoteVJP())) + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, autojacvec = ZygoteVJP())) res, err = quadgk(integrand, 0.0, 5.0, atol = 1e-14, rtol = 1e-14) @test adj_sol[end]≈du0 rtol=1e-12 @@ -96,9 +92,7 @@ function dg_disc(du, u, p, t, i) end du1, dp1 = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, - sensealg = QuadratureAdjoint(abstol = 1e-14, - reltol = 1e-14, - autojacvec = ZygoteVJP())) + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, autojacvec = ZygoteVJP())) @test du0≈du1 rtol=1e-12 @test dp≈dp1 rtol=1e-12 @@ -145,20 +139,17 @@ function dg(u, p, t) end du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = g, - sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, - autojacvec = ZygoteVJP())) + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, autojacvec = ZygoteVJP())) @test !iszero(du0) @test !iszero(dp) adj_prob = ODEAdjointProblem(sol, - QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, - autojacvec = SciMLSensitivity.ZygoteVJP()), + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, autojacvec = ZygoteVJP()), Tsit5(), nothing, nothing, nothing, dg, nothing, g) adj_sol = solve(adj_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) integrand = AdjointSensitivityIntegrand(sol, adj_sol, - QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, - autojacvec = SciMLSensitivity.ZygoteVJP())) + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, autojacvec = ZygoteVJP())) res, err = quadgk(integrand, 0.0, 5.0, atol = 1e-14, rtol = 1e-14) @test adj_sol[end]≈du0 rtol=1e-12 @@ -191,10 +182,8 @@ f_dp = ForwardDiff.gradient(G_p, p) ## concrete solve du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob, Tsit5(), u0, p, - abstol = 1e-10, reltol = 1e-10, - saveat = tsteps, - sensealg = QuadratureAdjoint(abstol = 1e-14, - reltol = 1e-14, + abstol = 1e-10, reltol = 1e-10, saveat = tsteps, + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, autojacvec = ZygoteVJP()))), u0, p) diff --git a/test/adjoint_param.jl b/test/adjoint_param.jl index 5f5b9fdcd..fca3af9f5 100644 --- a/test/adjoint_param.jl +++ b/test/adjoint_param.jl @@ -1,9 +1,4 @@ -using Test -using OrdinaryDiffEq -using SciMLSensitivity -using ForwardDiff -using QuadGK -using Zygote +using Test, OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, QuadGK, Zygote abstol = 1e-12 reltol = 1e-12 @@ -31,8 +26,7 @@ res_quad = adjoint_sensitivities(sol, Vern9(), dgdu_continuous = dgdu, reltol = reltol, sensealg = QuadratureAdjoint()) res_back = adjoint_sensitivities(sol, Vern9(), dgdu_continuous = dgdu, dgdp_continuous = dgdp, abstol = abstol, - reltol = reltol, - sensealg = BacksolveAdjoint(checkpointing = true)) # it's blowing up + reltol = reltol, sensealg = BacksolveAdjoint(checkpointing = true)) function G(p) tmp_prob = remake(prob, p = p, u0 = convert.(eltype(p), prob.u0)) diff --git a/test/adjoint_shapes.jl b/test/adjoint_shapes.jl index dc74dea11..0c6510a1f 100644 --- a/test/adjoint_shapes.jl +++ b/test/adjoint_shapes.jl @@ -21,16 +21,13 @@ end policy_params = ones(2, 2) z0 = zeros(3) fwd_sol = solve(ODEProblem(aug_dynamics!, z0, (0.0, 1.0), policy_params), - Tsit5(), - u0 = z0, - p = policy_params) + Tsit5(), u0 = z0, p = policy_params) sensealg = InterpolatingAdjoint() -sensealg = SciMLSensitivity.setvjp(sensealg, SciMLSensitivity.inplace_vjp(fwd_sol.prob, fwd_sol.prob.u0, fwd_sol.prob.p, true)) +sensealg = SciMLSensitivity.setvjp(sensealg, + SciMLSensitivity.inplace_vjp(fwd_sol.prob, fwd_sol.prob.u0, fwd_sol.prob.p, true)) -solve(ODEAdjointProblem(fwd_sol, - sensealg, - Tsit5(), +solve(ODEAdjointProblem(fwd_sol, sensealg, Tsit5(), [1.0], (out, x, p, t, i) -> (out .= 1)), Tsit5()) A = ones(2, 2) @@ -47,13 +44,8 @@ end policy_params = ones(2, 2) z0 = zeros(3) -fwd_sol = solve(ODEProblem(aug_dynamics!, z0, (0.0, 1.0), policy_params), - u0 = z0, +fwd_sol = solve(ODEProblem(aug_dynamics!, z0, (0.0, 1.0), policy_params), u0 = z0, p = policy_params, Tsit5()) -solve(ODEAdjointProblem(fwd_sol, - sensealg, - Tsit5(), - [1.0], - (out, x, p, t, i) -> (out .= 1)), - Tsit5()) +solve(ODEAdjointProblem(fwd_sol, sensealg, Tsit5(), [1.0], + (out, x, p, t, i) -> (out .= 1)), Tsit5()) diff --git a/test/aqua.jl b/test/aqua.jl new file mode 100644 index 000000000..f713e96a8 --- /dev/null +++ b/test/aqua.jl @@ -0,0 +1,13 @@ +using SciMLSensitivity, Aqua, DiffEqBase + +@testset "Aqua" begin + Aqua.find_persistent_tasks_deps(SciMLSensitivity) + Aqua.test_ambiguities(SciMLSensitivity, recursive = false) + Aqua.test_deps_compat(SciMLSensitivity) + Aqua.test_piracies(SciMLSensitivity; treat_as_own = [DiffEqBase._concrete_solve_adjoint, + DiffEqBase._concrete_solve_forward]) + Aqua.test_project_extras(SciMLSensitivity) + Aqua.test_stale_deps(SciMLSensitivity) + Aqua.test_unbound_args(SciMLSensitivity) + Aqua.test_undefined_exports(SciMLSensitivity) +end diff --git a/test/callback_reversediff.jl b/test/callback_reversediff.jl index fcde31270..5d188468f 100644 --- a/test/callback_reversediff.jl +++ b/test/callback_reversediff.jl @@ -1,4 +1,5 @@ -using OrdinaryDiffEq, Flux, SciMLSensitivity, DiffEqCallbacks, Test +using OrdinaryDiffEq, Lux, ComponentArrays, SciMLSensitivity, DiffEqCallbacks, Test +using Optimization, OptimizationOptimisers, Zygote using Random Random.seed!(1234) @@ -28,13 +29,13 @@ t = range(tspan[1], tspan[2], length = datasize) prob = ODEProblem(trueODEfunc, u0, tspan) ode_data = Array(solve(prob, Tsit5(), callback = CallbackSet(cbPreTime, cbFctCall), saveat = t)) -dudt2 = Chain(Dense(2, 50, tanh), - Dense(50, 2)) -p, re = Flux.destructure(dudt2) # use this p as the initial condition! +dudt2 = Chain(Dense(2, 50, tanh), Dense(50, 2)) +ps, st = Lux.setup(Random.default_rng(), dudt2) +ps = ComponentArray(ps) function dudt(du, u, p, t) du[1:2] .= -u[1:2] - du[3:end] .= re(p)(u[1:2]) #re(p)(u[3:end]) + du[3:end] .= first(dudt2(u[1:2], p, st)) #re(p)(u[3:end]) end z0 = Float32[u0; u0] prob = ODEProblem(dudt, z0, tspan) @@ -42,33 +43,24 @@ prob = ODEProblem(dudt, z0, tspan) affect!(integrator) = integrator.u[1:2] .= integrator.u[3:end] cb = PresetTimeCallback(dosetimes, affect!, save_positions = (false, false)) -function predict_n_ode() - _prob = remake(prob, p = p) - Array(solve(_prob, Tsit5(), u0 = z0, p = p, callback = cb, saveat = t, - sensealg = ReverseDiffAdjoint()))[1:2, - :] - #Array(solve(prob,Tsit5(),u0=z0,p=p,saveat=t))[1:2,:] +function predict_n_ode(ps) + Array(solve(prob, Tsit5(), u0 = z0, p = ps, callback = cb, saveat = t, + sensealg = ReverseDiffAdjoint()))[1:2, :] end -function loss_n_ode() - pred = predict_n_ode() +function loss_n_ode(ps, _) + pred = predict_n_ode(ps) loss = sum(abs2, ode_data .- pred) loss end -loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE +loss_n_ode(ps, nothing) -cba = function (; doplot = false) #callback function to observe training - pred = predict_n_ode() - display(sum(abs2, ode_data .- pred)) - # plot current prediction against data - #pl = scatter(t,ode_data[1,:],label="data") - #scatter!(pl,t,pred[1,:],label="prediction") - #display(plot(pl)) +cb1 = function (p, l) + @show l return false end -cba() -ps = Flux.params(p) -data = Iterators.repeated((), 200) -Flux.train!(loss_n_ode, ps, data, ADAM(0.05), cb = cba) -@test loss_n_ode() < 0.4 +res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), ps), + Adam(0.05); callback = cb1, maxiters = 100) + +@test loss_n_ode(res.u, nothing) < 0.4 diff --git a/test/complex_no_u.jl b/test/complex_no_u.jl index 88f53fbff..2beb0603d 100644 --- a/test/complex_no_u.jl +++ b/test/complex_no_u.jl @@ -1,21 +1,23 @@ -using OrdinaryDiffEq, SciMLSensitivity, LinearAlgebra, Optimization, OptimizationFlux, Flux +using OrdinaryDiffEq, ComponentArrays, Random, + SciMLSensitivity, LinearAlgebra, Optimization, OptimizationOptimisers, Lux nn = Chain(Dense(1, 16), Dense(16, 16, tanh), Dense(16, 2)) |> f64 -initial, re = Flux.destructure(nn) +ps, st = Lux.setup(Random.default_rng(), nn) +ps = ComponentArray(ps) function ode2!(u, p, t) - f1, f2 = re(p)([t]) .+ im + f1, f2 = first(nn([t], p, st)) .+ im [-f1^2; f2] end tspan = (0.0, 10.0) -prob = ODEProblem(ode2!, Complex{Float64}[0; 0], tspan, initial) +prob = ODEProblem(ode2!, Complex{Float64}[0; 0], tspan, ps) -function loss(p) +loss = function (p) sol = last(solve(prob, Tsit5(), p = p, sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP(allow_nothing = true)))) return norm(sol) end optf = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote()) -optprob = Optimization.OptimizationProblem(optf, initial) -res = Optimization.solve(optprob, ADAM(0.1), maxiters = 100) +optprob = Optimization.OptimizationProblem(optf, ps) +res = Optimization.solve(optprob, Adam(0.1), maxiters = 100) diff --git a/test/distributed.jl b/test/distributed.jl index 4f3908a8e..f68d9d334 100644 --- a/test/distributed.jl +++ b/test/distributed.jl @@ -1,4 +1,4 @@ -using Distributed, Flux +using Distributed, Optimization, OptimizationOptimisers addprocs(2) @everywhere begin @@ -8,7 +8,9 @@ addprocs(2) u0 = [3.0] end -function model4() +function model_distributed(pu0) + pa = pu0[1:1] + u0 = pu0[2:2] prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) function prob_func(prob, i, repeat) @@ -21,19 +23,16 @@ function model4() end # loss function -loss() = sum(abs2, 1.0 .- Array(model4())) +loss = (p, _) -> sum(abs2, 1.0 .- Array(model_distributed(p))) -data = Iterators.repeated((), 10) - -cb = function () # callback function to observe training - @show loss() +cb = function (p, l) # callback function to observe training + @info loss=l + return false end -pa = [1.0] -u0 = [3.0] -opt = Flux.ADAM(0.1) -println("Starting to train") -l1 = loss() -Flux.@epochs 10 Flux.train!(loss, Flux.params([pa, u0]), data, opt; cb = cb) -l2 = loss() +l1 = loss([1.0, 3.0], nothing) +@show l1 +res = solve(OptimizationProblem(OptimizationFunction(loss, AutoZygote()), + [1.0, 3.0]), Adam(0.1); callback = cb, maxiters = 100) +l2 = loss(res.u, nothing) @test 10l2 < l1 diff --git a/test/ensembles.jl b/test/ensembles.jl index c0c997039..15edea4d1 100644 --- a/test/ensembles.jl +++ b/test/ensembles.jl @@ -1,81 +1,34 @@ -using Flux, OrdinaryDiffEq, Test - -pa = [1.0] -u0 = [3.0] -function model1() - prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) +using SciMLSensitivity, OrdinaryDiffEq, Optimization, OptimizationOptimisers, Test, Zygote +@testset "$(i): EnsembleAlg = $(alg)" for (i, alg) in enumerate((EnsembleSerial(), + EnsembleThreads(), EnsembleSerial())) function prob_func(prob, i, repeat) remake(prob, u0 = 0.5 .+ i / 100 .* prob.u0) end + function model(p) + prob = ODEProblem((u, p, t) -> 1.01u .* p, p[1:1], (0.0, 1.0), p[2:2]) - ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) - sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, trajectories = 100) -end - -# loss function -loss() = sum(abs2, 1.0 .- Array(model1())) - -data = Iterators.repeated((), 10) - -cb = function () # callback function to observe training - @show loss() -end - -opt = ADAM(0.1) -println("Starting to train") -l1 = loss() -Flux.@epochs 10 Flux.train!(loss, Flux.params([pa, u0]), data, opt; cb = cb) -l2 = loss() -@test 10l2 < l1 - -function model2() - prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) - - function prob_func(prob, i, repeat) - remake(prob, u0 = 0.5 .+ i / 100 .* prob.u0) + ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) + sim = solve(ensemble_prob, Tsit5(), alg, saveat = 0.1, trajectories = 100) + return i == 3 ? sim.u : sim end - ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) - sim = solve(ensemble_prob, Tsit5(), EnsembleSerial(), saveat = 0.1, - trajectories = 100).u -end -loss() = sum(abs2, [sum(abs2, 1.0 .- u) for u in model2()]) - -pa = [1.0] -u0 = [3.0] -opt = ADAM(0.1) -println("Starting to train") -l1 = loss() -Flux.@epochs 10 Flux.train!(loss, Flux.params([pa, u0]), data, opt; cb = cb) -l2 = loss() -@test 10l2 < l1 - -function model3() - prob = ODEProblem((u, p, t) -> 1.01u .* p, u0, (0.0, 1.0), pa) - - function prob_func(prob, i, repeat) - remake(prob, u0 = 0.5 .+ i / 100 .* prob.u0) + # loss function + loss = if i == 3 + (p, _) -> sum(abs2, [sum(abs2, 1.0 .- u) for u in model(p)]) + else + (p, _) -> sum(abs2, 1.0 .- Array(model(p))) end - ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) - sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), saveat = 0.1, trajectories = 100) -end - -# loss function -loss() = sum(abs2, 1.0 .- Array(model3())) - -data = Iterators.repeated((), 10) + cb = function (p, l) # callback function to observe training + @info alg=alg loss=l + return false + end -cb = function () # callback function to observe training - @show loss() + l1 = loss([1.0, 3.0], nothing) + @show l1 + res = solve(OptimizationProblem(OptimizationFunction(loss, AutoZygote()), + [1.0, 3.0]), Adam(0.1); callback = cb, maxiters = 10) + l2 = loss(res.u, nothing) + @test 10l2 < l1 end - -pa = [1.0] -u0 = [3.0] -opt = ADAM(0.1) -println("Starting to train") -l1 = loss() -Flux.@epochs 10 Flux.train!(loss, Flux.params([pa, u0]), data, opt; cb = cb) -l2 = loss() -@test 10l2 < l1 diff --git a/test/forward_remake.jl b/test/forward_remake.jl index ff439a254..54b3c7108 100644 --- a/test/forward_remake.jl +++ b/test/forward_remake.jl @@ -1,5 +1,4 @@ -using SciMLSensitivity, ForwardDiff, Distributions, OrdinaryDiffEq, - LinearAlgebra, Test +using SciMLSensitivity, ForwardDiff, Distributions, OrdinaryDiffEq, LinearAlgebra, Test function fiip(du, u, p, t) du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] diff --git a/test/gdp_regression_test.jl b/test/gdp_regression_test.jl index 828767392..2dd8fed45 100644 --- a/test/gdp_regression_test.jl +++ b/test/gdp_regression_test.jl @@ -1,4 +1,6 @@ -using SciMLSensitivity, Flux, OrdinaryDiffEq, LinearAlgebra, Test +using SciMLSensitivity, + OrdinaryDiffEq, LinearAlgebra, Test, Zygote, Optimization, + OptimizationOptimisers GDP = [ 11394358246872.6, @@ -76,76 +78,35 @@ if false else ## false crashes. that is when i am tracking the initial conditions prob = ODEProblem(monomial, u0, tspan, p) end -function predict_rd() # Our 1-layer neural network - Array(solve(prob, Tsit5(), p = p, saveat = 1.0:1.0:59.0, reltol = 1e-4, - sensealg = TrackerAdjoint())) -end -function loss_rd() ##L2 norm biases the newer times unfairly - ##Graph looks better if we minimize relative error squared - c = 0.0 - a = predict_rd() - d = 0.0 - for i in 1:59 - c += (a[i][1] / GDP[i] - 1)^2 ## L2 of relative error +@testset "sensealg: $(sensealg)" for sensealg in (TrackerAdjoint(), + InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true))) + function predict_rd(pu0) # Our 1-layer neural network + p = pu0[1:6] + u0 = pu0[7:7] + Array(solve(prob, Tsit5(); p = p, u0 = u0, saveat = 1.0:1.0:59.0, reltol = 1e-4, + sensealg)) end - c + 3 * d -end - -data = Iterators.repeated((), 100) -opt = ADAM(0.01) - -peek = function () #callback function to observe training - #reduces training speed by a lot - println("Loss: ", loss_rd()) -end - -peek() -Flux.train!(loss_rd, Flux.params(p, u0), data, opt, cb = peek) -peek() - -@test loss_rd() < 0.2 - -function monomial(dcGDP, cGDP, parameters, t) - α1, β1, nu1, nu2, δ, δ2 = parameters - dcGDP[1] = α1 * ((cGDP[1]))^β1 -end - -GDP0 = GDP[1] -tspan = (1.0, 59.0) -p = [474.8501513113645, 0.7036417845990167, 0.0, 1e-10, 1e-10, 1e-10] -u0 = [GDP0] -if false - prob = ODEProblem(monomial, [GDP0], tspan, p) -else ## false crashes. that is when i am tracking the initial conditions - prob = ODEProblem(monomial, u0, tspan, p) -end -function predict_adjoint() # Our 1-layer neural network - Array(solve(prob, Tsit5(), p = p, saveat = 1.0, reltol = 1e-4, - sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)))) -end + function loss_rd(pu0, _) ##L2 norm biases the newer times unfairly + ##Graph looks better if we minimize relative error squared + c = 0.0 + a = predict_rd(pu0) + d = 0.0 + for i in 1:59 + c += (a[i][1] / GDP[i] - 1)^2 ## L2 of relative error + end + c + 3 * d + end -function loss_adjoint() ##L2 norm biases the newer times unfairly - ##Graph looks better if we minimize relative error squared - c = 0.0 - a = predict_adjoint() - d = 0.0 - for i in 1:59 - c += (a[i][1] / GDP[i] - 1)^2 ## L2 of relative error + peek = function (p, l) #callback function to observe training + #reduces training speed by a lot + println("$(sensealg) Loss: ", l) + return false end - c + 3 * d -end -data = Iterators.repeated((), 100) -opt = ADAM(0.01) + res = solve(OptimizationProblem(OptimizationFunction(loss_rd, AutoZygote()), + vcat(p, u0)), Adam(0.01); callback = peek, maxiters = 100) -peek = function () #callback function to observe training - #reduces training speed by a lot - println("Loss: ", loss_adjoint()) + @test loss_rd(res.u, nothing) < 0.2 end - -peek() -Flux.train!(loss_adjoint, Flux.params(p, u0), data, opt, cb = peek) -peek() -@test loss_adjoint() < 0.2 diff --git a/test/gpu/Project.toml b/test/gpu/Project.toml index 4fae2fbbb..9e7a27ba9 100644 --- a/test/gpu/Project.toml +++ b/test/gpu/Project.toml @@ -2,10 +2,10 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" [compat] CUDA = "3.12, 4, 5" DiffEqCallbacks = "2.24" -DiffEqFlux = "1.52, 2" -Flux = "0.13, 0.14" +DiffEqFlux = "3" +LuxCUDA = "0.3.1" diff --git a/test/gpu/diffeqflux_standard_gpu.jl b/test/gpu/diffeqflux_standard_gpu.jl index 5bb264b4d..d82086ff2 100644 --- a/test/gpu/diffeqflux_standard_gpu.jl +++ b/test/gpu/diffeqflux_standard_gpu.jl @@ -1,6 +1,10 @@ -using SciMLSensitivity, OrdinaryDiffEq, Flux, DiffEqFlux, CUDA, Zygote +using SciMLSensitivity, OrdinaryDiffEq, Lux, DiffEqFlux, LuxCUDA, Zygote, Random +using ComponentArrays CUDA.allowscalar(false) # Makes sure no slow operations are occuring +const gdev = gpu_device() +const cdev = cpu_device() + # Generate Data u0 = Float32[2.0; 0.0] datasize = 30 @@ -12,28 +16,22 @@ function trueODEfunc(du, u, p, t) end prob_trueode = ODEProblem(trueODEfunc, u0, tspan) # Make the data into a GPU-based array if the user has a GPU -ode_data = gpu(solve(prob_trueode, Tsit5(), saveat = tsteps)) - -dudt2 = Chain(x -> x .^ 3, - Dense(2, 50, tanh), - Dense(50, 2)) |> gpu -u0 = Float32[2.0; 0.0] |> gpu +ode_data = gdev(solve(prob_trueode, Tsit5(), saveat = tsteps)) -_p, re = Flux.destructure(dudt2) -p = gpu(_p) +dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2)) +u0 = Float32[2.0; 0.0] |> gdev prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) +ps, st = Lux.setup(Random.default_rng(), dudt2) +ps = ComponentArray(ps) |> gdev function predict_neuralode(p) - gpu(prob_neuralode(u0, p)) + gdev(first(prob_neuralode(u0, p, st))) end function loss_neuralode(p) pred = predict_neuralode(p) loss = sum(abs2, ode_data .- pred) return loss end -# Callback function to observe training -list_plots = [] -iter = 0 -Zygote.gradient(loss_neuralode, p) +Zygote.gradient(loss_neuralode, ps) diff --git a/test/gpu/mixed_gpu_cpu_adjoint.jl b/test/gpu/mixed_gpu_cpu_adjoint.jl index ef943ec09..578e4e990 100644 --- a/test/gpu/mixed_gpu_cpu_adjoint.jl +++ b/test/gpu/mixed_gpu_cpu_adjoint.jl @@ -1,5 +1,5 @@ using SciMLSensitivity, OrdinaryDiffEq -using Lux, CUDA, Test, Zygote, Random, LinearAlgebra, ComponentArrays +using Lux, LuxCUDA, Test, Zygote, Random, LinearAlgebra, ComponentArrays CUDA.allowscalar(false) diff --git a/test/hybrid_de.jl b/test/hybrid_de.jl index 06cdc26a8..c4520f86d 100644 --- a/test/hybrid_de.jl +++ b/test/hybrid_de.jl @@ -1,4 +1,5 @@ -using Flux, SciMLSensitivity, DiffEqCallbacks, OrdinaryDiffEq, Test # , Plots +using Lux, ComponentArrays, SciMLSensitivity, DiffEqCallbacks, OrdinaryDiffEq, Test # , Plots +using Optimization, OptimizationOptimisers, Zygote, Random u0 = Float32[2.0; 0.0] datasize = 100 @@ -16,13 +17,13 @@ t = range(tspan[1], tspan[2], length = datasize) prob = ODEProblem(trueODEfunc, u0, tspan) ode_data = Array(solve(prob, Tsit5(), callback = cb_, saveat = t)) -dudt2 = Chain(Dense(2, 50, tanh), - Dense(50, 2)) -p, re = Flux.destructure(dudt2) # use this p as the initial condition! +dudt2 = Chain(Dense(2, 50, tanh), Dense(50, 2)) +ps, st = Lux.setup(Random.default_rng(), dudt2) +ps = ComponentArray(ps) function dudt(du, u, p, t) du[1:2] .= -u[1:2] - du[3:end] .= re(p)(u[1:2]) #re(p)(u[3:end]) + du[3:end] .= first(dudt2(u[1:2], p, st)) end z0 = Float32[u0; u0] prob = ODEProblem(dudt, z0, tspan) @@ -30,33 +31,23 @@ prob = ODEProblem(dudt, z0, tspan) affect!(integrator) = integrator.u[1:2] .= integrator.u[3:end] cb = PresetTimeCallback(dosetimes, affect!, save_positions = (false, false)) -function predict_n_ode() - _prob = remake(prob, p = p) - Array(solve(_prob, Tsit5(), u0 = z0, p = p, callback = cb, saveat = t, - sensealg = ReverseDiffAdjoint()))[1:2, - :] - # Array(solve(prob,Tsit5(),u0=z0,p=p,saveat=t))[1:2,:] +function predict_n_ode(ps) + Array(solve(prob, Tsit5(), u0 = z0, p = ps, callback = cb, saveat = t, + sensealg = ReverseDiffAdjoint()))[1:2, :] end -function loss_n_ode() - pred = predict_n_ode() +function loss_n_ode(ps, _) + pred = predict_n_ode(ps) loss = sum(abs2, ode_data .- pred) loss end -loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE - -cba = function (; doplot = false) #callback function to observe training - pred = predict_n_ode() - display(sum(abs2, ode_data .- pred)) - # plot current prediction against data - # pl = scatter(t,ode_data[1,:],label="data") - # scatter!(pl,t,pred[1,:],label="prediction") - # display(plot(pl)) +loss_n_ode(ps, nothing) + +cba = function (p, l) #callback function to observe training + @show l return false end -cba() -ps = Flux.params(p) -data = Iterators.repeated((), 200) -Flux.train!(loss_n_ode, ps, data, ADAM(0.05), cb = cba) -loss_n_ode() < 1.0 +res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), + ps), Adam(0.05); callback = cba, maxiters = 200) +@test loss_n_ode(res.u, nothing) < 0.4 diff --git a/test/layers.jl b/test/layers.jl index 44eafb5f9..90fc97f83 100644 --- a/test/layers.jl +++ b/test/layers.jl @@ -1,4 +1,4 @@ -using SciMLSensitivity, Flux, Zygote, OrdinaryDiffEq, Test # , Plots +using SciMLSensitivity, Zygote, OrdinaryDiffEq, Test, Optimization, OptimizationOptimisers function lotka_volterra(du, u, p, t) x, y = u @@ -10,80 +10,27 @@ p = [2.2, 1.0, 2.0, 0.4] u0 = [1.0, 1.0] prob = ODEProblem(lotka_volterra, u0, (0.0, 10.0), p) -# Reverse-mode - -function predict_rd(p) - Array(solve(prob, Tsit5(), p = p, saveat = 0.1, reltol = 1e-4, - sensealg = TrackerAdjoint())) -end -loss_rd(p) = sum(abs2, x - 1 for x in predict_rd(p)) -loss_rd() = sum(abs2, x - 1 for x in predict_rd(p)) -loss_rd() - -grads = Zygote.gradient(loss_rd, p) -@test !iszero(grads[1]) - -opt = ADAM(0.1) -cb = function () - display(loss_rd()) - # display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6))) -end - -# Display the ODE with the current parameter values. -loss1 = loss_rd() -Flux.train!(loss_rd, Flux.params(p), Iterators.repeated((), 100), opt, cb = cb) -loss2 = loss_rd() -@test 10loss2 < loss1 - -# Forward-mode, R^n -> R^m layer - -p = [2.2, 1.0, 2.0, 0.4] -function predict_fd() - vec(Array(solve(prob, Tsit5(), p = p, saveat = 0.0:0.1:1.0, reltol = 1e-4, - sensealg = ForwardDiffSensitivity()))) -end -loss_fd() = sum(abs2, x - 1 for x in predict_fd()) -loss_fd() - -ps = Flux.params(p) -grads = Zygote.gradient(loss_fd, ps) -@test !iszero(grads[p]) - -data = Iterators.repeated((), 100) -opt = ADAM(0.1) -cb = function () - display(loss_fd()) - # display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6))) -end - -# Display the ODE with the current parameter values. -loss1 = loss_fd() -Flux.train!(loss_fd, ps, data, opt, cb = cb) -loss2 = loss_fd() -@test 10loss2 < loss1 - -# Adjoint sensitivity -p = [2.2, 1.0, 2.0, 0.4] -ps = Flux.params(p) -function predict_adjoint() - solve(remake(prob, p = p), Tsit5(), saveat = 0.1, reltol = 1e-4) +@testset "sensealg: $(sensealg)" for sensealg in (TrackerAdjoint(), + ForwardDiffSensitivity(), nothing) + function predict(pu0) + p = pu0[1:4] + u0 = pu0[5:6] + vec(Array(solve(prob, Tsit5(); p, u0, saveat = 0.1, reltol = 1e-4, sensealg))) + end + loss(pu0, _) = sum(abs2, x .- 1 for x in predict(pu0)) + + grads = Zygote.gradient(loss, [p; u0], nothing) + @test !iszero(grads[1]) + + cb = function (p, l) + @info sensealg loss=l + return false + end + + l1 = loss([p; u0], nothing) + @show l1 + res = solve(OptimizationProblem(OptimizationFunction(loss, AutoZygote()), + [p; u0]), Adam(0.1); callback = cb, maxiters = 100) + l2 = loss(res.u, nothing) + @test 10l2 < l1 end -loss_reduction(sol) = sum(abs2, x - 1 for x in vec(sol)) -loss_adjoint() = loss_reduction(predict_adjoint()) -loss_adjoint() - -grads = Zygote.gradient(loss_adjoint, ps) -@test !iszero(grads[p]) - -data = Iterators.repeated((), 100) -opt = ADAM(0.1) -cb = function () - display(loss_adjoint()) - # display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6))) -end - -# Display the ODE with the current parameter values. -loss1 = loss_adjoint() -Flux.train!(loss_adjoint, ps, data, opt, cb = cb) -loss2 = loss_adjoint() -@test 10loss2 < loss1 diff --git a/test/layers_dde.jl b/test/layers_dde.jl index bec7bf4dc..064c04415 100644 --- a/test/layers_dde.jl +++ b/test/layers_dde.jl @@ -1,4 +1,4 @@ -using SciMLSensitivity, Flux, Zygote, DelayDiffEq, Test +using SciMLSensitivity, Zygote, DelayDiffEq, Test ## Setup DDE to optimize function delay_lotka_volterra(du, u, h, p, t) @@ -12,8 +12,7 @@ prob = DDEProblem(delay_lotka_volterra, [1.0, 1.0], h, (0.0, 10.0), constant_lag p = [2.2, 1.0, 2.0, 0.4] function predict_fd_dde(p) solve(prob, MethodOfSteps(Tsit5()), p = p, saveat = 0.0:0.1:10.0, reltol = 1e-4, - sensealg = ForwardDiffSensitivity())[1, - :] + sensealg = ForwardDiffSensitivity())[1, :] end loss_fd_dde(p) = sum(abs2, x - 1 for x in predict_fd_dde(p)) loss_fd_dde(p) diff --git a/test/layers_sde.jl b/test/layers_sde.jl index dd13b1517..44c2295e9 100644 --- a/test/layers_sde.jl +++ b/test/layers_sde.jl @@ -1,4 +1,4 @@ -using SciMLSensitivity, Flux, Zygote, StochasticDiffEq, Test +using SciMLSensitivity, Zygote, StochasticDiffEq, Test function lotka_volterra(du, u, p, t) x, y = u diff --git a/test/runtests.jl b/test/runtests.jl index 359c2d7f7..a25d8ab04 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -168,6 +168,9 @@ end @time @safetestset "Parameter Handling" begin include("parameter_handling.jl") end + @time @safetestset "Quality Assurance" begin + include("aqua.jl") + end end if GROUP == "All" || GROUP == "SDE1" diff --git a/test/sde_neural.jl b/test/sde_neural.jl index 8e296f9af..8356114c3 100644 --- a/test/sde_neural.jl +++ b/test/sde_neural.jl @@ -1,11 +1,7 @@ -using SciMLSensitivity, Flux, LinearAlgebra -using DiffEqNoiseProcess -using StochasticDiffEq -using Statistics -using SciMLSensitivity +using SciMLSensitivity, Lux, ComponentArrays, LinearAlgebra, DiffEqNoiseProcess, Test +using StochasticDiffEq, Statistics, SciMLSensitivity, Zygote using DiffEqBase.EnsembleAnalysis -using Zygote -using Optimization, OptimizationFlux +using Optimization, OptimizationOptimisers using Random Random.seed!(238248735) @@ -48,13 +44,14 @@ Random.seed!(238248735) (truemean, truevar) = Array.(timeseries_steps_meanvar(solution)) ann = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 2)) - α, re = Flux.destructure(ann) + α, st = Lux.setup(Random.default_rng(), ann) + α = ComponentArray(α) α = Float64.(α) function dudt_(du, u, p, t) r, e, μ, h, ph, z, i = p_ - MM = re(p)(u) + MM = first(ann(u, p, st)) du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series @@ -66,7 +63,7 @@ Random.seed!(238248735) function dudt_op(u, p, t) r, e, μ, h, ph, z, i = p_ - MM = re(p)(u) + MM = first(ann(u, p, st)) [e * 0.5 * (5μ - u[1]), # nutrient input time series e * 0.05 * (10μ - u[2]), # grazer density time series @@ -132,14 +129,14 @@ Random.seed!(238248735) optf = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, α) - res1 = Optimization.solve(optprob, ADAM(0.001), callback = callback, maxiters = 200) + res1 = Optimization.solve(optprob, Adam(0.001), callback = callback, maxiters = 200) println("Test non-mutating form") optf = Optimization.OptimizationFunction((x, p) -> loss_op(x), Optimization.AutoZygote()) optprob = Optimization.OptimizationProblem(optf, α) - res2 = Optimization.solve(optprob, ADAM(0.001), callback = callback, maxiters = 200) + res2 = Optimization.solve(optprob, Adam(0.001), callback = callback, maxiters = 200) end @testset "Adaptive neural SDE" begin @@ -149,8 +146,9 @@ end # Define Neural Network for the control input input_size = x_size + 1 # size of the spatial dimensions PLUS one time dimensions nn_initial = Chain(Dense(input_size, v_size)) # The actual neural network - p_nn, model = Flux.destructure(nn_initial) - nn(x, p) = model(p)(x) + ps, st = Lux.setup(Random.default_rng(), nn_initial) + ps = ComponentArray(ps) + nn(x, p) = first(nn_initial(x, p, st)) # The neural network function # Define the right hand side of the SDE const_mat = zeros(Float64, (x_size, v_size)) @@ -171,10 +169,10 @@ end u0 = vec(rand(Float64, (x_size, 1))) tspan = (0.0, 1.0) ts = collect(0:0.1:1) - prob = SDEProblem{true}(f!, g!, u0, tspan, p_nn) + prob = SDEProblem{true}(f!, g!, u0, tspan, ps) W = WienerProcess(0.0, 0.0, 0.0) - probscalar = SDEProblem{true}(f!, g!, u0, tspan, p_nn, noise = W) + probscalar = SDEProblem{true}(f!, g!, u0, tspan, ps, noise = W) # Defining the loss function function loss(pars, prob, alg) @@ -189,8 +187,7 @@ end _sol = solve(ensembleprob, alg, EnsembleThreads(), sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP()), - saveat = ts, trajectories = 10, - abstol = 1e-1, reltol = 1e-1) + saveat = ts, trajectories = 10, abstol = 1e-1, reltol = 1e-1) A = convert(Array, _sol) sum(abs2, A .- 1), mean(A) end @@ -209,16 +206,16 @@ end optf = Optimization.OptimizationFunction((p, _) -> loss(p, probscalar, LambaEM()), Optimization.AutoZygote()) - optprob = Optimization.OptimizationProblem(optf, p_nn) - res1 = Optimization.solve(optprob, ADAM(0.1), callback = callback, maxiters = 5) + optprob = Optimization.OptimizationProblem(optf, ps) + res1 = Optimization.solve(optprob, Adam(0.1), callback = callback, maxiters = 5) optf = Optimization.OptimizationFunction((p, _) -> loss(p, probscalar, SOSRI()), Optimization.AutoZygote()) - optprob = Optimization.OptimizationProblem(optf, p_nn) - res2 = Optimization.solve(optprob, ADAM(0.1), callback = callback, maxiters = 5) + optprob = Optimization.OptimizationProblem(optf, ps) + res2 = Optimization.solve(optprob, Adam(0.1), callback = callback, maxiters = 5) optf = Optimization.OptimizationFunction((p, _) -> loss(p, prob, LambaEM()), Optimization.AutoZygote()) - optprob = Optimization.OptimizationProblem(optf, p_nn) - res1 = Optimization.solve(optprob, ADAM(0.1), callback = callback, maxiters = 5) + optprob = Optimization.OptimizationProblem(optf, ps) + res1 = Optimization.solve(optprob, Adam(0.1), callback = callback, maxiters = 5) end diff --git a/test/size_handling_adjoint.jl b/test/size_handling_adjoint.jl index d8c139ef7..29579343d 100644 --- a/test/size_handling_adjoint.jl +++ b/test/size_handling_adjoint.jl @@ -1,4 +1,5 @@ -using SciMLSensitivity, Zygote, Flux, OrdinaryDiffEq, Test # , Plots +using SciMLSensitivity, Zygote, OrdinaryDiffEq, Test +using Optimization, OptimizationOptimisers p = [1.5 1.0; 3.0 1.0] function lotka_volterra(du, u, p, t) @@ -14,28 +15,23 @@ sol = solve(prob, Tsit5()) # plot(sol) -p = [2.2 1.0; 2.0 0.4] # Tweaked Initial Parameter Array -ps = Flux.params(p) +ps = [2.2 1.0; 2.0 0.4] # Tweaked Initial Parameter Array -function predict_adjoint() # Our 1-layer neural network +function predict_adjoint(p) # Our 1-layer neural network Array(solve(prob, Tsit5(), p = p, saveat = 0.0:0.1:10.0)) end -loss_adjoint() = sum(abs2, x - 1 for x in predict_adjoint()) +loss_adjoint(p, _) = sum(abs2, x - 1 for x in predict_adjoint(p)) -data = Iterators.repeated((), 100) -opt = ADAM(0.1) -cb = function () #callback function to observe training - display(loss_adjoint()) +cb = function (p, loss) #callback function to observe training + @show loss + return false end -predict_adjoint() +res = solve(OptimizationProblem(OptimizationFunction(loss_adjoint, AutoZygote()), ps), + Adam(0.1); callback = cb, maxiters = 200) -# Display the ODE with the initial parameter values. -cb() -Flux.train!(loss_adjoint, ps, data, opt, cb = cb) - -@test loss_adjoint() < 1 +@test loss_adjoint(res.u, nothing) < 1 tspan = (0, 1) tran = collect(0:0.1:1) @@ -70,4 +66,4 @@ dp6 = Zygote.pullback(x -> loss(x; vjp = false), p0)[2](1)[1] @test dp1 ≈ dp3 @test dp1 ≈ dp4 @test dp1 ≈ dp5 -@test dp1 ≈ dp6 \ No newline at end of file +@test dp1 ≈ dp6 diff --git a/test/sparse_adjoint.jl b/test/sparse_adjoint.jl index 253c343bf..0ce874c13 100644 --- a/test/sparse_adjoint.jl +++ b/test/sparse_adjoint.jl @@ -1,5 +1,5 @@ using SciMLSensitivity, OrdinaryDiffEq, LinearAlgebra, SparseArrays, Zygote, LinearSolve -using AlgebraicMultigrid: AlgebraicMultigrid +using AlgebraicMultigrid using Test foop(u, p, t) = jac(u, p, t) * u diff --git a/test/steady_state.jl b/test/steady_state.jl index 90821e6ca..95576c380 100644 --- a/test/steady_state.jl +++ b/test/steady_state.jl @@ -1,10 +1,6 @@ using Test, LinearAlgebra using SciMLSensitivity, SteadyStateDiffEq, DiffEqBase, NLsolve -using OrdinaryDiffEq -using NonlinearSolve, SciMLNLSolve -using ForwardDiff, Calculus -using Zygote -using Random +using OrdinaryDiffEq, NonlinearSolve, ForwardDiff, Calculus, Zygote, Random Random.seed!(12345) @testset "Adjoint sensitivities of steady state solver" begin @@ -73,7 +69,7 @@ Random.seed!(12345) @info "Calculate adjoint sensitivities from autodiff & numerical diff" function G(p) tmp_prob = remake(prob, u0 = convert.(eltype(p), prob.u0), p = p) - sol = solve(tmp_prob, DynamicSS(Rodas5())) + sol = solve(tmp_prob, DynamicSS(Rodas5()); abstol = 1e-14, reltol = 1e-14) A = convert(Array, sol) g(A, p, nothing) end @@ -90,8 +86,7 @@ Random.seed!(12345) # with jac, param_jac f1 = ODEFunction(f!; jac = jac!, paramjac = paramjac!) prob1 = SteadyStateProblem(f1, u0, p) - sol1 = solve(prob1, DynamicSS(Rodas5(), reltol = 1e-14, abstol = 1e-14), - reltol = 1e-14, abstol = 1e-14) + sol1 = solve(prob1, DynamicSS(Rodas5()), reltol = 1e-14, abstol = 1e-14) res1a = adjoint_sensitivities(sol1, DynamicSS(Rodas5()), sensealg = SteadyStateAdjoint(), dgdu = dgdu!, @@ -151,8 +146,7 @@ Random.seed!(12345) # without jac, without param_jac f3 = ODEFunction(f!) prob3 = SteadyStateProblem(f3, u0, p) - sol3 = solve(prob3, DynamicSS(Rodas5(), reltol = 1e-14, abstol = 1e-14), - reltol = 1e-14, abstol = 1e-14) + sol3 = solve(prob3, DynamicSS(Rodas5()), reltol = 1e-14, abstol = 1e-14) res3a = adjoint_sensitivities(sol3, DynamicSS(Rodas5()), sensealg = SteadyStateAdjoint(), dgdu = dgdu!, dgdp = dgdp!, g = g) @@ -415,7 +409,7 @@ end dp5 = Zygote.gradient(p -> test_loss(p, prob2, alg = SimpleNewtonRaphson()), p)[1] dp6 = Zygote.gradient(p -> test_loss(p, prob2, alg = Klement()), p)[1] dp7 = Zygote.gradient(p -> test_loss(p, prob2, alg = SimpleTrustRegion()), p)[1] - dp8 = Zygote.gradient(p -> test_loss(p, prob2, alg = NLSolveJL()), p)[1] + dp8 = Zygote.gradient(p -> test_loss(p, prob2, alg = NLsolveJL()), p)[1] @test dp1≈dp2 rtol=1e-10 @test dp1≈dp3 rtol=1e-10