Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learn white_vec parameter of AutoGaussian guide #2946

Merged
merged 38 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
17d05fb
Learn white_vec in AutoGaussian
fritzo Oct 14, 2021
c87a383
Fix bugs
fritzo Oct 14, 2021
73e6037
Link to issue
fritzo Oct 14, 2021
2f89668
Attempt to fix AutoGaussian dispatch
fritzo Oct 14, 2021
d0ee4f4
Merge branch 'fix-auto-gaussian' into learn-white-vec
fritzo Oct 14, 2021
5dd3dd6
Fix some tests
fritzo Oct 20, 2021
c91a71b
Speed up test_median
fritzo Oct 20, 2021
e2bb3c3
Support more temperatures
fritzo Oct 20, 2021
152332a
Merge branch 'dev' into learn-white-vec
fritzo Oct 20, 2021
88d303f
Merge branch 'dev' into fix-auto-gaussian
fritzo Oct 20, 2021
532f2f1
Add xfailing tests
fritzo Oct 20, 2021
0c4c4ac
Fix bug excluding obs sites from prototype_trace
fritzo Oct 21, 2021
c271a92
Fix more bugs
fritzo Oct 22, 2021
4693933
Fix more tests
fritzo Oct 22, 2021
e606e61
Add failing test of elbo gradient
fritzo Oct 22, 2021
f377149
lint
fritzo Oct 22, 2021
f378036
Make test less trivial
fritzo Oct 26, 2021
34ce03c
Strengthen tests, make AutoGaussian abstract
fritzo Oct 27, 2021
dab6d95
Add has_rsample kwarg to pyro.factor
fritzo Oct 27, 2021
722581f
Fix tests
fritzo Oct 27, 2021
7110f8a
Add has_rsample kwarg to pyro.factor
fritzo Oct 27, 2021
60925d7
Require specification of has_rsample for pyro.factor in guides
fritzo Oct 27, 2021
70b316e
Merge branch 'factor-in-guide' into fix-auto-gaussian
fritzo Oct 27, 2021
c817cac
Remove debug statement
fritzo Oct 27, 2021
882d36e
Update AutoGaussian
fritzo Oct 27, 2021
de2ed46
Merge branch 'factor-in-guide' into fix-auto-gaussian
fritzo Oct 27, 2021
4f3361d
Fix scanvi example
fritzo Oct 27, 2021
81ae3d4
Merge branch 'fix-auto-gaussian' into learn-white-vec
fritzo Oct 27, 2021
8f5b095
Fix tests
fritzo Oct 27, 2021
943d71e
Fix profiling test
fritzo Oct 27, 2021
9b5e35e
Merge branch 'dev' into fix-auto-gaussian
fritzo Oct 27, 2021
6dc1e2c
Merge branch 'fix-auto-gaussian' into learn-white-vec
fritzo Oct 27, 2021
f488679
Merge branch 'dev' into learn-white-vec
fritzo Oct 27, 2021
9dbd458
Merge branch 'dev' into learn-white-vec
fritzo Nov 12, 2021
5c99961
Remove experimental code
fritzo Nov 12, 2021
6368cf6
Bump funsor version
fritzo Nov 18, 2021
c0e9c93
Merge branch 'dev' into learn-white-vec
fritzo Dec 13, 2021
d21be94
Pin to Funsor 0.4.2
fritzo Dec 13, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 108 additions & 50 deletions pyro/infer/autoguide/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def _setup_prototype(self, *args, **kwargs) -> None:

self.locs = PyroModule()
self.scales = PyroModule()
self.factors = PyroModule()
self.white_vecs = PyroModule()
self.prec_sqrts = PyroModule()
self._factors = OrderedDict()
self._plates = OrderedDict()
self._event_numel = OrderedDict()
Expand Down Expand Up @@ -211,18 +212,20 @@ def _setup_prototype(self, *args, **kwargs) -> None:
d_size = min(d_size, u_size) # just an optimization
batch_shape = _plates_to_shape(self._plates[d])

