diff --git a/Project.toml b/Project.toml index 891e4bda..5ac9dcb8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StaticArrays" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.7.0" +version = "1.8.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -10,13 +10,16 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [extensions] +StaticArraysChainRulesCoreExt = "ChainRulesCore" StaticArraysStatisticsExt = "Statistics" [compat] Aqua = "0.7" +ChainRulesCore = "1" PrecompileTools = "1" StaticArraysCore = "~1.4.0" julia = "1.6" @@ -24,6 +27,8 @@ julia = "1.6" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -31,4 +36,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Statistics", "Unitful", "Aqua"] +test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Statistics", "Unitful", "Aqua", "ChainRulesTestUtils", "ChainRulesCore"] diff --git a/ext/StaticArraysChainRulesCoreExt.jl b/ext/StaticArraysChainRulesCoreExt.jl new file mode 100644 index 00000000..09c62ac8 --- /dev/null +++ b/ext/StaticArraysChainRulesCoreExt.jl @@ -0,0 +1,32 @@ +module StaticArraysChainRulesCoreExt + +using StaticArrays +# ChainRulesCore imports +import ChainRulesCore: NoTangent, ProjectTo, Tangent, project_type, rrule +import ChainRulesCore as CRC + +# Projecting a tuple to SMatrix leads to CRC._projection_mismatch by default, so +# overloaded here +function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::StaticArraysCore.SArray) + dy = reshape(dx, axes(project.elements)) + dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) + return project_type(project)(dz...) +end + +# Project SArray to SArray +function ProjectTo(x::SArray{S, T}) where {S, T} + return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = S) +end + +function (project::ProjectTo{SArray})(dx::AbstractArray{S, M}) where {S, M} + return SArray{project.axes}(dx) +end + +# Adjoint for SArray constructor +function rrule(::Type{T}, x::Tuple) where {T <: SArray} + project_x = ProjectTo(x) + ∇Array(∂y) = (NoTangent(), project_x(∂y)) + return T(x), ∇Array +end + +end diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 00000000..d5394ef1 --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,10 @@ +using StaticArrays, ChainRulesCore, ChainRulesTestUtils, Test + +@testset "Chain Rules Integration" begin + @testset "Projection" begin + test_rrule(SMatrix{1, 4}, (1.0, 1.0, 1.0, 1.0)) + test_rrule(SMatrix{4, 1}, (1.0, 1.0, 1.0, 1.0)) + test_rrule(SMatrix{2, 2}, (1.0, 1.0, 1.0, 1.0)) + test_rrule(SVector{4}, (1.0, 1.0, 1.0, 1.0)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 334f716c..bf47ff05 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -88,4 +88,9 @@ if TEST_GROUP ∈ ["", "all", "group-B"] addtests("io.jl") addtests("svd.jl") addtests("unitful.jl") + + # chain rules integration via pkg extensions is available only in Julia 1.9+ + if VERSION ≥ v"1.9-" + addtests("chainrules.jl") + end end