From 623299a373a536a7025f274d46164241acd74917 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 17 Jul 2021 22:12:03 -0400 Subject: [PATCH 1/2] add ChainRulesCore projections --- Project.toml | 2 ++ src/FillArrays.jl | 1 + src/chainrules.jl | 30 ++++++++++++++++++++++++++++++ test/runtests.jl | 14 +++++++++++++- 4 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 src/chainrules.jl diff --git a/Project.toml b/Project.toml index 16e8e213..699dd109 100644 --- a/Project.toml +++ b/Project.toml @@ -3,12 +3,14 @@ uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" version = "0.12.1" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] +ChainRulesCore = "1" julia = "1" [extras] diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 9e6b0a7d..7b33922d 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -626,6 +626,7 @@ end include("fillalgebra.jl") include("fillbroadcast.jl") include("trues.jl") +include("chainrules.jl") ## # print diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 00000000..f76ae56b --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,30 @@ +import ChainRulesCore: ProjectTo, NoTangent + +""" + ProjectTo(::Fill) -> ProjectTo{Fill} + ProjectTo(::Ones) -> ProjectTo{NoTangent} + +Most FillArrays arrays store one number, and so their gradients under automatic +differentiation represent the variation of this one number. + +The exception is those like `Ones` and `Zeros` whose type fixes their value, +which have no graidient. +""" +ProjectTo(x::Fill{<:Number}) = ProjectTo{Fill}(; element = ProjectTo(getindex_value(x)), axes = axes(x)) + +ProjectTo(x::AbstractFill{Bool}) = ProjectTo{NoTangent}() # Bool is always regarded as categorical + +ProjectTo(x::Zeros) = ProjectTo{NoTangent}() +ProjectTo(x::Ones) = ProjectTo{NoTangent}() + +function (project::ProjectTo{Fill})(dx::AbstractArray) + for d in 1:max(ndims(dx), length(project.axes)) + size(dx, d) == length(get(project.axes, d, 1)) || throw(_projection_mismatch(axes_x, size(dx))) + end + Fill(mean(dx), project.axes) # Note that mean(dx::Fill) is optimised +end + +function _projection_mismatch(axes_x::Tuple, size_dx::Tuple) + size_x = map(length, axes_x) + DimensionMismatch("variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx") +end diff --git a/test/runtests.jl b/test/runtests.jl index a688ef54..d1e02959 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,7 @@ -using FillArrays, LinearAlgebra, SparseArrays, StaticArrays, Random, Base64, Test, Statistics + +using FillArrays, StaticArrays, ChainRulesCore, Base64 +using LinearAlgebra, SparseArrays, Random, Statistics, Test # standard libraries + import FillArrays: AbstractFill, RectDiagonal, SquareEye @testset "fill array constructors and convert" begin @@ -1323,3 +1326,12 @@ end @test cor(Fill(3, 4, 5)) ≈ cor(fill(3, 4, 5)) nans=true @test cor(Fill(3, 4, 5), dims=2) ≈ cor(fill(3, 4, 5), dims=2) nans=true end + +@testset "ChainRules integration" begin + @test ProjectTo(Fill(1,2,3))(ones(2,3)) === Fill(1.0, 2, 3) + @test ProjectTo(Fill(1,2,3))(ones(2,3,1) .+ im) === Fill(1.0, 2, 3) + @test ProjectTo(Fill(1,2,3))(Fill(1+im, 2,3)) === Fill(1.0, 2, 3) + + @test ProjectTo(Eye(3))(rand(3,3)) === NoTangent() + @test ProjectTo(Zeros(3))(rand(3)) === NoTangent() +end From 32edf8e5fbb21d303376daa9faf12f1a4b3e0ede Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 20 Aug 2021 12:06:05 -0400 Subject: [PATCH 2/2] reconstruct from Tangent --- src/chainrules.jl | 10 +++++++++- test/runtests.jl | 8 +++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index f76ae56b..01f35580 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,4 +1,4 @@ -import ChainRulesCore: ProjectTo, NoTangent +import ChainRulesCore: ProjectTo, NoTangent, Tangent """ ProjectTo(::Fill) -> ProjectTo{Fill} @@ -24,6 +24,14 @@ function (project::ProjectTo{Fill})(dx::AbstractArray) Fill(mean(dx), project.axes) # Note that mean(dx::Fill) is optimised end +function (project::ProjectTo{Fill})(dx::Tangent{<:Fill}) + # This would need a definition for length(::NoTangent) to be safe: + # for d in 1:max(length(dx.axes), length(project.axes)) + # length(get(dx.axes, d, 1)) == length(get(project.axes, d, 1)) || throw(_projection_mismatch(dx.axes, size(dx))) + # end + Fill(dx.value / prod(length, project.axes), project.axes) +end + function _projection_mismatch(axes_x::Tuple, size_dx::Tuple) size_x = map(length, axes_x) DimensionMismatch("variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx") diff --git a/test/runtests.jl b/test/runtests.jl index d1e02959..0255bdf2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1328,9 +1328,11 @@ end end @testset "ChainRules integration" begin - @test ProjectTo(Fill(1,2,3))(ones(2,3)) === Fill(1.0, 2, 3) - @test ProjectTo(Fill(1,2,3))(ones(2,3,1) .+ im) === Fill(1.0, 2, 3) - @test ProjectTo(Fill(1,2,3))(Fill(1+im, 2,3)) === Fill(1.0, 2, 3) + x = Fill(1,2,3) + @test ProjectTo(x)(ones(2,3)) === Fill(1.0, 2, 3) + @test ProjectTo(x)(ones(2,3,1) .+ im) === Fill(1.0, 2, 3) + @test ProjectTo(x)(Fill(1+im, 2,3)) === Fill(1.0, 2, 3) + @test ProjectTo(x)(Tangent{typeof(x)}(; value=6)) === Fill(1.0, 2, 3) @test ProjectTo(Eye(3))(rand(3,3)) === NoTangent() @test ProjectTo(Zeros(3))(rand(3)) === NoTangent()