# Create a square root parameter (full, not lower triangular).
# Create parameters of each Gaussian factor.
white_vec = init_loc.new_zeros(batch_shape + (d_size,))
# We initialize with noise to avoid singular gradient.
sqrt = torch.rand(
prec_sqrt = torch.rand(
batch_shape + (u_size, d_size),
dtype=init_loc.dtype,
device=init_loc.device,
)
sqrt.sub_(0.5).mul_(self._init_scale)
prec_sqrt.sub_(0.5).mul_(self._init_scale)
if not site["is_observed"]:
# Initialize the [d,d] block to the identity matrix.
sqrt.diagonal(dim1=-2, dim2=-1).fill_(1)
deep_setattr(self.factors, d, PyroParam(sqrt, event_dim=2))
prec_sqrt.diagonal(dim1=-2, dim2=-1).fill_(1)
deep_setattr(self.white_vecs, d, PyroParam(white_vec, event_dim=1))
deep_setattr(self.prec_sqrts, d, PyroParam(prec_sqrt, event_dim=2))

@staticmethod
def _compress_site(site):
Expand All @@ -243,7 +246,7 @@ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
if self.prototype_trace is None:
self._setup_prototype(*args, **kwargs)

aux_values = self._sample_aux_values()
aux_values = self._sample_aux_values(temperature=1.0)
values, log_densities = self._transform_values(aux_values)

# Replay via Pyro primitives.
Expand All @@ -268,7 +271,7 @@ def median(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
:rtype: dict
"""
with torch.no_grad(), poutine.mask(mask=False):
aux_values = {name: 0.0 for name in self._factors}
aux_values = self._sample_aux_values(temperature=0.0)
values, _ = self._transform_values(aux_values)
return values

Expand Down Expand Up @@ -299,7 +302,7 @@ def _transform_values(
return values, log_densities

@abstractmethod
def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]:
raise NotImplementedError


Expand Down Expand Up @@ -331,11 +334,13 @@ def _setup_prototype(self, *args, **kwargs):
# Create sparse -> dense precision scatter indices.
self._dense_scatter = {}
for d, site in self._factors.items():
sqrt_shape = deep_getattr(self.factors, d).shape
precision_shape = sqrt_shape[:-1] + sqrt_shape[-2:-1]
index = torch.zeros(precision_shape, dtype=torch.long)
prec_sqrt_shape = deep_getattr(self.prec_sqrts, d).shape
info_vec_shape = prec_sqrt_shape[:-1]
precision_shape = prec_sqrt_shape[:-1] + prec_sqrt_shape[-2:-1]
index1 = torch.zeros(info_vec_shape, dtype=torch.long)
index2 = torch.zeros(precision_shape, dtype=torch.long)

# Collect local offsets.
# Collect local offsets and create index1 for info_vec blockwise.
upstreams = [
u for u in self.dependencies[d] if not self._factors[u]["is_observed"]
]
Expand All @@ -345,8 +350,17 @@ def _setup_prototype(self, *args, **kwargs):
local_offsets[u] = pos
broken_plates = self._plates[u] - self._plates[d]
pos += self._event_numel[u] * _plates_to_shape(broken_plates).numel()
u_index = global_indices[u]

# Permute broken plates to the right of preserved plates.
u_index = _break_plates(u_index, self._plates[u], self._plates[d])

# Create indices blockwise.
# Scatter global indices into the [u] block.
u_start = local_offsets[u]
u_stop = u_start + u_index.size(-1)
index1[..., u_start:u_stop] = u_index

# Create index2 for precision blockwise.
for u, v in itertools.product(upstreams, upstreams):
u_index = global_indices[u]
v_index = global_indices[v]
Expand All @@ -360,18 +374,24 @@ def _setup_prototype(self, *args, **kwargs):
u_stop = u_start + u_index.size(-1)
v_start = local_offsets[v]
v_stop = v_start + v_index.size(-1)
index[
index2[
..., u_start:u_stop, v_start:v_stop
] = self._dense_size * u_index.unsqueeze(-1) + v_index.unsqueeze(-2)

self._dense_scatter[d] = index.reshape(-1)

def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
flat_samples = pyro.sample(
f"_{self._pyro_name}_latent",
self._dense_get_mvn(),
infer={"is_auxiliary": True},
)
self._dense_scatter[d] = index1.reshape(-1), index2.reshape(-1)

def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]:
mvn = self._dense_get_mvn()
if temperature == 0:
# Simply return the mode.
flat_samples = mvn.mean
elif temperature == 1:
# Sample from a dense joint Gaussian over flattened variables.
flat_samples = pyro.sample(
f"_{self._pyro_name}_latent", mvn, infer={"is_auxiliary": True}
)
else:
raise NotImplementedError(f"Invalid temperature: {temperature}")
samples = self._dense_unflatten(flat_samples)
return samples

Expand Down Expand Up @@ -401,14 +421,22 @@ def _dense_flatten(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:

def _dense_get_mvn(self):
# Create a dense joint Gaussian over flattened variables.
flat_info_vec = torch.zeros(self._dense_size)
flat_precision = torch.zeros(self._dense_size ** 2)
for d, index in self._dense_scatter.items():
sqrt = deep_getattr(self.factors, d)
precision = sqrt @ sqrt.transpose(-1, -2)
flat_precision.scatter_add_(0, index, precision.reshape(-1))
for d, (index1, index2) in self._dense_scatter.items():
white_vec = deep_getattr(self.white_vecs, d)
prec_sqrt = deep_getattr(self.prec_sqrts, d)
info_vec = (prec_sqrt @ white_vec[..., None])[..., 0]
precision = prec_sqrt @ prec_sqrt.transpose(-1, -2)
flat_info_vec.scatter_add_(0, index1, info_vec.reshape(-1))
flat_precision.scatter_add_(0, index2, precision.reshape(-1))
info_vec = flat_info_vec
precision = flat_precision.reshape(self._dense_size, self._dense_size)
loc = precision.new_zeros(self._dense_size)
return dist.MultivariateNormal(loc, precision_matrix=precision)
scale_tril = _precision_to_scale_tril(precision)
loc = (
scale_tril @ (scale_tril.transpose(-1, -2) @ info_vec.unsqueeze(-1))
).squeeze(-1)
return dist.MultivariateNormal(loc, scale_tril=scale_tril)


class AutoGaussianFunsor(AutoGaussian):
Expand Down Expand Up @@ -464,7 +492,7 @@ def _setup_prototype(self, *args, **kwargs):
self._funsor_plate_to_dim = plate_to_dim
self._funsor_plates = frozenset(plate_to_dim)

def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]:
funsor = _import_funsor()

# Convert torch to funsor.
Expand All @@ -473,38 +501,43 @@ def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
plate_to_dim.update({f.name: f.dim for f in particle_plates})
factors = {}
for d, inputs in self._funsor_factor_inputs.items():
prec_sqrt = deep_getattr(self.factors, d)
batch_shape = torch.Size(
p.size for p in sorted(self._plates[d], key=lambda p: p.dim)
)
prec_sqrt = prec_sqrt.reshape(batch_shape + prec_sqrt.shape[-2:])
# TODO Make white_vec learnable once .median() can be computed via
# funsor.recipies.forward_filter_backward_precondition()
# https://github.com/pyro-ppl/funsor/pull/553
white_vec = prec_sqrt.new_zeros(()).expand(
prec_sqrt.shape[:-2] + prec_sqrt.shape[-1:]
)
white_vec = deep_getattr(self.white_vecs, d)
prec_sqrt = deep_getattr(self.prec_sqrts, d)
factors[d] = funsor.gaussian.Gaussian(
white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs
white_vec=white_vec.reshape(batch_shape + white_vec.shape[-1:]),
prec_sqrt=prec_sqrt.reshape(batch_shape + prec_sqrt.shape[-2:]),
inputs=inputs,
)

# Perform Gaussian tensor variable elimination.
try: # Convert ValueError into NotImplementedError.
samples, log_prob = funsor.recipes.forward_filter_backward_rsample(
if temperature == 1:
samples, log_prob = _try_possibly_intractable(
funsor.recipes.forward_filter_backward_rsample,
factors=factors,
eliminate=self._funsor_eliminate,
plates=frozenset(plate_to_dim),
sample_inputs={f.name: funsor.Bint[f.size] for f in particle_plates},
)
except ValueError as e:
if str(e) != "intractable!":
raise e from None
raise NotImplementedError(
"Funsor backend found intractable plate nesting. "
'Consider using AutoGaussian(..., backend="dense"), '
"splitting into multiple guides via AutoGuideList, or "
"replacing some plates in the model by .to_event()."
) from e

else:
samples, log_prob = _try_possibly_intractable(
funsor.recipes.forward_filter_backward_precondition,
factors=factors,
eliminate=self._funsor_eliminate,
plates=frozenset(plate_to_dim),
)

# Substitute noise.
sample_shape = torch.Size(f.size for f in particle_plates)
noise = torch.randn(sample_shape + log_prob.inputs["aux"].shape)
noise.mul_(temperature)
aux = funsor.Tensor(noise)[tuple(f.name for f in particle_plates)]
with funsor.interpretations.memoize():
samples = {k: v(aux=aux) for k, v in samples.items()}
log_prob = log_prob(aux=aux)

# Convert funsor to torch.
if am_i_wrapped() and poutine.get_mask() is not False:
Expand All @@ -516,6 +549,31 @@ def _sample_aux_values(self) -> Dict[str, torch.Tensor]:
return samples


def _precision_to_scale_tril(P):
# Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
L = torch.triangular_solve(
torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), L_inv, upper=False
)[0]
return L


def _try_possibly_intractable(fn, *args, **kwargs):
# Convert ValueError into NotImplementedError.
try:
return fn(*args, **kwargs)
except ValueError as e:
if str(e) != "intractable!":
raise e from None
raise NotImplementedError(
"Funsor backend found intractable plate nesting. "
'Consider using AutoGaussian(..., backend="dense"), '
"splitting into multiple guides via AutoGuideList, or "
"replacing some plates in the model by .to_event()."
) from e


def _plates_to_shape(plates):
shape = [1] * max([0] + [-f.dim for f in plates])
for f in plates:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@
"horovod": ["horovod[pytorch]>=0.19"],
"funsor": [
# This must be a released version when Pyro is released.
"funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461",
# "funsor[torch]==0.4.1",
# "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461",
"funsor[torch]==0.4.2",
],
},
python_requires=">=3.6",
Expand Down
26 changes: 15 additions & 11 deletions tests/infer/autoguide/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_break_plates,
)
from pyro.infer.reparam import LocScaleReparam
from pyro.optim import Adam
from pyro.optim import ClippedAdam
from tests.common import assert_close, assert_equal, xfail_if_not_implemented

BACKENDS = [
Expand Down Expand Up @@ -131,27 +131,27 @@ def check_backends_agree(model):
params2 = dict(guide2.named_parameters())
assert set(params1) == set(params2)
for k, v in params1.items():
v.data.normal_()
v.data.add_(torch.zeros_like(v).normal_())
params2[k].data.copy_(v.data)
names = sorted(params1)

# Check densities agree between backends.
with torch.no_grad(), poutine.trace() as tr:
aux = guide2._sample_aux_values()
aux = guide2._sample_aux_values(temperature=1.0)
flat = guide1._dense_flatten(aux)
tr.trace.compute_log_prob()
log_prob_funsor = tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"]
with torch.no_grad(), poutine.trace() as tr:
with poutine.condition(data={"_AutoGaussianDense_latent": flat}):
guide1._sample_aux_values()
guide1._sample_aux_values(temperature=1.0)
tr.trace.compute_log_prob()
log_prob_dense = tr.trace.nodes["_AutoGaussianDense_latent"]["log_prob"]
assert_equal(log_prob_funsor, log_prob_dense)

# Check Monte Carlo estimate of entropy.
entropy1 = guide1._dense_get_mvn().entropy()
with pyro.plate("particle", 100000, dim=-3), poutine.trace() as tr:
guide2._sample_aux_values()
guide2._sample_aux_values(temperature=1.0)
tr.trace.compute_log_prob()
entropy2 = -tr.trace.nodes["_AutoGaussianFunsor_latent"]["log_prob"].mean()
assert_close(entropy1, entropy2, atol=1e-2)
Expand All @@ -163,10 +163,14 @@ def check_backends_agree(model):
)
for name, grad1, grad2 in zip(names, grads1, grads2):
# Gradients should agree to very high precision.
if grad1 is None and grad2 is not None:
grad1 = torch.zeros_like(grad2)
elif grad2 is None and grad1 is not None:
grad2 = torch.zeros_like(grad1)
assert_close(grad1, grad2, msg=f"{name}:\n{grad1} vs {grad2}")

# Check elbos agree between backends.
elbo = Trace_ELBO(num_particles=100000, vectorize_particles=True)
elbo = Trace_ELBO(num_particles=1000000, vectorize_particles=True)
loss1 = elbo.differentiable_loss(model, guide1)
loss2 = elbo.differentiable_loss(model, guide2)
assert_close(loss1, loss2, atol=1e-2, rtol=0.05)
Expand Down Expand Up @@ -422,7 +426,7 @@ def model():
pyro.sample("b", dist.Normal(a.mean(-1), 1), obs=torch.tensor(0.0))

guide = AutoGaussian(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
for step in range(2):
with xfail_if_not_implemented():
svi.step()
Expand All @@ -445,7 +449,7 @@ def model():
pyro.sample("d", dist.Normal(c, 1), obs=torch.zeros(3, 2))

guide = AutoGaussian(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
for step in range(2):
with xfail_if_not_implemented():
svi.step()
Expand Down Expand Up @@ -674,7 +678,7 @@ def test_pyrocov_smoke(model, Guide, backend):
}

guide = Guide(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
for step in range(2):
with xfail_if_not_implemented():
svi.step(dataset)
Expand Down Expand Up @@ -703,7 +707,7 @@ def test_pyrocov_reparam(model, Guide, backend):
}
model = poutine.reparam(model, config)
guide = Guide(model, backend=backend)
svi = SVI(model, guide, Adam({"lr": 1e-8}), Trace_ELBO())
svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
for step in range(2):
with xfail_if_not_implemented():
svi.step(dataset)
Expand Down Expand Up @@ -825,7 +829,7 @@ def test_profile(backend, jit, n=1, num_steps=1, log_every=1):
print("Training")
Elbo = JitTrace_ELBO if jit else Trace_ELBO
elbo = Elbo(max_plate_nesting=3, ignore_jit_warnings=True)
svi = SVI(model, guide, Adam({"lr": 1e-8}), elbo)
svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), elbo)
for step in range(num_steps):
loss = svi.step(dataset)
if log_every and step % log_every == 0:
Expand Down
2 changes: 1 addition & 1 deletion tests/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,7 +1571,7 @@ def model(data):
guide.requires_grad_(False)
with torch.no_grad():
# Check moments.
vectorize = pyro.plate("particles", 10000, dim=-2)
vectorize = pyro.plate("particles", 50000, dim=-2)
guide_trace = poutine.trace(vectorize(guide)).get_trace(data)
samples = poutine.replay(vectorize(model), guide_trace)(data)
for name in ["x", "y"]:
Expand Down