Skip to content

Commit b9530c7

Browse files
authored
Merge pull request #1273 from mzgubic/mz/number_rrules
number adjoints to rrules
2 parents c822e9e + cdadaff commit b9530c7

File tree

3 files changed

+99
-22
lines changed

3 files changed

+99
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.6.42"
3+
version = "0.6.43"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/lib/number.jl

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,69 @@
1-
@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
2-
Base.literal_pow(^,x,Val(p)),
3-
Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing)
1+
function ChainRulesCore.rrule(
2+
::ZygoteRuleConfig, ::typeof(convert), T::Type{<:Real}, x::Real
3+
)
4+
convert_pullback(Δ) = (NoTangent(), NoTangent(), Δ)
5+
return convert(T, x), convert_pullback
6+
end
7+
8+
function ChainRulesCore.rrule(
9+
::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{p}
10+
) where {p}
11+
function literal_pow_pullback(Δ)
12+
dx = Δ * conj(p * Base.literal_pow(^,x,Val(p-1)))
13+
return (NoTangent(), NoTangent(), dx, NoTangent())
14+
end
15+
return Base.literal_pow(^,x,Val(p)), literal_pow_pullback
16+
end
417

5-
@adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), ȳ -> (nothing, ȳ)
6-
@adjoint (T::Type{<:Real})(x::Real) = T(x), ȳ -> (nothing, ȳ)
18+
function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Real}, x::Real)
19+
Real_pullback(Δ) = (NoTangent(), Δ)
20+
return T(x), Real_pullback
21+
end
722

823
for T in Base.uniontypes(Core.BuiltinInts)
9-
@adjoint (::Type{T})(x::Core.BuiltinInts) = T(x), Δ -> (Δ,)
24+
@eval function ChainRulesCore.rrule(::ZygoteRuleConfig, ::Type{$T}, x::Core.BuiltinInts)
25+
IntX_pullback(Δ) = (NoTangent(), Δ)
26+
return $T(x), IntX_pullback
27+
end
1028
end
1129

12-
@adjoint Base.:+(xs::Number...) = +(xs...), Δ -> map(_ -> Δ, xs)
30+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(+), xs::Number...)
31+
plus_pullback(Δ) = (NoTangent(), map(_ -> Δ, xs)...)
32+
return +(xs...), plus_pullback
33+
end
1334

14-
@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, -* a // b // b))
35+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(//), a, b)
36+
divide_pullback(r̄) = (NoTangent(), r̄ * 1//b, -* a // b // b)
37+
return a // b, divide_pullback
38+
end
1539

1640
# Complex Numbers
1741

18-
@adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄))
42+
function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Complex}, r, i)
43+
Complex_pullback(c̄) = (NoTangent(), real(c̄), imag(c̄))
44+
return T(r, i), Complex_pullback
45+
end
1946

2047
# we define these here because ChainRules.jl only defines them for x::Union{Real,Complex}
2148

22-
@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)
23-
@adjoint real(x::Number) = real(x), r̄ -> (real(r̄),)
24-
@adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),)
25-
@adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,)
49+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(abs2), x::Number)
50+
abs2_pullback(Δ) = (NoTangent(), real(Δ)*(x + x))
51+
return abs2(x), abs2_pullback
52+
end
53+
54+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(real), x::Number)
55+
real_pullback(r̄) = (NoTangent(), real(r̄))
56+
return real(x), real_pullback
57+
end
58+
59+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(conj), x::Number)
60+
conj_pullback(c̄) = (NoTangent(), conj(c̄))
61+
return conj(x), conj_pullback
62+
end
2663

2764
# for real x, ChainRules pulls back a zero real adjoint, whereas we treat x
2865
# as embedded in the complex numbers and pull back a pure imaginary adjoint
29-
@adjoint imag(x::Real) = zero(x), ī -> (real(ī)*im,)
66+
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(imag), x::Number)
67+
imag_pullback(ī) = (NoTangent(), real(ī)*im)
68+
return imag(x), imag_pullback
69+
end

test/lib/number.jl

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,44 @@
1-
@testset "nograds" begin
2-
@test gradient(floor, 1) === (0.0,)
3-
@test gradient(ceil, 1) === (0.0,)
4-
@test gradient(round, 1) === (0.0,)
5-
@test gradient(hash, 1) === nothing
6-
@test gradient(div, 1, 2) === nothing
7-
end #testset
1+
@testset "number.jl" begin
2+
@testset "nograds" begin
3+
@test gradient(floor, 1) === (0.0,)
4+
@test gradient(ceil, 1) === (0.0,)
5+
@test gradient(round, 1) === (0.0,)
6+
@test gradient(hash, 1) === nothing
7+
@test gradient(div, 1, 2) === nothing
8+
end
9+
10+
@testset "basics" begin
11+
@test gradient(Base.literal_pow, ^, 3//2, Val(-5))[2] isa Rational
12+
13+
@test gradient(convert, Rational, 3.14) == (nothing, 1.0)
14+
@test gradient(convert, Rational, 2.3) == (nothing, 1.0)
15+
@test gradient(convert, UInt64, 2) == (nothing, 1.0)
16+
@test gradient(convert, BigFloat, π) == (nothing, 1.0)
17+
18+
@test gradient(Rational, 2) == (1//1,)
19+
20+
@test gradient(Bool, 1) == (1.0,)
21+
@test gradient(Int32, 2) == (1.0,)
22+
@test gradient(UInt16, 2) == (1.0,)
23+
24+
@test gradient(+, 2.0, 3, 4.0, 5.0) == (1.0, 1.0, 1.0, 1.0)
25+
26+
@test gradient(//, 3, 2) == (1//2, -3//4)
27+
end
28+
29+
@testset "Complex numbers" begin
30+
@test gradient(imag, 3.0) == (0.0,)
31+
@test gradient(imag, 3.0 + 3.0im) == (0.0 + 1.0im,)
32+
33+
@test gradient(conj, 3.0) == (1.0,)
34+
@test gradient(real conj, 3.0 + 1im) == (1.0 + 0im,)
35+
36+
@test gradient(real, 3.0) == (1.0,)
37+
@test gradient(real, 3.0 + 1im) == (1.0 + 0im,)
38+
39+
@test gradient(abs2, 3.0) == (2*3.0,)
40+
@test gradient(abs2, 3.0+2im) == (2*3.0 + 2*2.0im,)
41+
42+
@test gradient(real Complex, 3.0, 2.0) == (1.0, 0.0)
43+
end
44+
end

0 commit comments

Comments
 (0)