Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ext/StaticArraysChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ end

# Project SArray to SArray
function ProjectTo(x::SArray{S, T}) where {S, T}
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = Size(x))
# We have a axes field because it is expected by other ProjectTo's like the one for Transpose
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = axes(x),
size = Size(x))
end

@inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx)

(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.axes, dx)
(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.size, dx)

# Adjoint for SArray constructor
function rrule(::Type{T}, x::Tuple) where {T <: SArray}
Expand Down
21 changes: 19 additions & 2 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@
using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, Test
using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, LinearAlgebra, Test

@testset "Chain Rules Integration" begin
@testset "ChainRules Integration" begin
@testset "Projection" begin
# There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx)
# implies a check, and reshape will wrap a Vector into a static SizedVector:
pstat = ProjectTo(SA[1, 2, 3])
@test axes(pstat(rand(3))) === (SOneTo(3),)

# This recurses into structured arrays:
pst = ProjectTo(transpose(SA[1, 2, 3]))
@test axes(pst(rand(1,3))) === (SOneTo(1), SOneTo(3))
@test pst(rand(1,3)) isa Transpose

# When the argument is an ordinary Array, static gradients are allowed to pass,
# like FillArrays. Collecting to an Array would cost a copy.
pvec3 = ProjectTo([1, 2, 3])
@test pvec3(SA[1, 2, 3]) isa StaticArray
end

@testset "Constructor rrules" begin
test_rrule(SMatrix{1, 4}, (1.0, 1.0, 1.0, 1.0))
test_rrule(SMatrix{4, 1}, (1.0, 1.0, 1.0, 1.0))
test_rrule(SMatrix{2, 2}, (1.0, 1.0, 1.0, 1.0))
Expand Down