From caf932de6ed2e6651f792300204776f483be99a5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 14 Dec 2023 10:55:47 -0500 Subject: [PATCH] Properly support StaticArrays construction --- Project.toml | 4 +++- src/ChainRules.jl | 3 +++ src/rulesets/StaticArrays/staticarrays.jl | 23 ++++++++++++++++++++++ test/rulesets/StaticArrays/staticarrays.jl | 8 ++++++++ test/runtests.jl | 3 +++ 5 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 src/rulesets/StaticArrays/staticarrays.jl create mode 100644 test/rulesets/StaticArrays/staticarrays.jl diff --git a/Project.toml b/Project.toml index 77f514046..bfd5dd5b7 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseInverseSubset = "dc90abb0-5640-4711-901d-7e5b23a2fada" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" @@ -32,9 +33,10 @@ JuliaInterpreter = "0.8,0.9" LinearAlgebra = "1" Random = "1" RealDot = "0.1" -SparseInverseSubset = "0.1" SparseArrays = "1" +SparseInverseSubset = "0.1" StaticArrays = "1.2" +StaticArraysCore = "1" Statistics = "1" StructArrays = "0.6.11" SuiteSparse = "1" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 6d33a22e7..879558c87 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -14,6 +14,7 @@ using RealDot: realdot using SparseArrays using Statistics using StructArrays +using StaticArraysCore # Basically everything this package does is overloading these, so we make an exception # to the normal rule of only overload via `ChainRulesCore.rrule`. @@ -65,4 +66,6 @@ include("rulesets/SparseArrays/sparsematrix.jl") include("rulesets/Random/random.jl") +include("rulesets/StaticArrays/staticarrays.jl") + end # module diff --git a/src/rulesets/StaticArrays/staticarrays.jl b/src/rulesets/StaticArrays/staticarrays.jl new file mode 100644 index 000000000..b55344f85 --- /dev/null +++ b/src/rulesets/StaticArrays/staticarrays.jl @@ -0,0 +1,23 @@ +# Projecting a tuple to SMatrix leads to CRC._projection_mismatch by default, so +# overloaded here +function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) + dy = reshape(dx, axes(project.elements)) + dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) + return ChainRulesCore.project_type(project)(dz...) +end + +# Project SArray to SArray +function ProjectTo(x::SArray{S, T}) where {S, T} + return ProjectTo{SArray}(; element = ChainRulesCore._eltype_projectto(T), axes = S) +end + +function (project::ProjectTo{SArray})(dx::AbstractArray{S, M}) where {S, M} + return SArray{project.axes}(dx) +end + +# Adjoint for SArray constructor +function rrule(::Type{T}, x::Tuple) where {T <: SArray} + project_x = ProjectTo(x) + ∇Array(∂y) = (NoTangent(), project_x(∂y)) + return T(x), ∇Array +end \ No newline at end of file diff --git a/test/rulesets/StaticArrays/staticarrays.jl b/test/rulesets/StaticArrays/staticarrays.jl new file mode 100644 index 000000000..54ea24f13 --- /dev/null +++ b/test/rulesets/StaticArrays/staticarrays.jl @@ -0,0 +1,8 @@ +@testset "StaticArrays Constructors" begin + @testset "Projection" begin + test_rrule(SMatrix{1, 4}, (1.0, 1.0, 1.0, 1.0)) + test_rrule(SMatrix{4, 1}, (1.0, 1.0, 1.0, 1.0)) + test_rrule(SMatrix{2, 2}, (1.0, 1.0, 1.0, 1.0)) + test_rrule(SVector{4}, (1.0, 1.0, 1.0, 1.0)) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 768f7c208..c19ce2eb2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -90,4 +90,7 @@ end include_test("rulesets/Random/random.jl") println() + + include_test("rulesets/StaticArrays/staticarrays.jl") + println() end