@@ -72,7 +72,7 @@ Base.show(io::IO, d::Dirichlet) = show(io, d, (:alpha,))
7272length (d:: Dirichlet ) = length (d. alpha)
7373mean (d:: Dirichlet ) = d. alpha .* inv (d. alpha0)
7474params (d:: Dirichlet ) = (d. alpha,)
75- @inline partype (d :: Dirichlet{T} ) where {T<: Real } = T
75+ @inline partype (:: Dirichlet{T} ) where {T<: Real } = T
7676
7777function var (d:: Dirichlet )
7878 α0 = d. alpha0
@@ -375,3 +375,62 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
375375 elogp = mean_logp (suffstats (Dirichlet, P, w))
376376 fit_dirichlet! (elogp, α; maxiter= maxiter, tol= tol, debug= debug)
377377end
378+
379+ # # Differentiation
380+ function ChainRulesCore. frule ((_, Δalpha):: Tuple{Any,Any} , :: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool = true ) where {T <: Real , DT <: Union{Dirichlet{T}, Dirichlet} }
381+ d = DT (alpha; check_args= check_args)
382+ ∂alpha0 = sum (Δalpha)
383+ digamma_alpha0 = SpecialFunctions. digamma (d. alpha0)
384+ ∂lmnB = sum (Broadcast. instantiate (Broadcast. broadcasted (Δalpha, alpha) do Δalphai, alphai
385+ Δalphai * (SpecialFunctions. digamma (alphai) - digamma_alpha0)
386+ end ))
387+ Δd = ChainRulesCore. Tangent {typeof(d)} (; alpha= Δalpha, alpha0= ∂alpha0, lmnB= ∂lmnB)
388+ return d, Δd
389+ end
390+
391+ function ChainRulesCore. rrule (:: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool = true ) where {T <: Real , DT <: Union{Dirichlet{T}, Dirichlet} }
392+ d = DT (alpha; check_args= check_args)
393+ digamma_alpha0 = SpecialFunctions. digamma (d. alpha0)
394+ function Dirichlet_pullback (_Δd)
395+ Δd = ChainRulesCore. unthunk (_Δd)
396+ Δalpha = Δd. alpha .+ Δd. alpha0 .+ Δd. lmnB .* (SpecialFunctions. digamma .(alpha) .- digamma_alpha0)
397+ return ChainRulesCore. NoTangent (), Δalpha
398+ end
399+ return d, Dirichlet_pullback
400+ end
401+
402+ function ChainRulesCore. frule ((_, Δd, Δx):: Tuple{Any,Any,Any} , :: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
403+ Ω = _logpdf (d, x)
404+ ∂alpha = sum (Broadcast. instantiate (Broadcast. broadcasted (Δd. alpha, Δx, d. alpha, x) do Δalphai, Δxi, alphai, xi
405+ xlogy (Δalphai, xi) + (alphai - 1 ) * Δxi / xi
406+ end ))
407+ ∂lmnB = - Δd. lmnB
408+ ΔΩ = ∂alpha + ∂lmnB
409+ if ! isfinite (Ω)
410+ ΔΩ = oftype (ΔΩ, NaN )
411+ end
412+ return Ω, ΔΩ
413+ end
414+
415+ function ChainRulesCore. rrule (:: typeof (_logpdf), d:: T , x:: AbstractVector{<:Real} ) where {T<: Dirichlet }
416+ Ω = _logpdf (d, x)
417+ isfinite_Ω = isfinite (Ω)
418+ alpha = d. alpha
419+ function _logpdf_Dirichlet_pullback (_ΔΩ)
420+ ΔΩ = ChainRulesCore. unthunk (_ΔΩ)
421+ ∂alpha = _logpdf_Dirichlet_∂alphai .(x, ΔΩ, isfinite_Ω)
422+ ∂lmnB = isfinite_Ω ? - float (ΔΩ) : oftype (float (ΔΩ), NaN )
423+ Δd = ChainRulesCore. Tangent {T} (; alpha= ∂alpha, lmnB= ∂lmnB)
424+ Δx = _logpdf_Dirichlet_Δxi .(ΔΩ, alpha, x, isfinite_Ω)
425+ return ChainRulesCore. NoTangent (), Δd, Δx
426+ end
427+ return Ω, _logpdf_Dirichlet_pullback
428+ end
429+ function _logpdf_Dirichlet_∂alphai (xi, ΔΩi, isfinite:: Bool )
430+ ∂alphai = xlogy .(ΔΩi, xi)
431+ return isfinite ? ∂alphai : oftype (∂alphai, NaN )
432+ end
433+ function _logpdf_Dirichlet_Δxi (ΔΩi, alphai, xi, isfinite:: Bool )
434+ Δxi = ΔΩi * (alphai - 1 ) / xi
435+ return isfinite ? Δxi : oftype (Δxi, NaN )
436+ end
0 commit comments