diff --git a/Project.toml b/Project.toml index 9311b00a2..ac039e57c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.59.0" +version = "1.59.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 28cc11d19..bf97318c7 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -2,11 +2,30 @@ # that also have FastMath versions. @scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent()) - -@scalar_rule one(x) ZeroTangent() -@scalar_rule zero(x) ZeroTangent() @scalar_rule transpose(x) true +# `zero` + +function frule((_, _), ::typeof(zero), x) + return (zero(x), ZeroTangent()) +end + +function rrule(::typeof(zero), x) + zero_pullback(_) = (NoTangent(), ZeroTangent()) + return (zero(x), zero_pullback) +end + +# `one` + +function frule((_, _), ::typeof(one), x) + return (one(x), ZeroTangent()) +end + +function rrule(::typeof(one), x) + one_pullback(_) = (NoTangent(), ZeroTangent()) + return (one(x), one_pullback) +end + # `adjoint` frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz') diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 9a5278747..80e418b82 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -1,4 +1,14 @@ @testset "base.jl" begin + @testset "zero/one" begin + for f in [zero, one] + for x in [1.0, 1.0im, [10.0+im 11.0-im; 12.0+2im 13.0-3im]] + test_frule(f, x) + test_rrule(f, x) + end + end + test_frule(zero, [1.0, 2.0, 3.0]) + test_rrule(zero, [1.0, 2.0, 3.0]) + end @testset "copysign" begin # don't go too close to zero as the numerics may jump over it yielding wrong results @testset "at $y" for y in (-1.1, 0.1, 100.0)