Skip to content

Commit

Permalink
Implement Measurable Stacks
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed May 15, 2022
1 parent 9c76f62 commit cc78f30
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 2 deletions.
119 changes: 118 additions & 1 deletion aeppl/tensor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Optional
from typing import List, Optional, Union

import aesara
from aesara import tensor as at
from aesara.graph.op import compute_test_value
from aesara.graph.opt import local_optimizer
from aesara.tensor.basic import Join, MakeVector
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import local_dimshuffle_rv_lift, local_rv_size_lift

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.logprob import _logprob, logprob
from aeppl.opt import PreserveRVMappings, measurable_ir_rewrites_db


Expand Down Expand Up @@ -77,10 +80,124 @@ def naive_bcast_rv_lift(fgraph, node):
return [bcasted_node.outputs[1]]


class MeasurableMakeVector(MakeVector):
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""


MeasurableVariable.register(MeasurableMakeVector)


@_logprob.register(MeasurableMakeVector)
def logprob_make_vector(op, values, *base_vars, **kwargs):
"""Compute the log-likelihood graph for a `MeasurableMakeVector`."""
(value,) = values

return at.stack(
[logprob(base_var, value[i]) for i, base_var in enumerate(base_vars)]
)


class MeasurableJoin(Join):
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""


MeasurableVariable.register(MeasurableJoin)


@_logprob.register(MeasurableJoin)
def logprob_join(op, values, axis, *base_vars, **kwargs):
"""Compute the log-likelihood graph for a `Join`."""
(value,) = values

split_values = at.split(
value,
splits_size=[base_var.shape[axis] for base_var in base_vars],
n_splits=len(base_vars),
axis=axis,
)

logps = [
logprob(base_var, split_value)
for base_var, split_value in zip(base_vars, split_values)
]

if len(set(logp.ndim for logp in logps)) != 1:
raise ValueError(
"Joined logps have different number of dimensions, this can happen when "
"joining univariate and multivariate distributions",
)

base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim
join_logprob = at.concatenate(
[
at.atleast_1d(logprob(base_var, split_value))
for base_var, split_value in zip(base_vars, split_values)
],
axis=axis - base_vars_ndim_supp,
)

return join_logprob


@local_optimizer([MakeVector, Join])
def find_measurable_stacks(
fgraph, node
) -> Optional[List[Union[MeasurableMakeVector, MeasurableJoin]]]:
r"""Finds `Joins`\s and `MakeVector`\s for which a `logprob` can be computed."""

if isinstance(node.op, (MeasurableMakeVector, MeasurableJoin)):
return None # pragma: no cover

rv_map_feature: PreserveRVMappings = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

stack_out = node.outputs[0]

is_join = isinstance(node.op, Join)

if is_join:
axis, *base_vars = node.inputs
else:
base_vars = node.inputs

if not all(
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
for base_var in base_vars
):
return None # pragma: no cover

# Make base_vars unmeasurable
base_vars = [
assign_custom_measurable_outputs(base_var.owner) for base_var in base_vars
]

if is_join:
measurable_stack = MeasurableJoin()(axis, *base_vars)
else:
measurable_stack = MeasurableMakeVector(node.op.dtype)(*base_vars)

measurable_stack.name = stack_out.name

return [measurable_stack]


measurable_ir_rewrites_db.register(
"dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic", "tensor"
)

measurable_ir_rewrites_db.register(
"broadcast_to_lift", naive_bcast_rv_lift, -5, "basic", "tensor"
)


measurable_ir_rewrites_db.register(
"find_measurable_stacks",
find_measurable_stacks,
0,
"basic",
"tensor",
)
168 changes: 167 additions & 1 deletion tests/test_tensor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import numpy as np
import pytest
from aesara import tensor as at
from aesara.graph import optimize_graph
from aesara.graph.opt import in2out
from aesara.tensor.extra_ops import BroadcastTo
from scipy import stats as st

from aeppl import factorized_joint_logprob
from aeppl import factorized_joint_logprob, joint_logprob
from aeppl.tensor import naive_bcast_rv_lift


Expand Down Expand Up @@ -39,3 +40,168 @@ def test_naive_bcast_rv_lift_valued_var():
assert np.allclose(
logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0])
)


def test_measurable_make_vector():
base1_rv = at.random.normal(name="base1")
base2_rv = at.random.halfnormal(name="base2")
base3_rv = at.random.exponential(name="base3")
y_rv = at.stack((base1_rv, base2_rv, base3_rv))
y_rv.name = "y"

base1_vv = base1_rv.clone()
base2_vv = base2_rv.clone()
base3_vv = base3_rv.clone()
y_vv = y_rv.clone()

ref_logp = joint_logprob(
{base1_rv: base1_vv, base2_rv: base2_vv, base3_rv: base3_vv}
)
make_vector_logp = joint_logprob({y_rv: y_vv}, sum=False)

