|
1 | | -@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} = |
2 | | - Base.literal_pow(^,x,Val(p)), |
3 | | - Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing) |
| 1 | +function ChainRulesCore.rrule( |
| 2 | + ::ZygoteRuleConfig, ::typeof(convert), T::Type{<:Real}, x::Real |
| 3 | +) |
| 4 | + convert_pullback(Δ) = (NoTangent(), NoTangent(), Δ) |
| 5 | + return convert(T, x), convert_pullback |
| 6 | +end |
| 7 | + |
| 8 | +function ChainRulesCore.rrule( |
| 9 | + ::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{p} |
| 10 | +) where {p} |
| 11 | + function literal_pow_pullback(Δ) |
| 12 | + dx = Δ * conj(p * Base.literal_pow(^,x,Val(p-1))) |
| 13 | + return (NoTangent(), NoTangent(), dx, NoTangent()) |
| 14 | + end |
| 15 | + return Base.literal_pow(^,x,Val(p)), literal_pow_pullback |
| 16 | +end |
4 | 17 |
|
5 | | -@adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), ȳ -> (nothing, ȳ) |
6 | | -@adjoint (T::Type{<:Real})(x::Real) = T(x), ȳ -> (nothing, ȳ) |
| 18 | +function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Real}, x::Real) |
| 19 | + Real_pullback(Δ) = (NoTangent(), Δ) |
| 20 | + return T(x), Real_pullback |
| 21 | +end |
7 | 22 |
|
8 | 23 | for T in Base.uniontypes(Core.BuiltinInts) |
9 | | - @adjoint (::Type{T})(x::Core.BuiltinInts) = T(x), Δ -> (Δ,) |
| 24 | + @eval function ChainRulesCore.rrule(::ZygoteRuleConfig, ::Type{$T}, x::Core.BuiltinInts) |
| 25 | + IntX_pullback(Δ) = (NoTangent(), Δ) |
| 26 | + return $T(x), IntX_pullback |
| 27 | + end |
10 | 28 | end |
11 | 29 |
|
12 | | -@adjoint Base.:+(xs::Number...) = +(xs...), Δ -> map(_ -> Δ, xs) |
| 30 | +function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(+), xs::Number...) |
| 31 | + plus_pullback(Δ) = (NoTangent(), map(_ -> Δ, xs)...) |
| 32 | + return +(xs...), plus_pullback |
| 33 | +end |
13 | 34 |
|
14 | | -@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, - c̄ * a // b // b)) |
| 35 | +function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(//), a, b) |
| 36 | + divide_pullback(r̄) = (NoTangent(), r̄ * 1//b, - r̄ * a // b // b) |
| 37 | + return a // b, divide_pullback |
| 38 | +end |
15 | 39 |
|
16 | 40 | # Complex Numbers |
17 | 41 |
|
18 | | -@adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄)) |
| 42 | +function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Complex}, r, i) |
| 43 | + Complex_pullback(c̄) = (NoTangent(), real(c̄), imag(c̄)) |
| 44 | + return T(r, i), Complex_pullback |
| 45 | +end |
19 | 46 |
|
20 | 47 | # we define these here because ChainRules.jl only defines them for x::Union{Real,Complex} |
21 | 48 |
|
22 | | -@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),) |
23 | | -@adjoint real(x::Number) = real(x), r̄ -> (real(r̄),) |
24 | | -@adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),) |
25 | | -@adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,) |
| 49 | +function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(abs2), x::Number) |
| 50 | + abs2_pullback(Δ) = (NoTangent(), real(Δ)*(x + x)) |
| 51 | + return abs2(x), abs2_pullback |
| 52 | +end |
| 53 | + |
| 54 | +function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(real), x::Number) |
| 55 | + real_pullback(r̄) = (NoTangent(), real(r̄)) |
| 56 | + return real(x), real_pullback |
| 57 | +end |
| 58 | + |
| 59 | +function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(conj), x::Number) |
| 60 | + conj_pullback(c̄) = (NoTangent(), conj(c̄)) |
| 61 | + return conj(x), conj_pullback |
| 62 | +end |
26 | 63 |
|
27 | 64 | # for real x, ChainRules pulls back a zero real adjoint, whereas we treat x |
28 | 65 | # as embedded in the complex numbers and pull back a pure imaginary adjoint |
29 | | -@adjoint imag(x::Real) = zero(x), ī -> (real(ī)*im,) |
| 66 | +function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(imag), x::Number) |
| 67 | + imag_pullback(ī) = (NoTangent(), real(ī)*im) |
| 68 | + return imag(x), imag_pullback |
| 69 | +end |
0 commit comments