Skip to content

Commit d81dc81

Browse files
committed
bug fix for MNLE on gpu.
1 parent 5736080 commit d81dc81

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

sbi/neural_nets/mnle.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,18 @@ def log_prob_iid(self, x: Tensor, theta: Tensor) -> Tensor:
361361
x_cont_repeated, x_disc_repeated = _separate_x(x_repeated)
362362
x_cont, x_disc = _separate_x(x)
363363

364-
log_prob_per_cat = torch.zeros(self.discrete_net.num_categories, batch_size)
365364
# repeat categories for parameters
366365
repeated_categories = torch.repeat_interleave(
367366
torch.arange(self.discrete_net.num_categories - 1), batch_size, dim=0
368367
)
369368
# repeat parameters for categories
370369
repeated_theta = theta.repeat(self.discrete_net.num_categories - 1, 1)
370+
log_prob_per_cat = torch.zeros(self.discrete_net.num_categories, batch_size).to(
371+
net_device
372+
)
371373
log_prob_per_cat[:-1, :] = self.discrete_net.log_prob(
372-
repeated_categories,
373-
repeated_theta,
374+
repeated_categories.to(net_device),
375+
repeated_theta.to(net_device),
374376
).reshape(-1, batch_size)
375377
# infer the last category logprob from sum to one.
376378
log_prob_per_cat[-1, :] = torch.log(1 - log_prob_per_cat[:-1, :].exp().sum(0))

tests/mnle_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_mnle_on_device(device):
3333

3434
# Test sampling on device.
3535
posterior = trainer.build_posterior()
36-
posterior.sample((1,), x=x[0], show_progress_bars=False)
36+
posterior.sample((1,), x=x[0], show_progress_bars=False, mcmc_method="nuts")
3737

3838

3939
@pytest.mark.parametrize("sampler", ("mcmc", "rejection", "vi"))

0 commit comments

Comments
 (0)