From cc78f30b5ed89e5b247b88d697467a76cc1e424e Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 11 May 2022 16:54:40 +0200 Subject: [PATCH] Implement Measurable Stacks --- aeppl/tensor.py | 119 +++++++++++++++++++++++++++++- tests/test_tensor.py | 168 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 285 insertions(+), 2 deletions(-) diff --git a/aeppl/tensor.py b/aeppl/tensor.py index 77b33cae..1a448a07 100644 --- a/aeppl/tensor.py +++ b/aeppl/tensor.py @@ -1,13 +1,16 @@ -from typing import Optional +from typing import List, Optional, Union import aesara from aesara import tensor as at from aesara.graph.op import compute_test_value from aesara.graph.opt import local_optimizer +from aesara.tensor.basic import Join, MakeVector from aesara.tensor.extra_ops import BroadcastTo from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.opt import local_dimshuffle_rv_lift, local_rv_size_lift +from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs +from aeppl.logprob import _logprob, logprob from aeppl.opt import PreserveRVMappings, measurable_ir_rewrites_db @@ -77,6 +80,111 @@ def naive_bcast_rv_lift(fgraph, node): return [bcasted_node.outputs[1]] +class MeasurableMakeVector(MakeVector): + """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" + + +MeasurableVariable.register(MeasurableMakeVector) + + +@_logprob.register(MeasurableMakeVector) +def logprob_make_vector(op, values, *base_vars, **kwargs): + """Compute the log-likelihood graph for a `MeasurableMakeVector`.""" + (value,) = values + + return at.stack( + [logprob(base_var, value[i]) for i, base_var in enumerate(base_vars)] + ) + + +class MeasurableJoin(Join): + """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" + + +MeasurableVariable.register(MeasurableJoin) + + +@_logprob.register(MeasurableJoin) +def logprob_join(op, values, axis, *base_vars, **kwargs): + """Compute the log-likelihood graph for a `Join`.""" + (value,) = values + + split_values = at.split( + value, + splits_size=[base_var.shape[axis] for base_var in base_vars], + n_splits=len(base_vars), + axis=axis, + ) + + logps = [ + logprob(base_var, split_value) + for base_var, split_value in zip(base_vars, split_values) + ] + + if len(set(logp.ndim for logp in logps)) != 1: + raise ValueError( + "Joined logps have different number of dimensions, this can happen when " + "joining univariate and multivariate distributions", + ) + + base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim + join_logprob = at.concatenate( + [ + at.atleast_1d(logprob(base_var, split_value)) + for base_var, split_value in zip(base_vars, split_values) + ], + axis=axis - base_vars_ndim_supp, + ) + + return join_logprob + + +@local_optimizer([MakeVector, Join]) +def find_measurable_stacks( + fgraph, node +) -> Optional[List[Union[MeasurableMakeVector, MeasurableJoin]]]: + r"""Finds `Joins`\s and `MakeVector`\s for which a `logprob` can be computed.""" + + if isinstance(node.op, (MeasurableMakeVector, MeasurableJoin)): + return None # pragma: no cover + + rv_map_feature: PreserveRVMappings = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return None # pragma: no cover + + stack_out = node.outputs[0] + + is_join = isinstance(node.op, Join) + + if is_join: + axis, *base_vars = node.inputs + else: + base_vars = node.inputs + + if not all( + base_var.owner + and isinstance(base_var.owner.op, MeasurableVariable) + and base_var not in rv_map_feature.rv_values + for base_var in base_vars + ): + return None # pragma: no cover + + # Make base_vars unmeasurable + base_vars = [ + assign_custom_measurable_outputs(base_var.owner) for base_var in base_vars + ] + + if is_join: + measurable_stack = MeasurableJoin()(axis, *base_vars) + else: + measurable_stack = MeasurableMakeVector(node.op.dtype)(*base_vars) + + measurable_stack.name = stack_out.name + + return [measurable_stack] + + measurable_ir_rewrites_db.register( "dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic", "tensor" ) @@ -84,3 +192,12 @@ def naive_bcast_rv_lift(fgraph, node): measurable_ir_rewrites_db.register( "broadcast_to_lift", naive_bcast_rv_lift, -5, "basic", "tensor" ) + + +measurable_ir_rewrites_db.register( + "find_measurable_stacks", + find_measurable_stacks, + 0, + "basic", + "tensor", +) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 15e27c3e..546cbe52 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -1,11 +1,12 @@ import numpy as np +import pytest from aesara import tensor as at from aesara.graph import optimize_graph from aesara.graph.opt import in2out from aesara.tensor.extra_ops import BroadcastTo from scipy import stats as st -from aeppl import factorized_joint_logprob +from aeppl import factorized_joint_logprob, joint_logprob from aeppl.tensor import naive_bcast_rv_lift @@ -39,3 +40,168 @@ def test_naive_bcast_rv_lift_valued_var(): assert np.allclose( logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0]) ) + + +def test_measurable_make_vector(): + base1_rv = at.random.normal(name="base1") + base2_rv = at.random.halfnormal(name="base2") + base3_rv = at.random.exponential(name="base3") + y_rv = at.stack((base1_rv, base2_rv, base3_rv)) + y_rv.name = "y" + + base1_vv = base1_rv.clone() + base2_vv = base2_rv.clone() + base3_vv = base3_rv.clone() + y_vv = y_rv.clone() + + ref_logp = joint_logprob( + {base1_rv: base1_vv, base2_rv: base2_vv, base3_rv: base3_vv} + ) + make_vector_logp = joint_logprob({y_rv: y_vv}, sum=False) + + base1_testval = base1_rv.eval() + base2_testval = base2_rv.eval() + base3_testval = base3_rv.eval() + y_testval = np.stack((base1_testval, base2_testval, base3_testval)) + + ref_logp_eval_eval = ref_logp.eval( + {base1_vv: base1_testval, base2_vv: base2_testval, base3_vv: base3_testval} + ) + make_vector_logp_eval = make_vector_logp.eval({y_vv: y_testval}) + + assert make_vector_logp_eval.shape == y_testval.shape + assert np.isclose(make_vector_logp_eval.sum(), ref_logp_eval_eval) + + +@pytest.mark.parametrize( + "size1, size2, axis, concatenate", + [ + ((5,), (3,), 0, True), + ((5,), (3,), -1, True), + ((5, 2), (3, 2), 0, True), + ((2, 5), (2, 3), 1, True), + ((2, 5), (2, 5), 0, False), + ((2, 5), (2, 5), 1, False), + ((2, 5), (2, 5), 2, False), + ], +) +def test_measurable_join_univariate(size1, size2, axis, concatenate): + base1_rv = at.random.normal(size=size1, name="base1") + base2_rv = at.random.exponential(size=size2, name="base2") + if concatenate: + y_rv = at.concatenate((base1_rv, base2_rv), axis=axis) + else: + y_rv = at.stack((base1_rv, base2_rv), axis=axis) + y_rv.name = "y" + + base1_vv = base1_rv.clone() + base2_vv = base2_rv.clone() + y_vv = y_rv.clone() + + base_logps = list( + factorized_joint_logprob({base1_rv: base1_vv, base2_rv: base2_vv}).values() + ) + if concatenate: + base_logps = at.concatenate(base_logps, axis=axis) + else: + base_logps = at.stack(base_logps, axis=axis) + y_logp = joint_logprob({y_rv: y_vv}, sum=False) + + base1_testval = base1_rv.eval() + base2_testval = base2_rv.eval() + if concatenate: + y_testval = np.concatenate((base1_testval, base2_testval), axis=axis) + else: + y_testval = np.stack((base1_testval, base2_testval), axis=axis) + np.testing.assert_allclose( + base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}), + y_logp.eval({y_vv: y_testval}), + ) + + +@pytest.mark.parametrize( + "size1, supp_size1, size2, supp_size2, axis, concatenate", + [ + (None, 2, None, 2, 0, True), + (None, 2, None, 2, -1, True), + ((5,), 2, (3,), 2, 0, True), + ((5,), 2, (3,), 2, -2, True), + ((2,), 5, (2,), 3, 1, True), + pytest.param( + (2,), + 5, + (2,), + 5, + 0, + False, + marks=pytest.mark.xfail( + reason="cannot measure dimshuffled multivariate RVs" + ), + ), + pytest.param( + (2,), + 5, + (2,), + 5, + 1, + False, + marks=pytest.mark.xfail( + reason="cannot measure dimshuffled multivariate RVs" + ), + ), + ], +) +def test_measurable_join_multivariate( + size1, supp_size1, size2, supp_size2, axis, concatenate +): + base1_rv = at.random.multivariate_normal( + np.zeros(supp_size1), np.eye(supp_size1), size=size1, name="base1" + ) + base2_rv = at.random.dirichlet(np.ones(supp_size2), size=size2, name="base2") + if concatenate: + y_rv = at.concatenate((base1_rv, base2_rv), axis=axis) + else: + y_rv = at.stack((base1_rv, base2_rv), axis=axis) + y_rv.name = "y" + + base1_vv = base1_rv.clone() + base2_vv = base2_rv.clone() + y_vv = y_rv.clone() + base_logps = [ + at.atleast_1d(logp) + for logp in factorized_joint_logprob( + {base1_rv: base1_vv, base2_rv: base2_vv} + ).values() + ] + + if concatenate: + axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim) + base_logps = at.concatenate(base_logps, axis=axis_norm - 1) + else: + axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim + 1) + base_logps = at.stack(base_logps, axis=axis_norm - 1) + y_logp = joint_logprob({y_rv: y_vv}, sum=False) + + base1_testval = base1_rv.eval() + base2_testval = base2_rv.eval() + if concatenate: + y_testval = np.concatenate((base1_testval, base2_testval), axis=axis) + else: + y_testval = np.stack((base1_testval, base2_testval), axis=axis) + print(base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}).shape) + np.testing.assert_allclose( + base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}), + y_logp.eval({y_vv: y_testval}), + ) + + +def test_join_mixed_ndim_supp(): + base1_rv = at.random.normal(size=3, name="base1") + base2_rv = at.random.dirichlet(np.ones(3), name="base2") + y_rv = at.concatenate((base1_rv, base2_rv), axis=0) + + y_vv = y_rv.clone() + with pytest.raises( + ValueError, match="Joined logps have different number of dimensions" + ): + joint_logprob({y_rv: y_vv})