diff --git a/Project.toml b/Project.toml index 673df031..eb8e049f 100644 --- a/Project.toml +++ b/Project.toml @@ -3,20 +3,23 @@ uuid = "90137ffa-7385-5640-81b9-e52037218182" version = "1.5.2" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -julia = "1.6" StaticArraysCore = "1" +julia = "1.6" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" [targets] -test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays"] +test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Zygote", "ForwardDiff"] diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index e2e3fd83..7b353f84 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -12,6 +12,7 @@ import Statistics: mean using Random import Random: rand, randn, randexp, rand!, randn!, randexp! +using ChainRulesCore using Core.Compiler: return_type import Base: sqrt, exp, log, float, real using LinearAlgebra @@ -129,6 +130,7 @@ include("io.jl") include("pinv.jl") include("precompile.jl") +include("chainrule.jl") _precompile_() end # module diff --git a/src/chainrule.jl b/src/chainrule.jl new file mode 100644 index 00000000..f177081c --- /dev/null +++ b/src/chainrule.jl @@ -0,0 +1,23 @@ +### Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so overloaded here +function (project::ChainRulesCore.ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) + dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray + dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) + return ChainRulesCore.project_type(project)(dz...) +end + +### Project SArray to SArray +function ChainRulesCore.ProjectTo(x::SArray{S,T}) where {S, T} + return ChainRulesCore.ProjectTo{SArray}(; element=ChainRulesCore._eltype_projectto(T), axes=S) +end + +function (project::ChainRulesCore.ProjectTo{SArray})(dx::AbstractArray{S,M}) where {S,M} + return SArray{project.axes}(dx) +end + +### Adjoint for SArray constructor + +function ChainRulesCore.rrule(::Type{T}, x::Tuple) where {T<:SArray} + project_x = ProjectTo(x) + Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) + return T(x), Array_pullback +end \ No newline at end of file diff --git a/test/abstractarray.jl b/test/abstractarray.jl index d8aff256..8fc6c917 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1,4 +1,4 @@ -using StaticArrays, Test, LinearAlgebra +using StaticArrays, Test, LinearAlgebra, Zygote, ForwardDiff @testset "AbstractArray interface" begin @testset "size and length" begin @@ -243,6 +243,43 @@ using StaticArrays, Test, LinearAlgebra @test rs == Base.reduced_indices(axes(a), i) end end + + @testset "AutoDiff" begin + u0 = @SVector rand(2) + p = @SVector rand(4) + + function lotka(u, p, svec=true) + du1 = p[1]*u[1] - p[2]*u[1]*u[2] + du2 = -p[3]*u[2] + p[4]*u[1]*u[2] + if svec + @SVector [du1, du2] + else + @SMatrix [du1 du2 du1; du2 du1 du1] + end + end + + #SVector constructor adjoint + function loss(p) + u = lotka(u0, p) + sum(1 .- u) + end + + grad = Zygote.gradient(loss, p) + @test typeof(grad[1]) <: SArray + grad2 = ForwardDiff.gradient(loss, p) + @test grad[1] ≈ grad2 rtol=1e-12 + + #SMatrix constructor adjoint + function loss_mat(p) + u = lotka(u0, p, false) + sum(1 .- u) + end + + grad = Zygote.gradient(loss_mat, p) + @test typeof(grad[1]) <: SArray + grad2 = ForwardDiff.gradient(loss_mat, p) + @test grad[1] ≈ grad2 rtol=1e-12 + end end @testset "permutedims" begin