Skip to content

Commit

Permalink
Add Precondition interpretation for Gaussian TVE (#553)
Browse files Browse the repository at this point in the history
* Add Precondition interpretation for Gaussian TVE

* Fix bugs; add a test

* Add an ops.getslice for more complex eager indexing

* Fix bugs, add patterns, add more tests

* fix is_affine()

* Add more tests

* Fix eager_getslice_lambda

* Switch to sqrt(precision) representation in Gaussian

* Fix some bugs

* Fix more math

* Add GaussianMeta conversions; fix broadcasting bug

* Fix some distribution tests

* Refactor from info_vec to white_vec

* Fix more tests

* Flesh our matrix_and_mvn_to_funsor()

* Work our marginalization

* fix more tests

* Fix more tests

* Fix test_gaussian.py

* Fix distribution patterns

* Fix argmax approximation

* Remove Gaussian.negate attribute

* Fix matrix_and_mvn_to_funsor diag (full still broken)

* Fix old uses of info_vec

* Add a test

* Fix shape bug in matrix_and_mvn_to_funsor()

* Enable pprint for funsors

* Revert pp property

* Fix matrix_and_mvn_to_funsor()

* Relax rank condition

* Fix ._sample()

* Fix eager_contraction_to_binary

* Fix test_joint.py

* Fix comparisons in sequential sum product

* Fix saarka bilmes test

* Add and xfail tests of singular matrices

* Fix rank deficiency issues

* Add gaussian integrate patterns

* Fix comment

* Add a set_compression_threshold context manager

* Update docstring

* Fix backward sampling support bug

* Xfail test_elbo.py::test_complex

* Relax test thresholds

* Fix ops.qr numpy backend

* Fix jax tests

* Fix bugs

* Tweak sensor example

* Fix bugs

* Add more precondition approximate patterns

* Address review comments

* Add Sub[Gaussian, tuple] pattern

* Sketch implementation of partial sampling from Gaussians

* Fix bug

* Fix a bug in partial sampling

* Get partial sampling working

* Reorder Gaussians in cnf

* Fix batch shape computation

* Add pattern to fuse nested Subs

* Relax tolerance

* Fix eager_finitary_cat

* Increase sample count

* Fix jax backend for ops.randn

* Revert Gaussian - Gaussian pattern

* Relax tolerance

* Remove obsolete test
  • Loading branch information
fritzo authored Nov 12, 2021
1 parent c1a7258 commit 93250f9
Show file tree
Hide file tree
Showing 16 changed files with 631 additions and 134 deletions.
7 changes: 7 additions & 0 deletions docs/source/interpretations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ Monte Carlo
:show-inheritance:
:member-order: bysource

Preconditioning
---------------
.. automodule:: funsor.precondition
:members:
:show-inheritance:
:member-order: bysource

Approximations
--------------
.. automodule:: funsor.approximations
Expand Down
2 changes: 2 additions & 0 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
joint,
montecarlo,
ops,
precondition,
recipes,
sum_product,
terms,
Expand Down Expand Up @@ -102,6 +103,7 @@
"montecarlo",
"of_shape",
"ops",
"precondition",
"pretty",
"quote",
"reals",
Expand Down
7 changes: 2 additions & 5 deletions funsor/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ def __enter__(self):
self._old_interpretation = interpreter.get_interpretation()
return super().__enter__()

def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=frozenset()):
# TODO Replace this with root + Constant(...) after #548 merges.
root_vars = root.input_vars | batch_vars

def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=set()):
zero = to_funsor(ops.UNITS[sum_op])
one = to_funsor(ops.UNITS[bin_op])
adjoint_values = defaultdict(lambda: zero)
Expand Down Expand Up @@ -118,7 +115,7 @@ def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=frozenset())
in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs)
for v, adjv in in_adjs:
# Marginalize out message variables that don't appear in recipients.
agg_vars = adjv.input_vars - v.input_vars - root_vars
agg_vars = adjv.input_vars - v.input_vars - root.input_vars - batch_vars
assert "particle" not in {var.name for var in agg_vars} # DEBUG FIXME
old_value = adjoint_values[v]
adjoint_values[v] = sum_op(old_value, adjv.reduce(sum_op, agg_vars))
Expand Down
21 changes: 15 additions & 6 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
sampled_vars = sampled_vars.intersection(self.inputs)
if not sampled_vars:
return self
for term in self.terms:
if isinstance(term, Delta):
sampled_vars -= term.fresh
if not sampled_vars:
return self

if self.red_op in (ops.null, ops.logaddexp):
if rng_key is not None and get_backend() == "jax":
Expand All @@ -116,8 +121,8 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
rng_keys = [None] * len(self.terms)

if self.bin_op in (ops.null, ops.logaddexp):
# Design choice: we sample over logaddexp reductions, but leave logaddexp
# binary choices symbolic.
# Design choice: we sample over logaddexp reductions, but leave
# logaddexp binary choices symbolic.
terms = [
term._sample(
sampled_vars.intersection(term.inputs), sample_inputs, rng_key
Expand All @@ -132,11 +137,15 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
greedy_vars = sampled_vars.intersection(term.inputs)
if greedy_vars:
break
assert greedy_vars
greedy_terms, terms = [], []
for term in self.terms:
(
terms if greedy_vars.isdisjoint(term.inputs) else greedy_terms
).append(term)
if greedy_vars.isdisjoint(term.inputs):
terms.append(term)
elif isinstance(term, Delta) and greedy_vars.isdisjoint(term.fresh):
terms.append(term)
else:
greedy_terms.append(term)
if len(greedy_terms) == 1:
term = greedy_terms[0]
terms.append(term._sample(greedy_vars, sample_inputs, rng_keys[0]))
Expand Down Expand Up @@ -392,7 +401,7 @@ def _(fn):
# Normalizing Contractions
##########################################

ORDERING = {Delta: 1, Number: 2, Tensor: 3, Gaussian: 4}
ORDERING = {Delta: 1, Number: 2, Tensor: 3, Gaussian: 4, Unary[ops.NegOp, Gaussian]: 5}
GROUND_TERMS = tuple(ORDERING)


Expand Down
16 changes: 16 additions & 0 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,22 @@ def _find_domain_stack(op, parts):
return output


@find_domain.register(ops.CatOp)
def _find_domain_cat(op, parts):
dim = op.defaults["axis"]
if dim >= 0:
event_dims = {len(x.shape) for x in parts}
assert len(event_dims) == 1, "undefined"
dim = dim - next(iter(event_dims))
assert dim < 0
shape = broadcast_shape(*(x.shape[:dim] for x in parts))
shape += (sum(x.shape[dim] for x in parts),)
if dim < -1:
shape += broadcast_shape(*(x.shape[dim + 1 :] for x in parts))
output = Array[parts[0].dtype, shape]
return output


@find_domain.register(ops.EinsumOp)
def _find_domain_einsum(op, operands):
equation = op.defaults["equation"]
Expand Down
Loading

0 comments on commit 93250f9

Please sign in to comment.