From 9972f366033e4b72651c4f69d23268fcdcb8838c Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Wed, 24 Jan 2024 13:02:06 +0100 Subject: [PATCH 1/5] Use handwritten rules for `zero` and `one` --- src/rulesets/Base/base.jl | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 28cc11d19..453ba8172 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -2,11 +2,38 @@ # that also have FastMath versions. @scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent()) - -@scalar_rule one(x) ZeroTangent() -@scalar_rule zero(x) ZeroTangent() @scalar_rule transpose(x) true +# `zero` + +function frule((_, Δ1), ::typeof(zero), x) + var"∂f/∂x" = ZeroTangent() + (zero(x), Δ1 * var"∂f/∂x") +end + +function rrule(::typeof(zero), x) + Ω = zero(x) + proj_x = ProjectTo(x) + var"∂f/∂x" = ZeroTangent() + pullback(Δ1) = (NoTangent(), proj_x(conj(var"∂f/∂x") * Δ1)) + (Ω, pullback) +end + +# `one` + +function frule((_, Δ1), ::typeof(one), x) + var"∂f/∂x" = ZeroTangent() + (one(x), Δ1 * var"∂f/∂x") +end + +function rrule(::typeof(one), x) + Ω = one(x) + proj_x = ProjectTo(x) + var"∂f/∂x" = ZeroTangent() + pullback(Δ1) = (NoTangent(), proj_x(conj(var"∂f/∂x") * Δ1)) + (Ω, pullback) +end + # `adjoint` frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz') From 275b93b1f53c4037af467b954a0a9ea4e797161d Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Wed, 24 Jan 2024 19:53:29 +0100 Subject: [PATCH 2/5] formatting --- src/rulesets/Base/base.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 453ba8172..dd83e745d 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -8,7 +8,7 @@ function frule((_, Δ1), ::typeof(zero), x) var"∂f/∂x" = ZeroTangent() - (zero(x), Δ1 * var"∂f/∂x") + return (zero(x), Δ1 * var"∂f/∂x") end function rrule(::typeof(zero), x) @@ -16,14 +16,14 @@ function rrule(::typeof(zero), x) proj_x = ProjectTo(x) var"∂f/∂x" = ZeroTangent() pullback(Δ1) = (NoTangent(), proj_x(conj(var"∂f/∂x") * Δ1)) - (Ω, pullback) + return (Ω, pullback) end # `one` function frule((_, Δ1), ::typeof(one), x) var"∂f/∂x" = ZeroTangent() - (one(x), Δ1 * var"∂f/∂x") + return (one(x), Δ1 * var"∂f/∂x") end function rrule(::typeof(one), x) @@ -31,7 +31,7 @@ function rrule(::typeof(one), x) proj_x = ProjectTo(x) var"∂f/∂x" = ZeroTangent() pullback(Δ1) = (NoTangent(), proj_x(conj(var"∂f/∂x") * Δ1)) - (Ω, pullback) + return (Ω, pullback) end # `adjoint` From 0df040f9d59662afd7655afd2690479b6f8b4851 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 25 Jan 2024 10:23:01 +0100 Subject: [PATCH 3/5] simplify zero and one rules --- src/rulesets/Base/base.jl | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index dd83e745d..bf97318c7 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -6,32 +6,24 @@ # `zero` -function frule((_, Δ1), ::typeof(zero), x) - var"∂f/∂x" = ZeroTangent() - return (zero(x), Δ1 * var"∂f/∂x") +function frule((_, _), ::typeof(zero), x) + return (zero(x), ZeroTangent()) end function rrule(::typeof(zero), x) - Ω = zero(x) - proj_x = ProjectTo(x) - var"∂f/∂x" = ZeroTangent() - pullback(Δ1) = (NoTangent(), proj_x(conj(var"∂f/∂x") * Δ1)) - return (Ω, pullback) + zero_pullback(_) = (NoTangent(), ZeroTangent()) + return (zero(x), zero_pullback) end # `one` -function frule((_, Δ1), ::typeof(one), x) - var"∂f/∂x" = ZeroTangent() - return (one(x), Δ1 * var"∂f/∂x") +function frule((_, _), ::typeof(one), x) + return (one(x), ZeroTangent()) end function rrule(::typeof(one), x) - Ω = one(x) - proj_x = ProjectTo(x) - var"∂f/∂x" = ZeroTangent() - pullback(Δ1) = (NoTangent(), proj_x(conj(var"∂f/∂x") * Δ1)) - return (Ω, pullback) + one_pullback(_) = (NoTangent(), ZeroTangent()) + return (one(x), one_pullback) end # `adjoint` From 6e2f430c8caf5812e0da90f409f1ec64122e3856 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 25 Jan 2024 10:31:49 +0100 Subject: [PATCH 4/5] add some zero/one tests --- test/rulesets/Base/base.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 9a5278747..80e418b82 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -1,4 +1,14 @@ @testset "base.jl" begin + @testset "zero/one" begin + for f in [zero, one] + for x in [1.0, 1.0im, [10.0+im 11.0-im; 12.0+2im 13.0-3im]] + test_frule(f, x) + test_rrule(f, x) + end + end + test_frule(zero, [1.0, 2.0, 3.0]) + test_rrule(zero, [1.0, 2.0, 3.0]) + end @testset "copysign" begin # don't go too close to zero as the numerics may jump over it yielding wrong results @testset "at $y" for y in (-1.1, 0.1, 100.0) From cc934093bd53484bcc431c6e8f09eb702302c556 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 25 Jan 2024 10:32:45 +0100 Subject: [PATCH 5/5] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9311b00a2..ac039e57c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.59.0" +version = "1.59.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"