Skip to content

Commit

Permalink
Move local_dimshuffle_rv_lift and naive_bcast_rv_lift to reshape.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed May 15, 2022
1 parent 8d21e4a commit 9c76f62
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 123 deletions.
1 change: 1 addition & 0 deletions aeppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import aeppl.cumsum
import aeppl.mixture
import aeppl.scan
import aeppl.tensor
import aeppl.transforms
import aeppl.truncation

Expand Down
8 changes: 2 additions & 6 deletions aeppl/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,8 @@

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.logprob import _logprob, logprob
from aeppl.opt import (
local_lift_DiracDelta,
logprob_rewrites_db,
naive_bcast_rv_lift,
subtensor_ops,
)
from aeppl.opt import local_lift_DiracDelta, logprob_rewrites_db, subtensor_ops
from aeppl.tensor import naive_bcast_rv_lift
from aeppl.utils import get_constant_value


Expand Down
81 changes: 1 addition & 80 deletions aeppl/opt.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Dict, Optional, Tuple

import aesara
import aesara.tensor as at
from aesara.compile.mode import optdb
from aesara.graph.basic import Variable
from aesara.graph.features import Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value
from aesara.graph.opt import local_optimizer
from aesara.graph.optdb import EquilibriumDB, OptimizationQuery, SequenceDB
from aesara.tensor.basic_opt import (
Expand All @@ -16,12 +14,7 @@
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import (
local_dimshuffle_rv_lift,
local_rv_size_lift,
local_subtensor_rv_lift,
)
from aesara.tensor.random.opt import local_subtensor_rv_lift
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -226,72 +219,6 @@ def incsubtensor_rv_replace(fgraph, node):
return [base_rv_var]


@local_optimizer([BroadcastTo])
def naive_bcast_rv_lift(fgraph, node):
"""Lift a ``BroadcastTo`` through a ``RandomVariable`` ``Op``.
XXX: This implementation simply broadcasts the ``RandomVariable``'s
parameters, which won't always work (e.g. multivariate distributions).
TODO: Instead, it should use ``RandomVariable.ndim_supp``--and the like--to
determine which dimensions of each parameter need to be broadcasted.
Also, this doesn't need to remove ``size`` to perform the lifting, like it
currently does.
"""

if not (
isinstance(node.op, BroadcastTo)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, RandomVariable)
):
return None # pragma: no cover

bcast_shape = node.inputs[1:]

rv_var = node.inputs[0]
rv_node = rv_var.owner

if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars:
return None # pragma: no cover

# Do not replace RV if it is associated with a value variable
rv_map_feature: Optional[PreserveRVMappings] = getattr(
fgraph, "preserve_rv_mappings", None
)
if rv_map_feature is not None and rv_var in rv_map_feature.rv_values:
return None

if not bcast_shape:
# The `BroadcastTo` is broadcasting a scalar to a scalar (i.e. doing nothing)
assert rv_var.ndim == 0
return [rv_var]

size_lift_res = local_rv_size_lift.transform(fgraph, rv_node)
if size_lift_res is None:
lifted_node = rv_node
else:
_, lifted_rv = size_lift_res
lifted_node = lifted_rv.owner

rng, size, dtype, *dist_params = lifted_node.inputs

new_dist_params = [
at.broadcast_to(
param,
at.broadcast_shape(
tuple(param.shape), tuple(bcast_shape), arrays_are_shapes=True
),
)
for param in dist_params
]
bcasted_node = lifted_node.op.make_node(rng, size, dtype, *new_dist_params)

if aesara.config.compute_test_value != "off":
compute_test_value(bcasted_node)

return [bcasted_node.outputs[1]]


logprob_rewrites_db = SequenceDB()
logprob_rewrites_db.name = "logprob_rewrites_db"
logprob_rewrites_db.register(
Expand All @@ -311,15 +238,9 @@ def naive_bcast_rv_lift(fgraph, node):
# These rewrites push random/measurable variables "down", making them closer to
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
# "up" through the random/measurable variables and into their inputs.
measurable_ir_rewrites_db.register(
"dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic"
)
measurable_ir_rewrites_db.register(
"subtensor_lift", local_subtensor_rv_lift, -5, "basic"
)
measurable_ir_rewrites_db.register(
"broadcast_to_lift", naive_bcast_rv_lift, -5, "basic"
)
measurable_ir_rewrites_db.register(
"incsubtensor_lift", incsubtensor_rv_replace, -5, "basic"
)
Expand Down
86 changes: 86 additions & 0 deletions aeppl/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Optional

import aesara
from aesara import tensor as at
from aesara.graph.op import compute_test_value
from aesara.graph.opt import local_optimizer
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import local_dimshuffle_rv_lift, local_rv_size_lift

from aeppl.opt import PreserveRVMappings, measurable_ir_rewrites_db


@local_optimizer([BroadcastTo])
def naive_bcast_rv_lift(fgraph, node):
"""Lift a ``BroadcastTo`` through a ``RandomVariable`` ``Op``.
XXX: This implementation simply broadcasts the ``RandomVariable``'s
parameters, which won't always work (e.g. multivariate distributions).
TODO: Instead, it should use ``RandomVariable.ndim_supp``--and the like--to
determine which dimensions of each parameter need to be broadcasted.
Also, this doesn't need to remove ``size`` to perform the lifting, like it
currently does.
"""

if not (
isinstance(node.op, BroadcastTo)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, RandomVariable)
):
return None # pragma: no cover

