diff --git a/test/chainrules.jl b/test/chainrules.jl index 7dbf8e8a..ee3c02cc 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,7 +1,24 @@ using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, Test -@testset "Chain Rules Integration" begin +@testset "ChainRules Integration" begin @testset "Projection" begin + # There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx) + # implies a check, and reshape will wrap a Vector into a static SizedVector: + pstat = ProjectTo(SA[1, 2, 3]) + @test axes(pstat(rand(3))) === (SOneTo(3),) + + # This recurses into structured arrays: + pst = ProjectTo(transpose(SA[1, 2, 3])) + @test axes(pst(rand(1,3))) === (SOneTo(1), SOneTo(3)) + @test pst(rand(1,3)) isa Transpose + + # When the argument is an ordinary Array, static gradients are allowed to pass, + # like FillArrays. Collecting to an Array would cost a copy. + pvec3 = ProjectTo([1, 2, 3]) + @test pvec3(SA[1, 2, 3]) isa StaticArray + end + + @testset "Constructor rrules" begin test_rrule(SMatrix{1, 4}, (1.0, 1.0, 1.0, 1.0)) test_rrule(SMatrix{4, 1}, (1.0, 1.0, 1.0, 1.0)) test_rrule(SMatrix{2, 2}, (1.0, 1.0, 1.0, 1.0))