From 7c6d79ebcfdecc1e8c06ed0f5a0d72e413a07042 Mon Sep 17 00:00:00 2001 From: Justine Date: Tue, 5 Mar 2024 14:58:33 +0100 Subject: [PATCH] make lpt model consistent with lognormal model --- sbi_lens/simulator/Lpt_field.py | 62 ++++++++++++++++----------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/sbi_lens/simulator/Lpt_field.py b/sbi_lens/simulator/Lpt_field.py index 14a62c6..89bfbf2 100644 --- a/sbi_lens/simulator/Lpt_field.py +++ b/sbi_lens/simulator/Lpt_field.py @@ -173,8 +173,8 @@ def pk_fn(x): # Define the probabilistic model def lensingLpt( - field_size, - field_npix, + N, + map_size, box_size, box_shape, gal_per_arcmin2, @@ -187,25 +187,28 @@ def lensingLpt( """ This function defines the top-level forward model for our observations """ + + pix_area = (map_size * 60 / N) ** 2 + # Sampling initial conditions initial_conditions = numpyro.sample( - "initial_conditions", dist.Normal(jnp.zeros(box_shape), jnp.ones(box_shape)) + "z", dist.Normal(jnp.zeros(box_shape), jnp.ones(box_shape)) ) - Omega_c = numpyro.sample("Omega_c", dist.TruncatedNormal(0.2664, 0.2, low=0)) - Omega_b = numpyro.sample("Omega_b", dist.Normal(0.0492, 0.006)) - sigma8 = numpyro.sample("sigma8", dist.Normal(0.831, 0.14)) - h = numpyro.sample("h", dist.Normal(0.6727, 0.063)) + omega_c = numpyro.sample("omega_c", dist.TruncatedNormal(0.2664, 0.2, low=0)) + omega_b = numpyro.sample("omega_b", dist.Normal(0.0492, 0.006)) + sigma_8 = numpyro.sample("sigma_8", dist.Normal(0.831, 0.14)) + h_0 = numpyro.sample("h_0", dist.Normal(0.6727, 0.063)) n_s = numpyro.sample("n_s", dist.Normal(0.9645, 0.08)) - w0 = numpyro.sample("w0", dist.TruncatedNormal(-1.0, 0.9, low=-2.0, high=-0.3)) + w_0 = numpyro.sample("w_0", dist.TruncatedNormal(-1.0, 0.9, low=-2.0, high=-0.3)) cosmo = jc.Cosmology( - Omega_c=Omega_c, - Omega_b=Omega_b, - sigma8=sigma8, - h=h, + Omega_c=omega_c, + Omega_b=omega_b, + sigma8=sigma_8, + h=h_0, n_s=n_s, - w0=w0, + w0=w_0, wa=0.0, Omega_k=0.0, ) @@ -216,28 +219,25 @@ def lensingLpt( lensing_model = jax.jit( make_full_field_model( - field_size=field_size, - field_npix=field_npix, + field_size=map_size, + field_npix=N, box_size=box_size, box_shape=box_shape, ) ) - convergence_maps, _ = lensing_model(cosmo, nz_shear, initial_conditions) - - # Apply noise to the maps (this defines the likelihood) - observed_maps = [ - numpyro.sample( - "kappa_%d" % i, - dist.Normal( - k, - sigma_e - / jnp.sqrt( - nz_shear[i].gals_per_arcmin2 * (field_size * 60 / field_npix) ** 2 - ), + field, _ = lensing_model(cosmo, nz_shear, initial_conditions) + field = jnp.transpose(jnp.array(field), [1, 2, 0]) + + x = numpyro.sample( + "y", + dist.MultivariateNormal( + loc=field, + covariance_matrix=jnp.diag( + sigma_e**2 + / (jnp.array([b.gals_per_arcmin2 for b in nz_shear]) * pix_area) ), - ) - for i, k in enumerate(convergence_maps) - ] + ), + ) - return observed_maps + return x