From 72dc64f1a59bd94ef5fa7182648fbfd01dfc7849 Mon Sep 17 00:00:00 2001 From: Nikos Ignatiadis Date: Fri, 2 Dec 2022 16:02:43 -0500 Subject: [PATCH 1/2] add invtrigamma --- src/SpecialFunctions.jl | 4 +++- src/chainrules.jl | 5 ++++- src/gamma.jl | 43 +++++++++++++++++++++++++++++++++++++++++ test/chainrules.jl | 4 ++++ test/gamma.jl | 15 ++++++++++++++ test/other_tests.jl | 4 ++-- 6 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/SpecialFunctions.jl b/src/SpecialFunctions.jl index c4fe6120..ab456fb1 100644 --- a/src/SpecialFunctions.jl +++ b/src/SpecialFunctions.jl @@ -62,6 +62,7 @@ export invdigamma, polygamma, trigamma, + invtrigamma, gamma_inc, beta_inc, beta_inc_inv, @@ -93,7 +94,8 @@ include("chainrules.jl") include("deprecated.jl") for f in (:digamma, :erf, :erfc, :erfcinv, :erfcx, :erfi, :erfinv, :logerfc, :logerfcx, - :eta, :gamma, :invdigamma, :logfactorial, :lgamma, :trigamma, :ellipk, :ellipe) + :eta, :gamma, :invdigamma, :invtrigamma, :logfactorial, :lgamma, :trigamma, + :ellipk, :ellipe) @eval $(f)(::Missing) = missing end for f in (:beta, :lbeta) diff --git a/src/chainrules.jl b/src/chainrules.jl index 76a8f07d..4adec228 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -72,7 +72,10 @@ ChainRulesCore.@scalar_rule( inv(trigamma(invdigamma(x))), ) ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x)) - +ChainRulesCore.@scalar_rule( + invtrigamma(x), + inv(polygamma(2, invtrigamma(x))), +) # Bessel functions ChainRulesCore.@scalar_rule( besselj(ν, x), diff --git a/src/gamma.jl b/src/gamma.jl index 9fbb96d5..a5a16357 100644 --- a/src/gamma.jl +++ b/src/gamma.jl @@ -398,6 +398,49 @@ function _invdigamma(y::Float64) return x_new end +""" + invtrigamma(x) +Compute the inverse [`trigamma`](@ref) function of `x`. +""" +invtrigamma(y::Number) = _invtrigamma(float(y)) + +function _invtrigamma(y::Float64) + # Implementation of Newton algorithm described in + # "Linear Models and Empirical Bayes Methods for Assessing + # Differential Expression in Microarray Experiments" + # (Appendix "Inversion of Trigamma Function") + # by Gordon K. Smyth, 2004 + + if y <= 0 + throw(DomainError(y, "Only positive `y` supported.")) + end + + if y > 1e7 + return inv(sqrt(y)) + elseif y < 1e-6 + return inv(y) + end + + x_old = inv(y) + 0.5 + x_new = x_old + + # Newton iteration + δ = Inf + iteration = 0 + while δ > 1e-8 && iteration <= 25 + iteration += 1 + f_x_old = trigamma(x_old) + δx = f_x_old*(1-f_x_old/y) / polygamma(2, x_old) + x_new = x_old + δx + δ = - δx / x_new + x_old = x_new + end + + return x_new +end + + + """ zeta(s) diff --git a/test/chainrules.jl b/test/chainrules.jl index 1754d591..ef008da6 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -22,6 +22,10 @@ test_scalar(invdigamma, x) end + if x isa Real && x > 0 + test_scalar(invtrigamma, x) + end + if x isa Real && 0 < x < 1 test_scalar(erfinv, x) test_scalar(erfcinv, x) diff --git a/test/gamma.jl b/test/gamma.jl index 4509a23c..35702fe7 100644 --- a/test/gamma.jl +++ b/test/gamma.jl @@ -47,6 +47,21 @@ @test abs(invdigamma(2)) == abs(invdigamma(2.)) end + @testset "invtrigamma" begin + for val in [0.001, 0.01, 0.1, 1.0, 10.0] + @test invtrigamma(trigamma(val)) ≈ val + end + + for val in [1e-8, 0.001, 0.01, 0.1, 1.0, 10.0, 1e7, 1e9] + @test trigamma(invtrigamma(val)) ≈ val + end + + @test_throws DomainError invtrigamma(-1.0) + @test invtrigamma(2) == invtrigamma(2.) + end + + #@test "invtrigamma" begin + @testset "polygamma" begin @test polygamma(20, 7.) ≈ -4.644616027240543262561198814998587152547 @test polygamma(20, Float16(7.)) ≈ -4.644616027240543262561198814998587152547 diff --git a/test/other_tests.jl b/test/other_tests.jl index b7ab0694..7b3bcf01 100644 --- a/test/other_tests.jl +++ b/test/other_tests.jl @@ -76,7 +76,7 @@ end @testset "missing data" begin for f in (digamma, erf, erfc, erfcinv, erfcx, erfi, erfinv, eta, gamma, - invdigamma, logfactorial, trigamma) + invdigamma, invtrigamma, logfactorial, trigamma) @test f(missing) === missing end @test beta(1.0, missing) === missing @@ -90,7 +90,7 @@ end for n in numbers @test abs(n) == SpecialFunctions.fastabs(n) end - + numbers = [1im, 2 + 2im, 0 + 100im, 1e3 + 1e-10im] for n in numbers @test abs(real(n)) + abs(imag(n)) == SpecialFunctions.fastabs(n) From cc445ea24146ae1c77cb857ceb06ab10889475e1 Mon Sep 17 00:00:00 2001 From: Nikos Ignatiadis Date: Fri, 2 Dec 2022 16:04:51 -0500 Subject: [PATCH 2/2] remove floating comment --- test/gamma.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/gamma.jl b/test/gamma.jl index 35702fe7..d94d4e24 100644 --- a/test/gamma.jl +++ b/test/gamma.jl @@ -60,8 +60,6 @@ @test invtrigamma(2) == invtrigamma(2.) end - #@test "invtrigamma" begin - @testset "polygamma" begin @test polygamma(20, 7.) ≈ -4.644616027240543262561198814998587152547 @test polygamma(20, Float16(7.)) ≈ -4.644616027240543262561198814998587152547