diff --git a/jaxpm/pm.py b/jaxpm/pm.py index d9870f7..f1273f2 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -51,6 +51,41 @@ def linear_field(mesh_shape, box_size, pk, seed): field = jnp.fft.irfftn(field) return field +def box_muller_field(amplitude, phase, pkmesh): + """ + Obtain Gaussian random field given uniform random numbers and Pk amplitude. + """ + field = pkmesh**0.5 * jnp.sqrt(-jnp.log(amplitude)) * (jnp.cos(phase) + 1j * jnp.sin(phase)) + return jnp.fft.irfftn(field, (amplitude.shape[0],)*3, norm='ortho') + +def linear_field_box_muller(mesh_shape, box_size, pk, seed, fixamp = False, inv_phase = False): + """ + Generate initial conditions with fixed amplitude and/or inverted phase. + """ + + key, subkey1, subkey2 = jax.random.split(seed, 3) + kvec = fftk(mesh_shape) + kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5 + pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2]) + + if fixamp: + amplitude = jnp.ones_like(kmesh) + else: + amplitude = jax.random.uniform(subkey1, kmesh.shape, minval=1e-8) + + + if inv_phase: + phase = jax.random.uniform(subkey2, kmesh.shape, minval=1e-8) * 2 * jnp.pi + ret = [] + ret.append(box_muller_field(amplitude, phase, pkmesh)) + phase = (phase + jnp.pi) + ret.append(box_muller_field(amplitude, phase, pkmesh)) + return ret + + return box_muller_field(amplitude, phase, pkmesh) + + + def make_ode_fn(mesh_shape): def nbody_ode(state, a, cosmo):