bcast_shape = node.inputs[1:]

rv_var = node.inputs[0]
rv_node = rv_var.owner

if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars:
return None # pragma: no cover

# Do not replace RV if it is associated with a value variable
rv_map_feature: Optional[PreserveRVMappings] = getattr(
fgraph, "preserve_rv_mappings", None
)
if rv_map_feature is not None and rv_var in rv_map_feature.rv_values:
return None

if not bcast_shape:
# The `BroadcastTo` is broadcasting a scalar to a scalar (i.e. doing nothing)
assert rv_var.ndim == 0
return [rv_var]

size_lift_res = local_rv_size_lift.transform(fgraph, rv_node)
if size_lift_res is None:
lifted_node = rv_node
else:
_, lifted_rv = size_lift_res
lifted_node = lifted_rv.owner

rng, size, dtype, *dist_params = lifted_node.inputs

new_dist_params = [
at.broadcast_to(
param,
at.broadcast_shape(
tuple(param.shape), tuple(bcast_shape), arrays_are_shapes=True
),
)
for param in dist_params
]
bcasted_node = lifted_node.op.make_node(rng, size, dtype, *new_dist_params)

if aesara.config.compute_test_value != "off":
compute_test_value(bcasted_node)

return [bcasted_node.outputs[1]]


measurable_ir_rewrites_db.register(
"dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic", "tensor"
)

measurable_ir_rewrites_db.register(
"broadcast_to_lift", naive_bcast_rv_lift, -5, "basic", "tensor"
)
38 changes: 1 addition & 37 deletions tests/test_opt.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,12 @@
import aesara
import aesara.tensor as at
import numpy as np
import scipy.stats as st
from aesara.graph.opt import in2out
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.subtensor import Subtensor

from aeppl import factorized_joint_logprob
from aeppl.dists import DiracDelta, dirac_delta
from aeppl.opt import local_lift_DiracDelta, naive_bcast_rv_lift


def test_naive_bcast_rv_lift():
r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s."""
X_rv = at.random.normal()
Z_at = BroadcastTo()(X_rv, ())

# Make sure we're testing what we intend to test
assert isinstance(Z_at.owner.op, BroadcastTo)

res = optimize_graph(Z_at, custom_opt=in2out(naive_bcast_rv_lift), clone=False)
assert res is X_rv


def test_naive_bcast_rv_lift_valued_var():
r"""Check that `naive_bcast_rv_lift` won't touch valued variables"""

x_rv = at.random.normal(name="x")
broadcasted_x_rv = at.broadcast_to(x_rv, (2,))

y_rv = at.random.normal(broadcasted_x_rv, name="y")

x_vv = x_rv.clone()
y_vv = y_rv.clone()
logp_map = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})
assert x_vv in logp_map
assert y_vv in logp_map
assert len(logp_map) == 2
assert np.allclose(logp_map[x_vv].eval({x_vv: 0}), st.norm(0).logpdf(0))
assert np.allclose(
logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0])
)
from aeppl.opt import local_lift_DiracDelta


def test_local_lift_DiracDelta():
Expand Down
41 changes: 41 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
from aesara import tensor as at
from aesara.graph import optimize_graph
from aesara.graph.opt import in2out
from aesara.tensor.extra_ops import BroadcastTo
from scipy import stats as st

from aeppl import factorized_joint_logprob
from aeppl.tensor import naive_bcast_rv_lift


def test_naive_bcast_rv_lift():
r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s."""
X_rv = at.random.normal()
Z_at = BroadcastTo()(X_rv, ())

# Make sure we're testing what we intend to test
assert isinstance(Z_at.owner.op, BroadcastTo)

res = optimize_graph(Z_at, custom_opt=in2out(naive_bcast_rv_lift), clone=False)
assert res is X_rv


def test_naive_bcast_rv_lift_valued_var():
r"""Check that `naive_bcast_rv_lift` won't touch valued variables"""

x_rv = at.random.normal(name="x")
broadcasted_x_rv = at.broadcast_to(x_rv, (2,))

y_rv = at.random.normal(broadcasted_x_rv, name="y")

x_vv = x_rv.clone()
y_vv = y_rv.clone()
logp_map = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})
assert x_vv in logp_map
assert y_vv in logp_map
assert len(logp_map) == 2
assert np.allclose(logp_map[x_vv].eval({x_vv: 0}), st.norm(0).logpdf(0))
assert np.allclose(
logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0])
)

0 comments on commit 9c76f62

Please sign in to comment.