-
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move local_dimshuffle_rv_lift and naive_bcast_rv_lift to reshape.py
- Loading branch information
Ricardo
committed
May 15, 2022
1 parent
8d21e4a
commit 9c76f62
Showing
6 changed files
with
132 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from typing import Optional | ||
|
||
import aesara | ||
from aesara import tensor as at | ||
from aesara.graph.op import compute_test_value | ||
from aesara.graph.opt import local_optimizer | ||
from aesara.tensor.extra_ops import BroadcastTo | ||
from aesara.tensor.random.op import RandomVariable | ||
from aesara.tensor.random.opt import local_dimshuffle_rv_lift, local_rv_size_lift | ||
|
||
from aeppl.opt import PreserveRVMappings, measurable_ir_rewrites_db | ||
|
||
|
||
@local_optimizer([BroadcastTo]) | ||
def naive_bcast_rv_lift(fgraph, node): | ||
"""Lift a ``BroadcastTo`` through a ``RandomVariable`` ``Op``. | ||
XXX: This implementation simply broadcasts the ``RandomVariable``'s | ||
parameters, which won't always work (e.g. multivariate distributions). | ||
TODO: Instead, it should use ``RandomVariable.ndim_supp``--and the like--to | ||
determine which dimensions of each parameter need to be broadcasted. | ||
Also, this doesn't need to remove ``size`` to perform the lifting, like it | ||
currently does. | ||
""" | ||
|
||
if not ( | ||
isinstance(node.op, BroadcastTo) | ||
and node.inputs[0].owner | ||
and isinstance(node.inputs[0].owner.op, RandomVariable) | ||
): | ||
return None # pragma: no cover | ||
|
||
bcast_shape = node.inputs[1:] | ||
|
||
rv_var = node.inputs[0] | ||
rv_node = rv_var.owner | ||
|
||
if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars: | ||
return None # pragma: no cover | ||
|
||
# Do not replace RV if it is associated with a value variable | ||
rv_map_feature: Optional[PreserveRVMappings] = getattr( | ||
fgraph, "preserve_rv_mappings", None | ||
) | ||
if rv_map_feature is not None and rv_var in rv_map_feature.rv_values: | ||
return None | ||
|
||
if not bcast_shape: | ||
# The `BroadcastTo` is broadcasting a scalar to a scalar (i.e. doing nothing) | ||
assert rv_var.ndim == 0 | ||
return [rv_var] | ||
|
||
size_lift_res = local_rv_size_lift.transform(fgraph, rv_node) | ||
if size_lift_res is None: | ||
lifted_node = rv_node | ||
else: | ||
_, lifted_rv = size_lift_res | ||
lifted_node = lifted_rv.owner | ||
|
||
rng, size, dtype, *dist_params = lifted_node.inputs | ||
|
||
new_dist_params = [ | ||
at.broadcast_to( | ||
param, | ||
at.broadcast_shape( | ||
tuple(param.shape), tuple(bcast_shape), arrays_are_shapes=True | ||
), | ||
) | ||
for param in dist_params | ||
] | ||
bcasted_node = lifted_node.op.make_node(rng, size, dtype, *new_dist_params) | ||
|
||
if aesara.config.compute_test_value != "off": | ||
compute_test_value(bcasted_node) | ||
|
||
return [bcasted_node.outputs[1]] | ||
|
||
|
||
measurable_ir_rewrites_db.register( | ||
"dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic", "tensor" | ||
) | ||
|
||
measurable_ir_rewrites_db.register( | ||
"broadcast_to_lift", naive_bcast_rv_lift, -5, "basic", "tensor" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import numpy as np | ||
from aesara import tensor as at | ||
from aesara.graph import optimize_graph | ||
from aesara.graph.opt import in2out | ||
from aesara.tensor.extra_ops import BroadcastTo | ||
from scipy import stats as st | ||
|
||
from aeppl import factorized_joint_logprob | ||
from aeppl.tensor import naive_bcast_rv_lift | ||
|
||
|
||
def test_naive_bcast_rv_lift(): | ||
r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s.""" | ||
X_rv = at.random.normal() | ||
Z_at = BroadcastTo()(X_rv, ()) | ||
|
||
# Make sure we're testing what we intend to test | ||
assert isinstance(Z_at.owner.op, BroadcastTo) | ||
|
||
res = optimize_graph(Z_at, custom_opt=in2out(naive_bcast_rv_lift), clone=False) | ||
assert res is X_rv | ||
|
||
|
||
def test_naive_bcast_rv_lift_valued_var(): | ||
r"""Check that `naive_bcast_rv_lift` won't touch valued variables""" | ||
|
||
x_rv = at.random.normal(name="x") | ||
broadcasted_x_rv = at.broadcast_to(x_rv, (2,)) | ||
|
||
y_rv = at.random.normal(broadcasted_x_rv, name="y") | ||
|
||
x_vv = x_rv.clone() | ||
y_vv = y_rv.clone() | ||
logp_map = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv}) | ||
assert x_vv in logp_map | ||
assert y_vv in logp_map | ||
assert len(logp_map) == 2 | ||
assert np.allclose(logp_map[x_vv].eval({x_vv: 0}), st.norm(0).logpdf(0)) | ||
assert np.allclose( | ||
logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0]) | ||
) |