Skip to content

Commit 39afd21

Browse files
committed
Add stabilization rewrite for log of kv
1 parent d1e33fe commit 39afd21

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

pytensor/tensor/rewriting/math.py

+16
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
ge,
5757
int_div,
5858
isinf,
59+
kve,
5960
le,
6061
log,
6162
log1mexp,
@@ -3494,3 +3495,18 @@ def local_useless_conj(fgraph, node):
34943495
)
34953496

34963497
register_specialize(local_polygamma_to_tri_gamma)
3498+
3499+
3500+
local_log_kv = PatternNodeRewriter(
3501+
# Rewrite log(kv(v, x) = log(kve(v, x) * exp(-x) -> log(kve(v, x)) - x
3502+
# During stabilize -x is converted to -1.0 * x
3503+
(log, (mul, (kve, "v", "x"), (exp, (mul, -1.0, "x")))),
3504+
(sub, (log, (kve, "v", "x")), "x"),
3505+
allow_multiple_clients=True,
3506+
name="local_log_kv",
3507+
# Start the rewrite from the less likely kve node
3508+
tracks=[kve],
3509+
get_nodes=get_clients_at_depth2,
3510+
)
3511+
3512+
register_stabilize(local_log_kv)

tests/tensor/rewriting/test_math.py

+10
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
ge,
6262
gt,
6363
int_div,
64+
kv,
6465
le,
6566
log,
6667
log1mexp,
@@ -4578,3 +4579,12 @@ def test_local_batched_matmul_to_core_matmul():
45784579
x_test = rng.normal(size=(5, 3, 2))
45794580
y_test = rng.normal(size=(5, 2, 2))
45804581
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
4582+
4583+
4584+
def test_log_kv_stabilization():
4585+
x = pt.scalar("x")
4586+
out = log(kv(4.5, x))
4587+
4588+
# Reference value from mpmath
4589+
# mpmath.log(mpmath.besselk(4.5, 1000.0))
4590+
np.testing.assert_allclose(out.eval({x: 1000.0}), -1003.2180912984705)

0 commit comments

Comments
 (0)