From 0a21fe420466de2ca4bd2418b5c1cafddd5326a0 Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Mon, 16 Dec 2024 21:22:06 +0100 Subject: [PATCH] 2d sampling --- src/cryo_sbi/utils/estimator_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cryo_sbi/utils/estimator_utils.py b/src/cryo_sbi/utils/estimator_utils.py index ba9adbb..ff09470 100644 --- a/src/cryo_sbi/utils/estimator_utils.py +++ b/src/cryo_sbi/utils/estimator_utils.py @@ -87,7 +87,7 @@ def sample_posterior( samples = estimator.sample( image_batch.to(device, non_blocking=True), shape=(num_samples,) ).cpu() - theta_samples.append(samples.reshape(-1, image_batch.shape[0])) + theta_samples.append(samples) return torch.cat(theta_samples, dim=1)