Skip to content

Commit

Permalink
Fix pyro-ppl#3255 (draft)
Browse files Browse the repository at this point in the history
  • Loading branch information
gui11aume committed Aug 30, 2023
1 parent cc8e545 commit f88a16e
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyro/distributions/score_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def scale_and_mask(self, scale=1.0, mask=None):
:type mask: torch.BoolTensor or None
"""
log_prob = scale_and_mask(self.log_prob, scale, mask)
score_function = self.score_function # not scaled
score_function = scale_and_mask(self.score_function, 1.0, mask) # not scaled
entropy_term = scale_and_mask(self.entropy_term, scale, mask)
return ScoreParts(log_prob, score_function, entropy_term)
2 changes: 1 addition & 1 deletion pyro/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _compute_log_r(model_trace, guide_trace):
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
log_r_term = model_site["log_prob"]
if not model_site["is_observed"]:
if not model_site["is_observed"] and name in guide_trace.nodes:
log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"]
log_r.add((stacks[name], log_r_term.detach()))
return log_r
Expand Down
2 changes: 1 addition & 1 deletion pyro/infer/trace_mean_field_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _differentiable_loss_particle(self, model_trace, guide_trace):
if model_site["type"] == "sample":
if model_site["is_observed"]:
elbo_particle = elbo_particle + model_site["log_prob_sum"]
else:
elif name in guide_trace.nodes:
guide_site = guide_trace.nodes[name]
if is_validation_enabled():
check_fully_reparametrized(guide_site)
Expand Down
2 changes: 2 additions & 0 deletions pyro/infer/tracegraph_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def _compute_elbo(model_trace, guide_trace):
# we include only downstream costs to reduce variance
# optionally include baselines to further reduce variance
for node, downstream_cost in downstream_costs.items():
if node not in guide_trace.nodes:
continue
guide_site = guide_trace.nodes[node]
downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"])
score_function = guide_site["score_parts"].score_function
Expand Down
95 changes: 94 additions & 1 deletion tests/infer/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from collections import defaultdict

import numpy as np
import pytest
Expand Down Expand Up @@ -30,7 +31,6 @@

logger = logging.getLogger(__name__)


def DiffTrace_ELBO(*args, **kwargs):
return Trace_ELBO(*args, **kwargs).differentiable_loss

Expand Down Expand Up @@ -214,6 +214,99 @@ def guide(subsample):
assert_equal(actual_grads, expected_grads, prec=precision)


# Not including the unobserved site in the guide triggers a warning
# that can make the test fail if we do not deactivate UserWarning.
@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize(
"with_x_unobserved",
[True, False],
)
@pytest.mark.parametrize(
"mask",
[[True, True], [True, False], [False, True]],
)
@pytest.mark.parametrize(
"reparameterized,has_rsample",
[(True, None), (True, False), (True, True), (False, None)],
ids=["reparam", "reparam-False", "reparam-True", "nonreparam"],
)
@pytest.mark.parametrize(
"Elbo,local_samples",
[
(Trace_ELBO, False),
(DiffTrace_ELBO, False),
(TraceGraph_ELBO, False),
(TraceMeanField_ELBO, False),
(TraceEnum_ELBO, False),
(TraceEnum_ELBO, True),
],
)
def test_mask_gradient(
Elbo, reparameterized, has_rsample, local_samples, mask, with_x_unobserved,
):
pyro.clear_param_store()
data = torch.tensor([-0.5, 2.0])
precision = 0.08
Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

def model(data, mask):
z = pyro.sample("z", Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("x", Normal(z, 1), obs=data, obs_mask=mask)

def guide(data, mask):
scale = pyro.param("scale", lambda: torch.tensor([1.0]))
loc = pyro.param("loc", lambda: torch.tensor([1.0]))
z_dist = Normal(loc, scale)
if has_rsample is not None:
z_dist.has_rsample_(has_rsample)
z = pyro.sample("z", z_dist)
if with_x_unobserved:
with pyro.plate("data", len(data)):
with pyro.poutine.mask(mask=~mask):
pyro.sample("x_unobserved", Normal(z, 1))

num_particles = 50000
accumulation = 1
if local_samples:
# One has to limit the amount of samples in this
# test because the memory footprint is large.
guide = config_enumerate(guide, num_samples=5000)
accumulation = num_particles // 5000
num_particles = 1

optim = Adam({"lr": 0.1})
elbo = Elbo(
max_plate_nesting=1, # set this to ensure rng agrees across runs
num_particles=num_particles,
vectorize_particles=True,
strict_enumeration_warning=False,
)
actual_grads = defaultdict(lambda: np.zeros(1))
for _ in range(accumulation):
inference = SVI(model, guide, optim, loss=elbo)
with xfail_if_not_implemented():
inference.loss_and_grads(
model, guide, data=data, mask=torch.tensor(mask)
)
params = dict(pyro.get_param_store().named_parameters())
actual_grads = {
name: param.grad.detach().cpu().numpy() / accumulation
for name, param in params.items()
}

# grad(loc) = (n+1) * loc - (x1 + ... + xn)
# grad(scale) = (n+1) * scale - 1 / scale
expected_grads = {
"loc": sum(mask) + 1. - data[mask].sum(0, keepdim=True).numpy(),
"scale": sum(mask) + 1 - np.ones(1)
}
for name in sorted(params):
logger.info("expected {} = {}".format(name, expected_grads[name]))
logger.info("actual {} = {}".format(name, actual_grads[name]))
assert_equal(actual_grads, expected_grads, prec=precision)


@pytest.mark.parametrize(
"reparameterized", [True, False], ids=["reparam", "nonreparam"]
)
Expand Down

0 comments on commit f88a16e

Please sign in to comment.