base1_testval = base1_rv.eval()
base2_testval = base2_rv.eval()
base3_testval = base3_rv.eval()
y_testval = np.stack((base1_testval, base2_testval, base3_testval))

ref_logp_eval_eval = ref_logp.eval(
{base1_vv: base1_testval, base2_vv: base2_testval, base3_vv: base3_testval}
)
make_vector_logp_eval = make_vector_logp.eval({y_vv: y_testval})

assert make_vector_logp_eval.shape == y_testval.shape
assert np.isclose(make_vector_logp_eval.sum(), ref_logp_eval_eval)


@pytest.mark.parametrize(
"size1, size2, axis, concatenate",
[
((5,), (3,), 0, True),
((5,), (3,), -1, True),
((5, 2), (3, 2), 0, True),
((2, 5), (2, 3), 1, True),
((2, 5), (2, 5), 0, False),
((2, 5), (2, 5), 1, False),
((2, 5), (2, 5), 2, False),
],
)
def test_measurable_join_univariate(size1, size2, axis, concatenate):
base1_rv = at.random.normal(size=size1, name="base1")
base2_rv = at.random.exponential(size=size2, name="base2")
if concatenate:
y_rv = at.concatenate((base1_rv, base2_rv), axis=axis)
else:
y_rv = at.stack((base1_rv, base2_rv), axis=axis)
y_rv.name = "y"

base1_vv = base1_rv.clone()
base2_vv = base2_rv.clone()
y_vv = y_rv.clone()

base_logps = list(
factorized_joint_logprob({base1_rv: base1_vv, base2_rv: base2_vv}).values()
)
if concatenate:
base_logps = at.concatenate(base_logps, axis=axis)
else:
base_logps = at.stack(base_logps, axis=axis)
y_logp = joint_logprob({y_rv: y_vv}, sum=False)

base1_testval = base1_rv.eval()
base2_testval = base2_rv.eval()
if concatenate:
y_testval = np.concatenate((base1_testval, base2_testval), axis=axis)
else:
y_testval = np.stack((base1_testval, base2_testval), axis=axis)
np.testing.assert_allclose(
base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}),
y_logp.eval({y_vv: y_testval}),
)


@pytest.mark.parametrize(
"size1, supp_size1, size2, supp_size2, axis, concatenate",
[
(None, 2, None, 2, 0, True),
(None, 2, None, 2, -1, True),
((5,), 2, (3,), 2, 0, True),
((5,), 2, (3,), 2, -2, True),
((2,), 5, (2,), 3, 1, True),
pytest.param(
(2,),
5,
(2,),
5,
0,
False,
marks=pytest.mark.xfail(
reason="cannot measure dimshuffled multivariate RVs"
),
),
pytest.param(
(2,),
5,
(2,),
5,
1,
False,
marks=pytest.mark.xfail(
reason="cannot measure dimshuffled multivariate RVs"
),
),
],
)
def test_measurable_join_multivariate(
size1, supp_size1, size2, supp_size2, axis, concatenate
):
base1_rv = at.random.multivariate_normal(
np.zeros(supp_size1), np.eye(supp_size1), size=size1, name="base1"
)
base2_rv = at.random.dirichlet(np.ones(supp_size2), size=size2, name="base2")
if concatenate:
y_rv = at.concatenate((base1_rv, base2_rv), axis=axis)
else:
y_rv = at.stack((base1_rv, base2_rv), axis=axis)
y_rv.name = "y"

base1_vv = base1_rv.clone()
base2_vv = base2_rv.clone()
y_vv = y_rv.clone()
base_logps = [
at.atleast_1d(logp)
for logp in factorized_joint_logprob(
{base1_rv: base1_vv, base2_rv: base2_vv}
).values()
]

if concatenate:
axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim)
base_logps = at.concatenate(base_logps, axis=axis_norm - 1)
else:
axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim + 1)
base_logps = at.stack(base_logps, axis=axis_norm - 1)
y_logp = joint_logprob({y_rv: y_vv}, sum=False)

base1_testval = base1_rv.eval()
base2_testval = base2_rv.eval()
if concatenate:
y_testval = np.concatenate((base1_testval, base2_testval), axis=axis)
else:
y_testval = np.stack((base1_testval, base2_testval), axis=axis)
print(base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}).shape)
np.testing.assert_allclose(
base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}),
y_logp.eval({y_vv: y_testval}),
)


def test_join_mixed_ndim_supp():
base1_rv = at.random.normal(size=3, name="base1")
base2_rv = at.random.dirichlet(np.ones(3), name="base2")
y_rv = at.concatenate((base1_rv, base2_rv), axis=0)

y_vv = y_rv.clone()
with pytest.raises(
ValueError, match="Joined logps have different number of dimensions"
):
joint_logprob({y_rv: y_vv})

0 comments on commit cc78f30

Please sign in to comment.