diff --git a/Project.toml b/Project.toml index 185b08a..37e58cd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.43" +version = "0.6.44" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 9be245a..fad23d0 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -70,6 +70,13 @@ include("zygote.jl") end end + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + using .Zygote: Zygote + # HACK: Make Zygote (correctly) recognize that it should use `ForwardDiff` for broadcasting. + # See `is_diff_safe` for more information. + Zygote._dual_purefun(::Type{C}) where {C<:Closure} = is_diff_safe(C) + end + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin using DiffRules using SpecialFunctions @@ -80,45 +87,7 @@ include("zygote.jl") end @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin - using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray - - const LazyVectorOfUnivariate{ - S<:ValueSupport, - T<:UnivariateDistribution{S}, - Tdists<:BroadcastVector{T}, - } = VectorOfUnivariate{S,T,Tdists} - - function Distributions._logpdf( - dist::LazyVectorOfUnivariate, - x::AbstractVector{<:Real}, - ) - return sum(copy(logpdf.(dist.v, x))) - end - - function Distributions.logpdf( - dist::LazyVectorOfUnivariate, - x::AbstractMatrix{<:Real}, - ) - size(x, 1) == length(dist) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - return vec(sum(copy(logpdf.(dists, x)), dims = 1)) - end - - const LazyMatrixOfUnivariate{ - S<:ValueSupport, - T<:UnivariateDistribution{S}, - Tdists<:BroadcastArray{T,2}, - } = MatrixOfUnivariate{S,T,Tdists} - - function Distributions._logpdf( - dist::LazyMatrixOfUnivariate, - x::AbstractMatrix{<:Real}, - ) - return sum(copy(logpdf.(dist.dists, x))) - end - - lazyarray(f, x...) = LazyArray(Base.broadcasted(f, x...)) - export lazyarray + include("lazyarrays.jl") end end diff --git a/src/common.jl b/src/common.jl index ee094ba..74d82f3 100644 --- a/src/common.jl +++ b/src/common.jl @@ -48,3 +48,93 @@ parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) @non_differentiable adapt_randn(::Any...) + + +""" + Closure{F,G} + +A callable of the form `(x, args...) -> F(G(args...), x)`. + +# Examples + +This is particularly useful when one wants to avoid broadcasting over constructors +which can sometimes cause issues with type-inference, in particular when combined +with reverse-mode AD frameworks. + +```juliarepl +julia> using DistributionsAD, Distributions, ReverseDiff, BenchmarkTools + +julia> const data = randn(1000); + +julia> x = randn(length(data)); + +julia> f(x) = sum(logpdf.(Normal.(x), data)) +f (generic function with 2 methods) + +julia> @btime ReverseDiff.gradient(\$f, \$x); + 848.759 μs (14605 allocations: 521.84 KiB) + +julia> # Much faster with ReverseDiff.jl. + g(x) = sum(DistributionsAD.Closure(logpdf, Normal).(data, x)) +g (generic function with 1 method) + +julia> @btime ReverseDiff.gradient(\$g, \$x); + 17.460 μs (17 allocations: 71.52 KiB) +``` + +See https://github.com/TuringLang/Turing.jl/issues/1934 more further discussion. +""" +struct Closure{F,G} end + +Closure(::F, ::G) where {F,G} = Closure{F,G}() +Closure(::F, ::Type{G}) where {F,G} = Closure{F,G}() +Closure(::Type{F}, ::G) where {F,G} = Closure{F,G}() +Closure(::Type{F}, ::Type{G}) where {F,G} = Closure{F,G}() + +""" + is_diff_safe(f) + +Return `true` if it's safe to ignore gradients wrt. `f` when computing `f`. + +Useful for checking it's okay to take faster paths in pullbacks for certain AD backends. + +# Examples + +```jldoctest +julia> using Distributions + +julia> using DistributionsAD: is_diff_safe, Closure + +julia> is_diff_safe(typeof(logpdf)) +true + +julia> is_diff_safe(typeof(x -> 2x)) +true + +julia> # But it fails if we make a closure over a variable, which we might want to compute + # the gradient with respect to. + makef(x) = y -> x + y +makef (generic function with 1 method) + +julia> is_diff_safe(typeof(makef([1.0]))) +false + +julia> # Also works on `Closure`s from `DistributionsAD`. + is_diff_safe(typeof(Closure(logpdf, Normal))) +true + +julia> is_diff_safe(typeof(Closure(logpdf, makef([1.0])))) +false +""" +@inline is_diff_safe(_) = false +@inline is_diff_safe(::Type) = true +@inline is_diff_safe(::Type{F}) where {F<:Function} = Base.issingletontype(F) +@inline is_diff_safe(::Type{Closure{F,G}}) where {F,G} = is_diff_safe(F) && is_diff_safe(G) + +@generated function (closure::Closure{F,G})(x, args...) where {F,G} + f = Base.issingletontype(F) ? F.instance : F + g = Base.issingletontype(G) ? G.instance : G + return :($f($g(args...), x)) +end + + diff --git a/src/lazyarrays.jl b/src/lazyarrays.jl new file mode 100644 index 0000000..f6db1eb --- /dev/null +++ b/src/lazyarrays.jl @@ -0,0 +1,99 @@ +using .LazyArrays: BroadcastArray, BroadcastVector, LazyArray + +const LazyVectorOfUnivariate{ + S<:ValueSupport, + T<:UnivariateDistribution{S}, + Tdists<:BroadcastVector{T}, +} = VectorOfUnivariate{S,T,Tdists} + +_inner_constructor(::Type{<:BroadcastVector{<:Any,Type{D}}}) where {D} = D + +function Distributions._logpdf( + dist::LazyVectorOfUnivariate, + x::AbstractVector{<:Real}, +) + # TODO: Make use of `sum(Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)))` once + # we've addressed performance issues in ReverseDiff.jl. + constructor = _inner_constructor(typeof(dist.v)) + return sum(Closure(logpdf, constructor).(x, dist.v.args...)) +end + +function Distributions.logpdf( + dist::LazyVectorOfUnivariate, + x::AbstractMatrix{<:Real}, +) + size(x, 1) == length(dist) || + throw(DimensionMismatch("Inconsistent array dimensions.")) + constructor = _inner_constructor(typeof(dist.v)) + return vec(sum(Closure(logpdf, constructor).(x, dist.v.args...), dims = 1)) +end + +const LazyMatrixOfUnivariate{ + S<:ValueSupport, + T<:UnivariateDistribution{S}, + Tdists<:BroadcastArray{T,2}, +} = MatrixOfUnivariate{S,T,Tdists} + +function Distributions._logpdf( + dist::LazyMatrixOfUnivariate, + x::AbstractMatrix{<:Real}, +) + + constructor = _inner_constructor(typeof(dist.v)) + return sum(Closure(logpdf, constructor).(x, dist.v.args)) +end + +lazyarray(f, x...) = BroadcastArray(f, x...) +export lazyarray + +# HACK: All of the below probably shouldn't be here. +function ChainRulesCore.rrule(::Type{BroadcastArray}, f, args...) + function BroadcastArray_pullback(Δ::ChainRulesCore.Tangent) + return (ChainRulesCore.NoTangent(), Δ.f, Δ.args...) + end + return BroadcastArray(f, args...), BroadcastArray_pullback +end + +ChainRulesCore.ProjectTo(ba::BroadcastArray) = ProjectTo{typeof(ba)}((f=ba.f,)) +function (p::ChainRulesCore.ProjectTo{BA})(args...) where {BA<:BroadcastArray} + return ChainRulesCore.Tangent{BA}(f=p.f, args=args) +end + +function ChainRulesCore.rrule( + config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, + ::typeof(logpdf), + dist::LazyVectorOfUnivariate, + x::AbstractVector{<:Real} +) + # Extract the constructor used in the `BroadcastArray`. + constructor = DistributionsAD._inner_constructor(typeof(dist.v)) + + # If it's not safe to ignore the `constructor` in the pullback, then we fall back + # to the default implementation. + is_diff_safe(constructor) || return ChainRulesCore.rrule_via_ad(config, (d,x) -> sum(logpdf.(d.v, x)), dist, x) + + # Otherwise, we use `Closure`. + cl = DistributionsAD.Closure(logpdf, constructor) + + # Construct pullbacks manually to avoid the constructor of `BroadcastArray`. + y, dy = ChainRulesCore.rrule_via_ad(config, broadcast, cl, x, dist.v.args...) + z, dz = ChainRulesCore.rrule_via_ad(config, sum, y) + + project_broadcastarray = ChainRulesCore.ProjectTo(dist.v) + function logpdf_adjoint(Δ...) + # 1st argument is `sum` -> nothing. + (_, sum_Δ...) = dz(Δ...) + # 1st argument is `broadcast` -> nothing. + # 2nd argument is `cl` -> `nothing`. + # 3rd argument is `x` -> something. + # Rest is `dist` arguments -> something + (_, _, x_Δ, args_Δ...) = dy(sum_Δ...) + # Construct the structural tangents. + ba_tangent = project_broadcastarray(args_Δ...) + dist_tangent = ChainRulesCore.Tangent{typeof(dist)}(v=ba_tangent) + + return (ChainRulesCore.NoTangent(), dist_tangent, x_Δ) + end + + return z, logpdf_adjoint +end