Skip to content

Commit

Permalink
make lpt model consistent with lognormal model
Browse files Browse the repository at this point in the history
  • Loading branch information
Justinezgh committed Mar 5, 2024
1 parent 08ebddd commit 7c6d79e
Showing 1 changed file with 31 additions and 31 deletions.
62 changes: 31 additions & 31 deletions sbi_lens/simulator/Lpt_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def pk_fn(x):

# Define the probabilistic model
def lensingLpt(
field_size,
field_npix,
N,
map_size,
box_size,
box_shape,
gal_per_arcmin2,
Expand All @@ -187,25 +187,28 @@ def lensingLpt(
"""
This function defines the top-level forward model for our observations
"""

pix_area = (map_size * 60 / N) ** 2

# Sampling initial conditions
initial_conditions = numpyro.sample(
"initial_conditions", dist.Normal(jnp.zeros(box_shape), jnp.ones(box_shape))
"z", dist.Normal(jnp.zeros(box_shape), jnp.ones(box_shape))
)

Omega_c = numpyro.sample("Omega_c", dist.TruncatedNormal(0.2664, 0.2, low=0))
Omega_b = numpyro.sample("Omega_b", dist.Normal(0.0492, 0.006))
sigma8 = numpyro.sample("sigma8", dist.Normal(0.831, 0.14))
h = numpyro.sample("h", dist.Normal(0.6727, 0.063))
omega_c = numpyro.sample("omega_c", dist.TruncatedNormal(0.2664, 0.2, low=0))
omega_b = numpyro.sample("omega_b", dist.Normal(0.0492, 0.006))
sigma_8 = numpyro.sample("sigma_8", dist.Normal(0.831, 0.14))
h_0 = numpyro.sample("h_0", dist.Normal(0.6727, 0.063))
n_s = numpyro.sample("n_s", dist.Normal(0.9645, 0.08))
w0 = numpyro.sample("w0", dist.TruncatedNormal(-1.0, 0.9, low=-2.0, high=-0.3))
w_0 = numpyro.sample("w_0", dist.TruncatedNormal(-1.0, 0.9, low=-2.0, high=-0.3))

cosmo = jc.Cosmology(
Omega_c=Omega_c,
Omega_b=Omega_b,
sigma8=sigma8,
h=h,
Omega_c=omega_c,
Omega_b=omega_b,
sigma8=sigma_8,
h=h_0,
n_s=n_s,
w0=w0,
w0=w_0,
wa=0.0,
Omega_k=0.0,
)
Expand All @@ -216,28 +219,25 @@ def lensingLpt(

lensing_model = jax.jit(
make_full_field_model(
field_size=field_size,
field_npix=field_npix,
field_size=map_size,
field_npix=N,
box_size=box_size,
box_shape=box_shape,
)
)

convergence_maps, _ = lensing_model(cosmo, nz_shear, initial_conditions)

# Apply noise to the maps (this defines the likelihood)
observed_maps = [
numpyro.sample(
"kappa_%d" % i,
dist.Normal(
k,
sigma_e
/ jnp.sqrt(
nz_shear[i].gals_per_arcmin2 * (field_size * 60 / field_npix) ** 2
),
field, _ = lensing_model(cosmo, nz_shear, initial_conditions)
field = jnp.transpose(jnp.array(field), [1, 2, 0])

x = numpyro.sample(
"y",
dist.MultivariateNormal(
loc=field,
covariance_matrix=jnp.diag(
sigma_e**2
/ (jnp.array([b.gals_per_arcmin2 for b in nz_shear]) * pix_area)
),
)
for i, k in enumerate(convergence_maps)
]
),
)

return observed_maps
return x

0 comments on commit 7c6d79e

Please sign in to comment.