diff --git a/aeppl/__init__.py b/aeppl/__init__.py index bad797b3..a3328dcf 100644 --- a/aeppl/__init__.py +++ b/aeppl/__init__.py @@ -14,6 +14,7 @@ import aeppl.cumsum import aeppl.mixture import aeppl.scan +import aeppl.tensor import aeppl.transforms import aeppl.truncation diff --git a/aeppl/mixture.py b/aeppl/mixture.py index e0e899c9..262dcf28 100644 --- a/aeppl/mixture.py +++ b/aeppl/mixture.py @@ -25,12 +25,8 @@ from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs from aeppl.logprob import _logprob, logprob -from aeppl.opt import ( - local_lift_DiracDelta, - logprob_rewrites_db, - naive_bcast_rv_lift, - subtensor_ops, -) +from aeppl.opt import local_lift_DiracDelta, logprob_rewrites_db, subtensor_ops +from aeppl.tensor import naive_bcast_rv_lift from aeppl.utils import get_constant_value diff --git a/aeppl/opt.py b/aeppl/opt.py index c52bc5f0..9dffd6ad 100644 --- a/aeppl/opt.py +++ b/aeppl/opt.py @@ -1,12 +1,10 @@ from typing import Dict, Optional, Tuple -import aesara import aesara.tensor as at from aesara.compile.mode import optdb from aesara.graph.basic import Variable from aesara.graph.features import Feature from aesara.graph.fg import FunctionGraph -from aesara.graph.op import compute_test_value from aesara.graph.opt import local_optimizer from aesara.graph.optdb import EquilibriumDB, OptimizationQuery, SequenceDB from aesara.tensor.basic_opt import ( @@ -16,12 +14,7 @@ ) from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.extra_ops import BroadcastTo -from aesara.tensor.random.op import RandomVariable -from aesara.tensor.random.opt import ( - local_dimshuffle_rv_lift, - local_rv_size_lift, - local_subtensor_rv_lift, -) +from aesara.tensor.random.opt import local_subtensor_rv_lift from aesara.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -226,72 +219,6 @@ def incsubtensor_rv_replace(fgraph, node): return [base_rv_var] -@local_optimizer([BroadcastTo]) -def naive_bcast_rv_lift(fgraph, node): - """Lift a ``BroadcastTo`` through a ``RandomVariable`` ``Op``. - - XXX: This implementation simply broadcasts the ``RandomVariable``'s - parameters, which won't always work (e.g. multivariate distributions). - - TODO: Instead, it should use ``RandomVariable.ndim_supp``--and the like--to - determine which dimensions of each parameter need to be broadcasted. - Also, this doesn't need to remove ``size`` to perform the lifting, like it - currently does. - """ - - if not ( - isinstance(node.op, BroadcastTo) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, RandomVariable) - ): - return None # pragma: no cover - - bcast_shape = node.inputs[1:] - - rv_var = node.inputs[0] - rv_node = rv_var.owner - - if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars: - return None # pragma: no cover - - # Do not replace RV if it is associated with a value variable - rv_map_feature: Optional[PreserveRVMappings] = getattr( - fgraph, "preserve_rv_mappings", None - ) - if rv_map_feature is not None and rv_var in rv_map_feature.rv_values: - return None - - if not bcast_shape: - # The `BroadcastTo` is broadcasting a scalar to a scalar (i.e. doing nothing) - assert rv_var.ndim == 0 - return [rv_var] - - size_lift_res = local_rv_size_lift.transform(fgraph, rv_node) - if size_lift_res is None: - lifted_node = rv_node - else: - _, lifted_rv = size_lift_res - lifted_node = lifted_rv.owner - - rng, size, dtype, *dist_params = lifted_node.inputs - - new_dist_params = [ - at.broadcast_to( - param, - at.broadcast_shape( - tuple(param.shape), tuple(bcast_shape), arrays_are_shapes=True - ), - ) - for param in dist_params - ] - bcasted_node = lifted_node.op.make_node(rng, size, dtype, *new_dist_params) - - if aesara.config.compute_test_value != "off": - compute_test_value(bcasted_node) - - return [bcasted_node.outputs[1]] - - logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" logprob_rewrites_db.register( @@ -311,15 +238,9 @@ def naive_bcast_rv_lift(fgraph, node): # These rewrites push random/measurable variables "down", making them closer to # (or eventually) the graph outputs. Often this is done by lifting other `Op`s # "up" through the random/measurable variables and into their inputs. -measurable_ir_rewrites_db.register( - "dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic" -) measurable_ir_rewrites_db.register( "subtensor_lift", local_subtensor_rv_lift, -5, "basic" ) -measurable_ir_rewrites_db.register( - "broadcast_to_lift", naive_bcast_rv_lift, -5, "basic" -) measurable_ir_rewrites_db.register( "incsubtensor_lift", incsubtensor_rv_replace, -5, "basic" ) diff --git a/aeppl/tensor.py b/aeppl/tensor.py new file mode 100644 index 00000000..77b33cae --- /dev/null +++ b/aeppl/tensor.py @@ -0,0 +1,86 @@ +from typing import Optional + +import aesara +from aesara import tensor as at +from aesara.graph.op import compute_test_value +from aesara.graph.opt import local_optimizer +from aesara.tensor.extra_ops import BroadcastTo +from aesara.tensor.random.op import RandomVariable +from aesara.tensor.random.opt import local_dimshuffle_rv_lift, local_rv_size_lift + +from aeppl.opt import PreserveRVMappings, measurable_ir_rewrites_db + + +@local_optimizer([BroadcastTo]) +def naive_bcast_rv_lift(fgraph, node): + """Lift a ``BroadcastTo`` through a ``RandomVariable`` ``Op``. + + XXX: This implementation simply broadcasts the ``RandomVariable``'s + parameters, which won't always work (e.g. multivariate distributions). + + TODO: Instead, it should use ``RandomVariable.ndim_supp``--and the like--to + determine which dimensions of each parameter need to be broadcasted. + Also, this doesn't need to remove ``size`` to perform the lifting, like it + currently does. + """ + + if not ( + isinstance(node.op, BroadcastTo) + and node.inputs[0].owner + and isinstance(node.inputs[0].owner.op, RandomVariable) + ): + return None # pragma: no cover + + bcast_shape = node.inputs[1:] + + rv_var = node.inputs[0] + rv_node = rv_var.owner + + if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars: + return None # pragma: no cover + + # Do not replace RV if it is associated with a value variable + rv_map_feature: Optional[PreserveRVMappings] = getattr( + fgraph, "preserve_rv_mappings", None + ) + if rv_map_feature is not None and rv_var in rv_map_feature.rv_values: + return None + + if not bcast_shape: + # The `BroadcastTo` is broadcasting a scalar to a scalar (i.e. doing nothing) + assert rv_var.ndim == 0 + return [rv_var] + + size_lift_res = local_rv_size_lift.transform(fgraph, rv_node) + if size_lift_res is None: + lifted_node = rv_node + else: + _, lifted_rv = size_lift_res + lifted_node = lifted_rv.owner + + rng, size, dtype, *dist_params = lifted_node.inputs + + new_dist_params = [ + at.broadcast_to( + param, + at.broadcast_shape( + tuple(param.shape), tuple(bcast_shape), arrays_are_shapes=True + ), + ) + for param in dist_params + ] + bcasted_node = lifted_node.op.make_node(rng, size, dtype, *new_dist_params) + + if aesara.config.compute_test_value != "off": + compute_test_value(bcasted_node) + + return [bcasted_node.outputs[1]] + + +measurable_ir_rewrites_db.register( + "dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic", "tensor" +) + +measurable_ir_rewrites_db.register( + "broadcast_to_lift", naive_bcast_rv_lift, -5, "basic", "tensor" +) diff --git a/tests/test_opt.py b/tests/test_opt.py index 24836717..feb0e448 100644 --- a/tests/test_opt.py +++ b/tests/test_opt.py @@ -1,48 +1,12 @@ import aesara import aesara.tensor as at -import numpy as np -import scipy.stats as st from aesara.graph.opt import in2out from aesara.graph.opt_utils import optimize_graph from aesara.tensor.elemwise import DimShuffle, Elemwise -from aesara.tensor.extra_ops import BroadcastTo from aesara.tensor.subtensor import Subtensor -from aeppl import factorized_joint_logprob from aeppl.dists import DiracDelta, dirac_delta -from aeppl.opt import local_lift_DiracDelta, naive_bcast_rv_lift - - -def test_naive_bcast_rv_lift(): - r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s.""" - X_rv = at.random.normal() - Z_at = BroadcastTo()(X_rv, ()) - - # Make sure we're testing what we intend to test - assert isinstance(Z_at.owner.op, BroadcastTo) - - res = optimize_graph(Z_at, custom_opt=in2out(naive_bcast_rv_lift), clone=False) - assert res is X_rv - - -def test_naive_bcast_rv_lift_valued_var(): - r"""Check that `naive_bcast_rv_lift` won't touch valued variables""" - - x_rv = at.random.normal(name="x") - broadcasted_x_rv = at.broadcast_to(x_rv, (2,)) - - y_rv = at.random.normal(broadcasted_x_rv, name="y") - - x_vv = x_rv.clone() - y_vv = y_rv.clone() - logp_map = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv}) - assert x_vv in logp_map - assert y_vv in logp_map - assert len(logp_map) == 2 - assert np.allclose(logp_map[x_vv].eval({x_vv: 0}), st.norm(0).logpdf(0)) - assert np.allclose( - logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0]) - ) +from aeppl.opt import local_lift_DiracDelta def test_local_lift_DiracDelta(): diff --git a/tests/test_tensor.py b/tests/test_tensor.py new file mode 100644 index 00000000..15e27c3e --- /dev/null +++ b/tests/test_tensor.py @@ -0,0 +1,41 @@ +import numpy as np +from aesara import tensor as at +from aesara.graph import optimize_graph +from aesara.graph.opt import in2out +from aesara.tensor.extra_ops import BroadcastTo +from scipy import stats as st + +from aeppl import factorized_joint_logprob +from aeppl.tensor import naive_bcast_rv_lift + + +def test_naive_bcast_rv_lift(): + r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s.""" + X_rv = at.random.normal() + Z_at = BroadcastTo()(X_rv, ()) + + # Make sure we're testing what we intend to test + assert isinstance(Z_at.owner.op, BroadcastTo) + + res = optimize_graph(Z_at, custom_opt=in2out(naive_bcast_rv_lift), clone=False) + assert res is X_rv + + +def test_naive_bcast_rv_lift_valued_var(): + r"""Check that `naive_bcast_rv_lift` won't touch valued variables""" + + x_rv = at.random.normal(name="x") + broadcasted_x_rv = at.broadcast_to(x_rv, (2,)) + + y_rv = at.random.normal(broadcasted_x_rv, name="y") + + x_vv = x_rv.clone() + y_vv = y_rv.clone() + logp_map = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv}) + assert x_vv in logp_map + assert y_vv in logp_map + assert len(logp_map) == 2 + assert np.allclose(logp_map[x_vv].eval({x_vv: 0}), st.norm(0).logpdf(0)) + assert np.allclose( + logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0]) + )