Skip to content

Implement Kve Op and Kv helper #1081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pytensor/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
GammaIncInv,
Iv,
Ive,
Kve,
Log1mexp,
Psi,
TriGamma,
Expand Down Expand Up @@ -288,9 +289,12 @@

@jax_funcify.register(Ive)
def jax_funcify_Ive(op, **kwargs):
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
return try_import_tfp_jax_op(op, jax_op_name="bessel_ive")

Check warning on line 292 in pytensor/link/jax/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/scalar.py#L292

Added line #L292 was not covered by tests


return ive
@jax_funcify.register(Kve)
def jax_funcify_Kve(op, **kwargs):
return try_import_tfp_jax_op(op, jax_op_name="bessel_kve")


@jax_funcify.register(Log1mexp)
Expand Down
32 changes: 32 additions & 0 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,38 @@
ive = Ive(upgrade_to_float, name="ive")


class Kve(BinaryScalarOp):
"""Exponentially scaled modified Bessel function of the second kind of real order v."""

nfunc_spec = ("scipy.special.kve", 2, 1)

@staticmethod
def st_impl(v, x):
return scipy.special.kve(v, x)

Check warning on line 1291 in pytensor/scalar/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/scalar/math.py#L1291

Added line #L1291 was not covered by tests

def impl(self, v, x):
return self.st_impl(v, x)

Check warning on line 1294 in pytensor/scalar/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/scalar/math.py#L1294

Added line #L1294 was not covered by tests

def L_op(self, inputs, outputs, output_grads):
v, x = inputs
[kve_vx] = outputs
[g_out] = output_grads
# (1 -v/x) * kve(v, x) - kve(v - 1, x)
kve_vm1x = self(v - 1, x)
dx = (1 - v / x) * kve_vx - kve_vm1x

return [
grad_not_implemented(self, 0, v),
g_out * dx,
]

def c_code(self, *args, **kwargs):
raise NotImplementedError()


kve = Kve(upgrade_to_float, name="kve")


class Sigmoid(UnaryScalarOp):
"""
Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit
Expand Down
12 changes: 12 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,16 @@ def ive(v, x):
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""


@scalar_elemwise
def kve(v, x):
"""Exponentially scaled modified Bessel function of the second kind of real order v."""


def kv(v, x):
"""Modified Bessel function of the second kind of real order v."""
return kve(v, x) * exp(-x)


@scalar_elemwise
def sigmoid(x):
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
Expand Down Expand Up @@ -3040,6 +3050,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
"i1",
"iv",
"ive",
"kv",
"kve",
"sigmoid",
"expit",
"softplus",
Expand Down
16 changes: 16 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ge,
int_div,
isinf,
kve,
le,
log,
log1mexp,
Expand Down Expand Up @@ -3494,3 +3495,18 @@ def local_useless_conj(fgraph, node):
)

register_specialize(local_polygamma_to_tri_gamma)


local_log_kv = PatternNodeRewriter(
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
# During stabilize -x is converted to -1.0 * x
(log, (mul, (kve, "v", "x"), (exp, (mul, -1.0, "x")))),
(sub, (log, (kve, "v", "x")), "x"),
allow_multiple_clients=True,
name="local_log_kv",
# Start the rewrite from the less likely kve node
tracks=[kve],
get_nodes=get_clients_at_depth2,
)

register_stabilize(local_log_kv)
2 changes: 2 additions & 0 deletions tests/link/jax/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
gammainccinv,
gammaincinv,
iv,
kve,
log,
log1mexp,
polygamma,
Expand Down Expand Up @@ -157,6 +158,7 @@ def test_erfinv():
(erfcx, (0.7,)),
(erfcinv, (0.7,)),
(iv, (0.3, 0.7)),
(kve, (-2.5, 2.0)),
],
)
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
Expand Down
15 changes: 15 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
ge,
gt,
int_div,
kv,
le,
log,
log1mexp,
Expand Down Expand Up @@ -4578,3 +4579,17 @@ def test_local_batched_matmul_to_core_matmul():
x_test = rng.normal(size=(5, 3, 2))
y_test = rng.normal(size=(5, 2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)


def test_log_kv_stabilization():
x = pt.scalar("x")
out = log(kv(4.5, x))

# Expression would underflow to -inf without rewrite
mode = get_default_mode().including("stabilize")
# Reference value from mpmath
# mpmath.log(mpmath.besselk(4.5, 1000.0))
np.testing.assert_allclose(
out.eval({x: 1000.0}, mode=mode),
-1003.2180912984705,
)
38 changes: 36 additions & 2 deletions tests/tensor/test_math_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from pytensor.gradient import verify_grad
from pytensor.gradient import NullTypeGradError, verify_grad
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import Elemwise

Expand All @@ -18,7 +18,7 @@
from pytensor import tensor as pt
from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor import gammaincc, inplace, vector
from pytensor.tensor import gammaincc, inplace, kv, kve, vector
from tests import unittest_tools as utt
from tests.tensor.utils import (
_good_broadcast_unary_chi2sf,
Expand Down Expand Up @@ -1196,3 +1196,37 @@ def test_unused_grad_loop_opt(self, wrt):
[dd for i, dd in enumerate(expected_dds) if i in wrt],
rtol=rtol,
)


def test_kve():
rng = np.random.default_rng(3772)
v = vector("v")
x = vector("x")

out = kve(v[:, None], x[None, :])
test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype)
test_x = np.linspace(0, 1005, 10, dtype=x.type.dtype)

np.testing.assert_allclose(
out.eval({v: test_v, x: test_x}),
scipy.special.kve(test_v[:, None], test_x[None, :]),
)

with pytest.raises(NullTypeGradError):
grad(out.sum(), v)

verify_grad(lambda x: kv(4.5, x), [test_x + 0.5], rng=rng)


def test_kv():
v = vector("v")
x = vector("x")

out = kv(v[:, None], x[None, :])
test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype)
test_x = np.linspace(0, 512, 10, dtype=x.type.dtype)

np.testing.assert_allclose(
out.eval({v: test_v, x: test_x}),
scipy.special.kv(test_v[:, None], test_x[None, :]),
)
Loading