Skip to content

Commit 0d88da3

Browse files
committed
Implement Kv Op
1 parent fdbf3aa commit 0d88da3

File tree

5 files changed

+71
-7
lines changed

5 files changed

+71
-7
lines changed

pytensor/link/jax/dispatch/scalar.py

+11
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
GammaIncInv,
3232
Iv,
3333
Ive,
34+
Kv,
3435
Log1mexp,
3536
Psi,
3637
TriGamma,
@@ -293,6 +294,16 @@ def jax_funcify_Ive(op, **kwargs):
293294
return ive
294295

295296

297+
@jax_funcify.register(Kv)
298+
def jax_funcify_Kv(op, **kwargs):
299+
kve = try_import_tfp_jax_op(op, jax_op_name="bessel_kve")
300+
301+
def kv(v, x):
302+
return kve(v, x) / jnp.exp(jnp.abs(x))
303+
304+
return kv
305+
306+
296307
@jax_funcify.register(Log1mexp)
297308
def jax_funcify_Log1mexp(op, node, **kwargs):
298309
def log1mexp(x):

pytensor/scalar/math.py

+30
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,36 @@ def c_code(self, *args, **kwargs):
12811281
ive = Ive(upgrade_to_float, name="ive")
12821282

12831283

1284+
class Kv(BinaryScalarOp):
1285+
"""Modified Bessel function of the second kind of real order v."""
1286+
1287+
nfunc_spec = ("scipy.special.kv", 2, 1)
1288+
1289+
@staticmethod
1290+
def st_impl(v, x):
1291+
return scipy.special.kv(v, x)
1292+
1293+
def impl(self, v, x):
1294+
return self.st_impl(v, x)
1295+
1296+
def L_op(self, inputs, outputs, output_grads):
1297+
v, x = inputs
1298+
[out] = outputs
1299+
[g_out] = output_grads
1300+
# -(v / x) * kv(v, x) - kv(v - 1, x)
1301+
dx = -(v / x) * out - self(v - 1, x)
1302+
return [
1303+
grad_not_implemented(self, 0, v),
1304+
g_out * dx,
1305+
]
1306+
1307+
def c_code(self, *args, **kwargs):
1308+
raise NotImplementedError()
1309+
1310+
1311+
kv = Kv(upgrade_to_float, name="kv")
1312+
1313+
12841314
class Sigmoid(UnaryScalarOp):
12851315
"""
12861316
Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit

pytensor/tensor/math.py

+6
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,11 @@ def ive(v, x):
12291229
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""
12301230

12311231

1232+
@scalar_elemwise
1233+
def kv(v, x):
1234+
"""Modified Bessel function of the second kind of real order v."""
1235+
1236+
12321237
@scalar_elemwise
12331238
def sigmoid(x):
12341239
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
@@ -3040,6 +3045,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
30403045
"i1",
30413046
"iv",
30423047
"ive",
3048+
"kv",
30433049
"sigmoid",
30443050
"expit",
30453051
"softplus",

tests/link/jax/test_scalar.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
gammainccinv,
2222
gammaincinv,
2323
iv,
24+
kv,
2425
log,
2526
log1mexp,
2627
polygamma,
@@ -153,11 +154,7 @@ def test_erfinv():
153154

154155
@pytest.mark.parametrize(
155156
"op, test_values",
156-
[
157-
(erfcx, (0.7,)),
158-
(erfcinv, (0.7,)),
159-
(iv, (0.3, 0.7)),
160-
],
157+
[(erfcx, (0.7,)), (erfcinv, (0.7,)), (iv, (0.3, 0.7)), (kv, (-2.5, 2.0))],
161158
)
162159
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
163160
def test_tfp_ops(op, test_values):

tests/tensor/test_math_scipy.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55

6-
from pytensor.gradient import verify_grad
6+
from pytensor.gradient import NullTypeGradError, verify_grad
77
from pytensor.scalar import ScalarLoop
88
from pytensor.tensor.elemwise import Elemwise
99

@@ -18,7 +18,7 @@
1818
from pytensor import tensor as pt
1919
from pytensor.compile.mode import get_default_mode
2020
from pytensor.configdefaults import config
21-
from pytensor.tensor import gammaincc, inplace, vector
21+
from pytensor.tensor import gammaincc, inplace, kv, vector
2222
from tests import unittest_tools as utt
2323
from tests.tensor.utils import (
2424
_good_broadcast_unary_chi2sf,
@@ -1196,3 +1196,23 @@ def test_unused_grad_loop_opt(self, wrt):
11961196
[dd for i, dd in enumerate(expected_dds) if i in wrt],
11971197
rtol=rtol,
11981198
)
1199+
1200+
1201+
def test_kv():
1202+
rng = np.random.default_rng(3772)
1203+
v = vector("v")
1204+
x = vector("x")
1205+
1206+
out = kv(v[:, None], x[None, :])
1207+
test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype)
1208+
test_x = np.linspace(0, 5, 10, dtype=x.type.dtype)
1209+
1210+
np.testing.assert_allclose(
1211+
out.eval({v: test_v, x: test_x}),
1212+
scipy.special.kv(test_v[:, None], test_x[None, :]),
1213+
)
1214+
1215+
with pytest.raises(NullTypeGradError):
1216+
grad(out.sum(), v)
1217+
1218+
verify_grad(lambda x: kv(4.5, x), [test_x + 0.5], rng=rng)

0 commit comments

Comments
 (0)