Skip to content

Commit

Permalink
Support Hessian of gamma-distributed samples
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Jan 18, 2025
1 parent 9fb2976 commit c37bebd
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 11 deletions.
63 changes: 52 additions & 11 deletions jax/_src/lax/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

from enum import Enum
from typing import Any
import numpy as np
from functools import partial

Expand All @@ -29,14 +30,30 @@
standard_naryop, standard_unop, sub,
_const, _dtype,
_float, _nary_lower_hlo, _ones, _isnan, _reduce)
from jax._src.lax.control_flow import while_loop
from jax._src.lax.control_flow import cond, scan, while_loop

from jax._src import api
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.lib.mlir.dialects import chlo
from jax._src.typing import Array, ArrayLike

def _while_loop_scan(cond_fun, body_fun, init_val, max_iter):
"""Scan-based implementation (jit ok, reverse-mode autodiff ok)."""
def _iter(val):
next_val = body_fun(val)
next_cond = cond_fun(next_val)
return next_val, next_cond

def _fun(tup, it):
val, _cond = tup
# When _cond is met, we start doing no-ops.
return cond(_cond, _iter, lambda x: (x, False), val), it

init = (init_val, cond_fun(init_val))
return scan(_fun, init, None, length=max_iter)[0][0]

def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
r"""Elementwise regularized incomplete beta integral."""
return regularized_incomplete_beta_p.bind(a, b, x)
Expand Down Expand Up @@ -250,7 +267,7 @@ def _any(predicates: Array) -> Array:
all_dimensions = tuple(range(len(predicates_shape)))
return reduce(predicates, f, bitwise_or, all_dimensions)

def _igamma_series(ax, x, a, enabled, dtype, mode):
def _igamma_series(ax, x, a, enabled, dtype, mode, *, hessian: bool = False):
def cond_fn(vals):
return _any(vals[0])

Expand Down Expand Up @@ -285,7 +302,9 @@ def body_fn(vals):
full_like(a, 0),
)

vals = while_loop(cond_fn, body_fn, init_vals)
vals = (_while_loop_scan(cond_fn, body_fn, init_vals, 256)
if hessian
else while_loop(cond_fn, body_fn, init_vals))
ans = vals[3]
dans_da = vals[6]

Expand Down Expand Up @@ -327,7 +346,9 @@ def igamma_impl(a, x, *, dtype):
full_like(a, float('nan')), output)
return output

def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode,
*,
hessian: bool = False):
eps = dtypes.finfo(dtype).eps

def cond_fn(vals):
Expand Down Expand Up @@ -418,7 +439,9 @@ def body_fn(vals):
c, pkm1, qkm1, pkm2, qkm2,
dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)

vals = while_loop(cond_fn, body_fn, init_vals)
vals = (_while_loop_scan(cond_fn, body_fn, init_vals, 256)
if hessian
else while_loop(cond_fn, body_fn, init_vals))
ans = vals[1]
if mode == IgammaMode.VALUE:
return ans * ax
Expand Down Expand Up @@ -470,7 +493,12 @@ def igamma_grad_a_impl(a, x, *, dtype):
full_like(a, float('nan')), output)
return output

def random_gamma_grad_impl(a, x, *, dtype):
def random_gamma_grad_impl(a: Array,
x: Array,
*,
dtype: Any,
hessian: bool = False
) -> Array:
is_nan = bitwise_or(_isnan(a), _isnan(x))
x_is_zero = eq(x, full_like(x,0))
domain_error = bitwise_or(lt(x, full_like(x,0)), le(a, full_like(a,0)))
Expand All @@ -480,11 +508,13 @@ def random_gamma_grad_impl(a, x, *, dtype):
ax = exp(ax)
enabled = bitwise_not(bitwise_or(bitwise_or(bitwise_or
(x_is_zero, domain_error), underflow), is_nan))
output = select(use_igammac,
-_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
dtype, IgammaMode.SAMPLE_DERIVATIVE),
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
dtype, IgammaMode.SAMPLE_DERIVATIVE))
output = select(
use_igammac,
-_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
dtype, IgammaMode.SAMPLE_DERIVATIVE,
hessian=hessian),
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
dtype, IgammaMode.SAMPLE_DERIVATIVE, hessian=hessian))
output = select(x_is_zero, full_like(output,0), output)
output = select(bitwise_or(domain_error, is_nan),
full_like(a, float('nan')), output)
Expand Down Expand Up @@ -653,10 +683,21 @@ def bessel_i0e_impl(x):

ad.defjvp(igammac_p, igammac_grada, igammac_gradx)

def random_gamma_hessian_a(g, a, x, *, dtype):
return api.grad(random_gamma_grad_impl, argnums=0)(a, x, dtype=dtype,
hessian=True)

def random_gamma_hessian_x(g, a, x, *, dtype):
return api.grad(random_gamma_grad_impl, argnums=1)(a, x, dtype=dtype,
hessian=True)

random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad')
mlir.register_lowering(random_gamma_grad_p,
mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl),
multiple_results=False))
ad.defjvp(random_gamma_grad_p,
_up_and_broadcast(random_gamma_hessian_a),
_up_and_broadcast(random_gamma_hessian_x))

zeta_p = standard_naryop([_float, _float], 'zeta')
mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta))
Expand Down
13 changes: 13 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,5 +1484,18 @@ def f():
jax.random.normal(jax.random.key(0), 1000)
f() # don't crash

class SamplingDerivativeTest(jtu.JaxTestCase):
def test_gamma_hessian(self):
# Regression test for https://github.com/google/jax/issues/16076
def hessian_sample(key: jax.Array) -> jax.Array:
((retval,),) = jax.hessian(random.gamma, argnums=(1,))(key, 0.8)
return retval

keys = random.split(random.key(0), 300)
x = jax.vmap(hessian_sample)(keys)
mean_x = jnp.mean(x, axis=-1)
self.assertArraysAllClose(mean_x, jnp.asarray(0.61), atol=0.1, rtol=0.4)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit c37bebd

Please sign in to comment.