diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 7eabd1c..49dd9b5 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -15,28 +15,27 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v2 - uses: conda-incubator/setup-miniconda@v2 with: + miniforge-variant: Mambaforge + use-mamba: true python-version: ${{ matrix.python-version }} - channels: conda-forge,defaults - channel-priority: strict - show-channel-urls: true auto-update-conda: true + show-channel-urls: true - - name: Install dependencies + - name: Install package and dependencies shell: bash -l {0} run: | conda config --set always_yes yes - conda install pytest pip - conda install -c conda-forge pyccl + mamba install pytest pytest-xdist pyccl pip install . - + - name: Test with pytest shell: bash -l {0} run: | - pytest + pytest -n auto diff --git a/jax_cosmo/angular_cl.py b/jax_cosmo/angular_cl.py index 6db6c9f..3cfe77c 100644 --- a/jax_cosmo/angular_cl.py +++ b/jax_cosmo/angular_cl.py @@ -70,7 +70,7 @@ def angular_cl( def cl(ell): def integrand(a): # Step 1: retrieve the associated comoving distance - chi = bkgrd.radial_comoving_distance(cosmo, a) + _, chi = bkgrd.radial_comoving_distance(cosmo, a) # Step 2: get the power spectrum for this combination of chi and a k = (ell + 0.5) / np.clip(chi, 1.0) diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index 1a2a182..83ae8bd 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -1,4 +1,7 @@ -# This module implements various functions for the background COSMOLOGY +"""This module implements various functions for the cosmological background +and linear perturbations. + +""" import jax.numpy as np from jax import lax @@ -14,6 +17,7 @@ "Omega_m_a", "Omega_de_a", "radial_comoving_distance", + "a_of_chi", "dchioverda", "transverse_comoving_distance", "angular_diameter_distance", @@ -23,55 +27,53 @@ def w(cosmo, a): - r"""Dark Energy equation of state parameter using the Linder + r"""Dark energy equation of state parameter using the Linder parametrisation. Parameters ---------- - cosmo: Cosmology - Cosmological parameters structure - + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- w : ndarray, or float if input scalar - The Dark Energy equation of state parameter at the specified - scale factor + Dark energy equation of state parameters at specified scale factors. Notes ----- - - The Linder parametrization :cite:`2003:Linder` for the Dark Energy + The Linder parametrization :cite:`2003:Linder` for the dark energy equation of state :math:`p = w \rho` is given by: .. math:: w(a) = w_0 + w_a (1 - a) + """ return cosmo.w0 + (1.0 - a) * cosmo.wa # Equation (6) in Linder (2003) def f_de(cosmo, a): - r"""Evolution parameter for the Dark Energy density. + r"""Evolution parameter for the dark energy density. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- f : ndarray, or float if input scalar - The evolution parameter of the Dark Energy density as a function - of scale factor + Dark energy density evolution parameters at specified scale factors. Notes ----- - - For a given parametrisation of the Dark Energy equation of state, - the scaling of the Dark Energy density with time can be written as: + For a given parametrisation of the dark energy equation of state, + the scaling of the dark energy density with time can be written as: .. math:: @@ -86,38 +88,39 @@ def f_de(cosmo, a): .. math:: f(a) = -3 (1 + w_0 + w_a) \ln(a) + 3 w_a (a - 1) + """ return -3.0 * (1.0 + cosmo.w0 + cosmo.wa) * np.log(a) + 3.0 * cosmo.wa * (a - 1.0) def Esqr(cosmo, a): - r"""Square of the scale factor dependent factor E(a) in the Hubble - parameter. + r"""Squared time scaling factors E(a) of the Hubble expansion. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- E^2 : ndarray, or float if input scalar - Square of the scaling of the Hubble constant as a function of - scale factor + Squared scaling of the Hubble expansion at specified scale factors. Notes ----- - The Hubble parameter at scale factor `a` is given by - :math:`H^2(a) = E^2(a) H_o^2` where :math:`E^2` is obtained through + :math:`H^2(a) = E^2(a) H_0^2` where :math:`E^2` is obtained through Friedman's Equation (see :cite:`2005:Percival`) : .. math:: E^2(a) = \Omega_m a^{-3} + \Omega_k a^{-2} + \Omega_{de} e^{f(a)} - where :math:`f(a)` is the Dark Energy evolution parameter computed + where :math:`f(a)` is the dark energy evolution parameter computed by :py:meth:`.f_de`. + """ return ( cosmo.Omega_m * np.power(a, -3) @@ -127,33 +130,38 @@ def Esqr(cosmo, a): def H(cosmo, a): - r"""Hubble parameter [km/s/(Mpc/h)] at scale factor `a` + r"""Hubble expansion rate [km/s/(Mpc/h)] at given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- H : ndarray, or float if input scalar - Hubble parameter at the requested scale factor. + Hubble parameters at specified scale factors. + """ return const.H0 * np.sqrt(Esqr(cosmo, a)) def Omega_m_a(cosmo, a): - r"""Matter density at scale factor `a`. + r"""Matter density at given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- Omega_m : ndarray, or float if input scalar - Non-relativistic matter density at the requested scale factor + Non-relativistic matter density at specified scale factors. Notes ----- @@ -164,65 +172,78 @@ def Omega_m_a(cosmo, a): \Omega_m(a) = \frac{\Omega_m a^{-3}}{E^2(a)} see :cite:`2005:Percival` Eq. (6) + """ return cosmo.Omega_m * np.power(a, -3) / Esqr(cosmo, a) def Omega_de_a(cosmo, a): - r"""Dark Energy density at scale factor `a`. + r"""Dark energy density at given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- Omega_de : ndarray, or float if input scalar - Dark Energy density at the requested scale factor + Dark energy density at specified scale factors. Notes ----- - The evolution of Dark Energy density :math:`\Omega_{de}(a)` is given + The evolution of dark energy density :math:`\Omega_{de}(a)` is given by: .. math:: \Omega_{de}(a) = \frac{\Omega_{de} e^{f(a)}}{E^2(a)} - where :math:`f(a)` is the Dark Energy evolution parameter computed by + where :math:`f(a)` is the dark energy evolution parameter computed by :py:meth:`.f_de` (see :cite:`2005:Percival` Eq. (6)). + """ return cosmo.Omega_de * np.exp(f_de(cosmo, a)) / Esqr(cosmo, a) -def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=256): - r"""Radial comoving distance in [Mpc/h] for a given scale factor. +def radial_comoving_distance(cosmo, a): + r"""Radial comoving distances in [Mpc/h] at given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- + cosmo : Cosmology + Cosmological parameters with cached computations. chi : ndarray, or float if input scalar - Radial comoving distance corresponding to the specified scale - factor. + Radial comoving distances at specified scale factors. Notes ----- - The radial comoving distance is computed by performing the following + The radial comoving distances is computed by performing the following integration: .. math:: \chi(a) = R_H \int_a^1 \frac{da^\prime}{{a^\prime}^2 E(a^\prime)} + """ # Check if distances have already been computed - if not "background.radial_comoving_distance" in cosmo._workspace.keys(): + key = "background.radial_comoving_distance" + if not cosmo.is_cached(key): # Compute tabulated array - atab = np.logspace(log10_amin, 0.0, steps) + atab = np.logspace( + cosmo.config.log10_a_min, + cosmo.config.log10_a_max, + cosmo.config.log10_a_steps, + ) def dchioverdlna(y, x): xa = np.exp(x) @@ -232,81 +253,92 @@ def dchioverdlna(y, x): # np.clip(- 3000*np.log(atab), 0, 10000)#odeint(dchioverdlna, 0., np.log(atab), cosmo) chitab = chitab[-1] - chitab - cache = {"a": atab, "chi": chitab} - cosmo._workspace["background.radial_comoving_distance"] = cache + value = {"a": atab, "chi": chitab} + cosmo = cosmo.cache_set(key, value) else: - cache = cosmo._workspace["background.radial_comoving_distance"] + value = cosmo.cache_get(key) a = np.atleast_1d(a) # Return the results as an interpolation of the table - return np.clip(interp(a, cache["a"], cache["chi"]), 0.0) + chi = np.clip(interp(a, value["a"], value["chi"]), 0.0) + return cosmo, chi def a_of_chi(cosmo, chi): - r"""Computes the scale factor for corresponding (array) of radial comoving - distance by reverse linear interpolation. - - Parameters: - ----------- - cosmo: Cosmology - Cosmological parameters + r"""Computes the scale factors at given radial comoving distances by + reverse linear interpolation. - chi: array-like - radial comoving distance to query. + Parameters + ---------- + cosmo : Cosmology + Cosmological parameters. + chi : array-like + Radial comoving distances to query. - Returns: - -------- + Returns + ------- + cosmo : Cosmology + Cosmological parameters with cached computations. a : array-like - Scale factors corresponding to requested distances + Scale factors at specified distances. + """ # Check if distances have already been computed, force computation otherwise - if not "background.radial_comoving_distance" in cosmo._workspace.keys(): - radial_comoving_distance(cosmo, 1.0) - cache = cosmo._workspace["background.radial_comoving_distance"] + key = "background.radial_comoving_distance" + if not cosmo.is_cached(key): + cosmo, _ = radial_comoving_distance(cosmo, 1.0) + value = cosmo.cache_get(key) + chi = np.atleast_1d(chi) - return interp(chi, cache["chi"], cache["a"]) + a = interp(chi, value["chi"], value["a"]) + return cosmo, a def dchioverda(cosmo, a): - r"""Derivative of the radial comoving distance with respect to the - scale factor. + r"""Derivative of the radial comoving distances with respect to the + scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- - dchi/da : ndarray, or float if input scalar - Derivative of the radial comoving distance with respect to the - scale factor at the specified scale factor. + dchi/da : ndarray, or float if input scalar + Derivative of the radial comoving distances with respect to the + scale factors at specified scale factors. Notes ----- - The expression for :math:`\frac{d \chi}{da}` is: .. math:: \frac{d \chi}{da}(a) = \frac{R_H}{a^2 E(a)} + """ return const.rh / (a ** 2 * np.sqrt(Esqr(cosmo, a))) def transverse_comoving_distance(cosmo, a): - r"""Transverse comoving distance in [Mpc/h] for a given scale factor. + r"""Transverse comoving distances in [Mpc/h] for given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- + cosmo : Cosmology + Cosmological parameters with cached computations. f_k : ndarray, or float if input scalar - Transverse comoving distance corresponding to the specified - scale factor. + Transverse comoving distances at specified scale factors. Notes ----- @@ -325,8 +357,8 @@ def transverse_comoving_distance(cosmo, a): \mbox{for } \Omega_k < 0 \end{matrix} \right. + """ - index = cosmo.k + 1 def open_universe(chi): return const.rh / cosmo.sqrtk * np.sinh(cosmo.sqrtk * chi / const.rh) @@ -339,22 +371,28 @@ def close_universe(chi): branches = (open_universe, flat_universe, close_universe) - chi = radial_comoving_distance(cosmo, a) + cosmo, chi = radial_comoving_distance(cosmo, a) - return lax.switch(cosmo.k + 1, branches, chi) + f_k = lax.switch(cosmo.k + 1, branches, chi) + return cosmo, f_k def angular_diameter_distance(cosmo, a): - r"""Angular diameter distance in [Mpc/h] for a given scale factor. + r"""Angular diameter distances in [Mpc/h] for given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- + cosmo : Cosmology + Cosmological parameters with cached computations. d_A : ndarray, or float if input scalar + Angular diameter distances at specified scale factors. Notes ----- @@ -364,62 +402,69 @@ def angular_diameter_distance(cosmo, a): .. math:: d_A(a) = a f_k(a) + """ - return a * transverse_comoving_distance(cosmo, a) + cosmo, f_k = transverse_comoving_distance(cosmo, a) + d_A = a * f_k + return cosmo, d_A def growth_factor(cosmo, a): - """Compute linear growth factor D(a) at a given scale factor, - normalized such that D(a=1) = 1. + r"""Compute linear growth factors :math:`D(a)` at given scale factors, + normalized such that :math:`D(a=1) = 1`. Parameters ---------- - cosmo: `Cosmology` - Cosmology object - - a: array_like - Scale factor + cosmo : Cosmology + Cosmology parameters. + a : array_like + Scale factors. Returns ------- - D: ndarray, or float if input scalar - Growth factor computed at requested scale factor + cosmo : Cosmology + Cosmological parameters with cached computations. + D : ndarray, or float if input scalar + Growth factors at specified scale factors. Notes ----- - The growth computation will depend on the cosmology parametrization, for - instance if the $\gamma$ parameter is defined, the growth will be computed - assuming the $f = \Omega^\gamma$ growth rate, otherwise the usual ODE for - growth will be solved. + The growth computation depends on the cosmology parametrization: + if the :math:`\gamma` parameter is defined, the growth history is computed + assuming the growth rate :math:`f = \Omega_m(a)^\gamma`, otherwise the + usual ODE for growth will be solved. + """ - if cosmo._flags["gamma_growth"]: - return _growth_factor_gamma(cosmo, a) + if cosmo.gamma is not None: + cosmo, D = _growth_factor_gamma(cosmo, a) else: - return _growth_factor_ODE(cosmo, a) + cosmo, D = _growth_factor_ODE(cosmo, a) + return cosmo, D def growth_rate(cosmo, a): - """Compute growth rate dD/dlna at a given scale factor. + r"""Compute growth rates :math:`dD/d\ln\ a` at given scale factors. Parameters ---------- - cosmo: `Cosmology` - Cosmology object - - a: array_like - Scale factor + cosmo : Cosmology + Cosmology parameters. + a : array_like + Scale factors. Returns ------- - f: ndarray, or float if input scalar - Growth rate computed at requested scale factor + cosmo : Cosmology + Cosmological parameters with cached computations. + f : ndarray, or float if input scalar + Growth rate at specified scale factors. Notes ----- - The growth computation will depend on the cosmology parametrization, for - instance if the $\gamma$ parameter is defined, the growth will be computed - assuming the $f = \Omega^\gamma$ growth rate, otherwise the usual ODE for - growth will be solved. + The growth computation depends on the cosmology parametrization: + if the :math:`\gamma` parameter is defined, the growth history is computed + assuming the growth rate :math:`f = \Omega_m(a)^\gamma`, otherwise the + usual ODE for growth will be solved. The LCDM approximation to the growth rate :math:`f_{\gamma}(a)` is given by: @@ -427,40 +472,49 @@ def growth_rate(cosmo, a): f_{\gamma}(a) = \Omega_m^{\gamma} (a) - with :math: `\gamma` in LCDM, given approximately by: + with :math:`\gamma` in LCDM, given approximately by: .. math:: \gamma = 0.55 see :cite:`2019:Euclid Preparation VII, eqn.32` + """ - if cosmo._flags["gamma_growth"]: - return _growth_rate_gamma(cosmo, a) + if cosmo.gamma is not None: + f = _growth_rate_gamma(cosmo, a) else: - return _growth_rate_ODE(cosmo, a) + cosmo, f = _growth_rate_ODE(cosmo, a) + return cosmo, f -def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): - """Compute linear growth factor D(a) at a given scale factor, - normalised such that D(a=1) = 1. +def _growth_factor_ODE(cosmo, a): + r"""Compute linear growth factors :math:`D(a)` at given scale factors, + normalised such that :math:`D(a=1) = 1`. Parameters ---------- - a: array_like - Scale factor - - amin: float - Mininum scale factor, default 1e-3 + cosmo : Cosmology + Cosmological parameters. + a : array_like + Scale factors. Returns ------- - D: ndarray, or float if input scalar - Growth factor computed at requested scale factor + cosmo : Cosmology + Cosmological parameters with cached computations. + D : ndarray, or float if input scalar + Growth factors at specified scale factors. + """ # Check if growth has already been computed - if not "background.growth_factor" in cosmo._workspace.keys(): + key = "background.growth_factor" + if not cosmo.is_cached(key): # Compute tabulated array - atab = np.logspace(log10_amin, 0.0, steps) + atab = np.logspace( + cosmo.config.log10_a_min, + cosmo.config.log10_a_max, + cosmo.config.log10_a_steps, + ) def D_derivs(y, x): q = ( @@ -481,59 +535,72 @@ def D_derivs(y, x): # To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da ftab = y[:, 1] / y1[-1] * atab / gtab - cache = {"a": atab, "g": gtab, "f": ftab} - cosmo._workspace["background.growth_factor"] = cache + value = {"a": atab, "g": gtab, "f": ftab} + cosmo = cosmo.cache_set(key, value) else: - cache = cosmo._workspace["background.growth_factor"] - return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) + value = cosmo.cache_get(key) + + D = np.clip(interp(a, value["a"], value["g"]), 0.0, 1.0) + return cosmo, D def _growth_rate_ODE(cosmo, a): - """Compute growth rate dD/dlna at a given scale factor by solving the linear - growth ODE. + r"""Compute growth rates :math:`dD/d\ln\ a` at given scale factors by + solving the linear growth ODE. Parameters ---------- - cosmo: `Cosmology` - Cosmology object - - a: array_like - Scale factor + cosmo : Cosmology + Cosmology parameters. + a : array_like + Scale factors. Returns ------- - f: ndarray, or float if input scalar - Growth rate computed at requested scale factor + cosmo : Cosmology + Cosmological parameters with cached computations. + f : ndarray, or float if input scalar + Growth rates at specified scale factors. + """ # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): - _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) - cache = cosmo._workspace["background.growth_factor"] - return interp(a, cache["a"], cache["f"]) + key = "background.growth_factor" + if not cosmo.is_cached(key): + cosmo, _ = _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) + value = cosmo.cache_get(key) + + f = interp(a, value["a"], value["f"]) + return cosmo, f -def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128): - r"""Computes growth factor by integrating the growth rate provided by the - \gamma parametrization. Normalized such that D( a=1) =1 +def _growth_factor_gamma(cosmo, a): + r"""Growth factors by integrating the :math:`\gamma`-parametrized growth + rates, normalized such that :math:`D(a=1) = 1`. Parameters ---------- - a: array_like - Scale factor - - amin: float - Mininum scale factor, default 1e-3 + cosmo : Cosmology + Cosmological parameters. + a : array_like + Scale factors. Returns ------- - D: ndarray, or float if input scalar - Growth factor computed at requested scale factor + cosmo : Cosmology + Cosmological parameters with cached computations. + D : ndarray, or float if input scalar + Growth factors at specified scale factors. """ # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): + key = "background.growth_factor" + if not cosmo.is_cached(key): # Compute tabulated array - atab = np.logspace(log10_amin, 0.0, steps) + atab = np.logspace( + cosmo.config.log10_a_min, + cosmo.config.log10_a_max, + cosmo.config.log10_a_steps, + ) def integrand(y, loga): xa = np.exp(loga) @@ -541,28 +608,29 @@ def integrand(y, loga): gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab))) gtab = gtab / gtab[-1] # Normalize to a=1. - cache = {"a": atab, "g": gtab} - cosmo._workspace["background.growth_factor"] = cache + value = {"a": atab, "g": gtab} + cosmo = cosmo.cache_set(key, value) else: - cache = cosmo._workspace["background.growth_factor"] - return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) + value = cosmo.cache_get(key) + + D = np.clip(interp(a, value["a"], value["g"]), 0.0, 1.0) + return cosmo, D def _growth_rate_gamma(cosmo, a): - r"""Growth rate approximation at scale factor `a`. + r"""Growth rate approximation at given scale factors. Parameters ---------- - cosmo: `Cosmology` - Cosmology object - + cosmo : Cosmology + Cosmology parameters. a : array_like - Scale factor + Scale factors. Returns ------- f_gamma : ndarray, or float if input scalar - Growth rate approximation at the requested scale factor + Growth rate approximation at specified scale factors. Notes ----- @@ -572,11 +640,12 @@ def _growth_rate_gamma(cosmo, a): f_{\gamma}(a) = \Omega_m^{\gamma} (a) - with :math: `\gamma` in LCDM, given approximately by: + with :math:`\gamma` in LCDM, given approximately by: .. math:: \gamma = 0.55 see :cite:`2019:Euclid Preparation VII, eqn.32` + """ return Omega_m_a(cosmo, a) ** cosmo.gamma diff --git a/jax_cosmo/bias.py b/jax_cosmo/bias.py index 6887ce8..fadd7b2 100644 --- a/jax_cosmo/bias.py +++ b/jax_cosmo/bias.py @@ -13,8 +13,8 @@ class constant_linear_bias(container): """ Class representing a linear bias - Parameters: - ----------- + Parameters + ---------- b: redshift independent bias value """ @@ -29,15 +29,16 @@ class inverse_growth_linear_bias(container): TODO: what's a better name for this? Class representing an inverse bias in 1/growth(a) - Parameters: - ----------- + Parameters + ---------- cosmo: cosmology b: redshift independent bias value at z=0 """ def __call__(self, cosmo, z): b = self.params[0] - return b / bkgrd.growth_factor(cosmo, z2a(z)) + _, D = bkgrd.growth_factor(cosmo, z2a(z)) + return b / D @register_pytree_node_class @@ -45,8 +46,8 @@ class des_y1_ia_bias(container): """ https://arxiv.org/pdf/1708.01538.pdf Sec. VII.B - Parameters: - ----------- + Parameters + ---------- cosmo: cosmology A: amplitude eta: redshift dependent slope diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index 4ff49d6..9439af8 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -1,199 +1,157 @@ +"""This module implements the Cosmology type containing cosmological parameters +and cached computations, and the Configuration type containing configuration parameters. + +""" +from dataclasses import dataclass +from dataclasses import field +from dataclasses import replace +from functools import partial +from pprint import pformat +from typing import Any +from typing import Dict # use dict instead for python >= 3.10 +from typing import Optional + import jax.numpy as np -from jax.experimental.ode import odeint -from jax.tree_util import register_pytree_node_class -import jax_cosmo.constants as const -from jax_cosmo.utils import a2z -from jax_cosmo.utils import z2a +from jax_cosmo.dataclasses import pytree_dataclass -__all__ = ["Cosmology"] +__all__ = ["Cosmology", "Configuration"] -@register_pytree_node_class -class Cosmology: - def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None): - """ - Cosmology object, stores primary and derived cosmological parameters. - - Parameters: - ----------- - Omega_c, float - Cold dark matter density fraction. - Omega_b, float - Baryonic matter density fraction. - h, float - Hubble constant divided by 100 km/s/Mpc; unitless. - n_s, float - Primordial scalar perturbation spectral index. - sigma8, float - Variance of matter density perturbations at an 8 Mpc/h scale - Omega_k, float - Curvature density fraction. - w0, float - First order term of dark energy equation - wa, float - Second order term of dark energy equation of state - gamma: float - Index of the growth rate (optional) - - Notes: - ------ - - If `gamma` is specified, the emprical characterisation of growth in - terms of dlnD/dlna = \omega^\gamma will be used to define growth throughout. - Otherwise the linear growth factor and growth rate will be solved by ODE. - - """ - # Store primary parameters - self._Omega_c = Omega_c - self._Omega_b = Omega_b - self._h = h - self._n_s = n_s - self._sigma8 = sigma8 - self._Omega_k = Omega_k - self._w0 = w0 - self._wa = wa - - self._flags = {} - - # Secondary optional parameters - self._gamma = gamma - self._flags["gamma_growth"] = gamma is not None - - # Create a workspace where functions can store some precomputed - # results - self._workspace = {} +@dataclass(frozen=True) +class Configuration: + """Configuration parameters, that are not to be traced by JAX. - def __str__(self): - return ( - "Cosmological parameters: \n" - + " h: " - + str(self.h) - + " \n" - + " Omega_b: " - + str(self.Omega_b) - + " \n" - + " Omega_c: " - + str(self.Omega_c) - + " \n" - + " Omega_k: " - + str(self.Omega_k) - + " \n" - + " w0: " - + str(self.w0) - + " \n" - + " wa: " - + str(self.wa) - + " \n" - + " n: " - + str(self.n_s) - + " \n" - + " sigma8: " - + str(self.sigma8) - ) - - def __repr__(self): - return self.__str__() - - # Operations for flattening/unflattening representation - def tree_flatten(self): - params = ( - self._Omega_c, - self._Omega_b, - self._h, - self._n_s, - self._sigma8, - self._Omega_k, - self._w0, - self._wa, - ) - - if self._flags["gamma_growth"]: - params += (self._gamma,) - - return ( - params, - self._flags, - ) - - @classmethod - def tree_unflatten(cls, aux_data, children): - # Retrieve base parameters - Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa = children[:8] - children = list(children[8:]).reverse() - - # We extract the remaining parameters in reverse order from how they - # were inserted - if aux_data["gamma_growth"]: - gamma = children.pop() - else: - gamma = None - - return cls( - Omega_c=Omega_c, - Omega_b=Omega_b, - h=h, - n_s=n_s, - sigma8=sigma8, - Omega_k=Omega_k, - w0=w0, - wa=wa, - gamma=gamma, - ) - - # Cosmological parameters, base and derived - @property - def Omega(self): - return 1.0 - self._Omega_k + Parameters + ---------- + log10_a_min : float, optional + Minimum in scale factor logspace range + log10_a_max : float, optional + Maximum in scale factor logspace range + log10_a_num : int, optional + Number of samples in scale factor logspace range + growth_rtol : float, optional + Relative error tolerance for solving growth ODEs + growth_atol : float, optional + Absolute error tolerance for solving growth ODEs - @property - def Omega_b(self): - return self._Omega_b + log10_k_min : float, optional + Minimum in wavenumber logspace range + log10_k_max : float, optional + Maximum in wavenumber logspace range - @property - def Omega_c(self): - return self._Omega_c + """ - @property - def Omega_m(self): - return self._Omega_b + self._Omega_c + log10_a_min: float = -3.0 + log10_a_max: float = 0.0 + log10_a_steps: int = 256 # TODO revisit after improving odeint and interpolation + growth_atol: float = 0.0 + growth_rtol: float = 1e-4 - @property - def Omega_de(self): - return self.Omega - self.Omega_m + log10_k_min: float = -4.0 + log10_k_max: float = 3.0 - @property - def Omega_k(self): - return self._Omega_k - @property - def k(self): - return -np.sign(self._Omega_k).astype(np.int8) +@partial(pytree_dataclass, aux_fields="config", frozen=True) +class Cosmology: + """ + Cosmology parameter type, containing primary, secondary, derived parameters, + cached computations, and configurations; immutable as a frozen dataclass. + + Parameters + ---------- + Omega_c : float + Cold dark matter density fraction. + Omega_b : float + Baryonic matter density fraction. + h : float + Hubble constant divided by 100 km/s/Mpc; unitless. + n_s : float + Primordial scalar perturbation spectral index. + sigma8 : float + RMS of matter density perturbations in an 8 Mpc/h spherical tophat. + Omega_k : float + Curvature density fraction. + w0 : float + First order term of dark energy equation. + wa : float + Second order term of dark energy equation of state. + gamma : float, optional + Exponent of growth rate fitting formula. + config : Configuration, optional + Configuration parameters. + + Notes + ----- + + If `gamma` is specified, the growth rate fitting formula + :math:`dlnD/dlna = \Omega_m(a)^\gamma` will be used to model the growth history. + Otherwise the linear growth factor and growth rate will be solved by ODE. + + """ + + # Primary parameters + Omega_c: float + Omega_b: float + h: float + n_s: float + sigma8: float + Omega_k: float + w0: float + wa: float + + # Secondary optional parameters + gamma: Optional[float] = None + + # cache for intermediate computations; + # users should not access it directly but use the class methods instead + _cache: Dict[str, Any] = field(default_factory=dict, repr=False, compare=False) + + # configuration parameters, immutable (frozen dataclass) + config: Configuration = field(default_factory=Configuration) - @property - def sqrtk(self): - return np.sqrt(np.abs(self._Omega_k)) + def __str__(self): + return pformat(self, indent=4, width=1) # for python >= 3.10 - @property - def h(self): - return self._h + def is_cached(self, key): + return key in self._cache + + def cache_get(self, key): + return self._cache[key] + + def cache_set(self, key, value): + """Add key-value pair to cache and return a new ``Cosmology`` instance.""" + cache = self._cache.copy() + cache[key] = value + return replace(self, _cache=cache) + + def cache_del(self, key): + """Remove key from cache and return a new ``Cosmology`` instance.""" + cache = self._cache.copy() + del cache[key] + return replace(self, _cache=cache) + def cache_clear(self): + """Return a new ``Cosmology`` instance with empty cache.""" + return replace(self, _cache={}) + + # Derived parameters @property - def w0(self): - return self._w0 + def Omega(self): + return 1.0 - self.Omega_k @property - def wa(self): - return self._wa + def Omega_m(self): + return self.Omega_b + self.Omega_c @property - def n_s(self): - return self._n_s + def Omega_de(self): + return self.Omega - self.Omega_m @property - def sigma8(self): - return self._sigma8 + def k(self): + return -np.sign(self.Omega_k).astype(np.int8) @property - def gamma(self): - return self._gamma + def sqrtk(self): + return np.sqrt(np.abs(self.Omega_k)) diff --git a/jax_cosmo/dataclasses.py b/jax_cosmo/dataclasses.py new file mode 100644 index 0000000..a78d460 --- /dev/null +++ b/jax_cosmo/dataclasses.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from dataclasses import fields +from dataclasses import is_dataclass + +from jax.tree_util import register_pytree_node + + +def pytree_dataclass(cls, aux_fields=None, **kwargs): + """Register python dataclasses as custom pytree nodes. + + Parameters + ---------- + cls : type + Class to be registered, not a python dataclass yet. + aux_fields : str or Sequence[str], optional + Fields to be treated as pytree aux_data; unrecognized ones are ignored. + kwargs + Keyword arguments to be passed to python dataclass decorator. + + Returns + ------- + cls : type + Registered dataclass. + + Raises + ------ + TypeError + If cls is already a python dataclass. + + .. _Augmented dataclass for JAX pytree: + https://gist.github.com/odashi/813810a5bc06724ea3643456f8d3942d + + """ + if is_dataclass(cls): + raise TypeError("cls cannot already be a dataclass") + cls = dataclass(cls, **kwargs) + + if aux_fields is None: + aux_fields = [] + elif isinstance(aux_fields, str): + aux_fields = [aux_fields] + akeys = [field.name for field in fields(cls) if field.name in aux_fields] + ckeys = [field.name for field in fields(cls) if field.name not in aux_fields] + + def tree_flatten(obj): + children = [getattr(obj, key) for key in ckeys] + aux_data = [getattr(obj, key) for key in akeys] + return children, aux_data + + def tree_unflatten(aux_data, children): + return cls(**dict(zip(ckeys, children)), **dict(zip(akeys, aux_data))) + + register_pytree_node(cls, tree_flatten, tree_unflatten) + + return cls diff --git a/jax_cosmo/power.py b/jax_cosmo/power.py index 4d28cb1..16b9aac 100644 --- a/jax_cosmo/power.py +++ b/jax_cosmo/power.py @@ -42,7 +42,7 @@ def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwar """ k = np.atleast_1d(k) a = np.atleast_1d(a) - g = bkgrd.growth_factor(cosmo, a) + cosmo, g = bkgrd.growth_factor(cosmo, a) t = transfer_fn(cosmo, k, **kwargs) pknorm = cosmo.sigma8 ** 2 / sigmasqr(cosmo, 8.0, transfer_fn, **kwargs) @@ -54,18 +54,18 @@ def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwar return pk.squeeze() -def sigmasqr(cosmo, R, transfer_fn, kmin=0.0001, kmax=1000.0, ksteps=5, **kwargs): - """Computes the energy of the fluctuations within a sphere of R h^{-1} Mpc +def sigmasqr(cosmo, R, transfer_fn, **kwargs): + r"""Computes the energy of the fluctuations within a sphere of R h^{-1} Mpc .. math:: - \\sigma^2(R)= \\frac{1}{2 \\pi^2} \\int_0^\\infty \\frac{dk}{k} k^3 P(k,z) W^2(kR) + \sigma^2(R)= \frac{1}{2 \pi^2} \int_0^\infty \frac{dk}{k} k^3 P(k,z) W^2(kR) where .. math:: - W(kR) = \\frac{3j_1(kR)}{kR} + W(kR) = \frac{3j_1(kR)}{kR} """ def int_sigma(logk): @@ -75,7 +75,7 @@ def int_sigma(logk): pk = transfer_fn(cosmo, k, **kwargs) ** 2 * primordial_matter_power(cosmo, k) return k * (k * w) ** 2 * pk - y = romb(int_sigma, np.log10(kmin), np.log10(kmax), divmax=7) + y = romb(int_sigma, cosmo.config.log10_k_min, cosmo.config.log10_k_max, divmax=7) return 1.0 / (2.0 * np.pi ** 2.0) * y @@ -101,7 +101,7 @@ def int_sigma(logk): r = np.exp(logr) y = np.outer(k, r) pk = linear_matter_power(cosmo, k, transfer_fn=transfer_fn) - g = bkgrd.growth_factor(cosmo, np.atleast_1d(a)) + _, g = bkgrd.growth_factor(cosmo, np.atleast_1d(a)) return ( np.expand_dims(pk * k ** 3, axis=1) * np.exp(-(y ** 2)) @@ -123,7 +123,8 @@ def integrand(logk): k = np.exp(logk) y = np.outer(k, 1.0 / k_nl) pk = linear_matter_power(cosmo, k, transfer_fn=transfer_fn) - g = np.expand_dims(bkgrd.growth_factor(cosmo, np.atleast_1d(a)), 0) + _, g = bkgrd.growth_factor(cosmo, np.atleast_1d(a)) + g = np.expand_dims(g, 0) res = ( np.expand_dims(pk * k ** 3, axis=1) * np.exp(-(y ** 2)) diff --git a/jax_cosmo/probes.py b/jax_cosmo/probes.py index fe47895..825f4a0 100644 --- a/jax_cosmo/probes.py +++ b/jax_cosmo/probes.py @@ -26,7 +26,7 @@ def weak_lensing_kernel(cosmo, pzs, z, ell): z = np.atleast_1d(z) zmax = max([pz.zmax for pz in pzs]) # Retrieve comoving distance corresponding to z - chi = bkgrd.radial_comoving_distance(cosmo, z2a(z)) + cosmo, chi = bkgrd.radial_comoving_distance(cosmo, z2a(z)) # Extract the indices of pzs that can be treated as extended distributions, # and the ones that need to be treated as delta functions. @@ -45,7 +45,7 @@ def weak_lensing_kernel(cosmo, pzs, z, ell): @vmap def integrand(z_prime): - chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime)) + _, chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime)) # Stack the dndz of all redshift bins dndz = np.stack([pzs[i](z_prime) for i in pzs_extended_idx], axis=0) return dndz * np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0) @@ -56,7 +56,7 @@ def integrand(z_prime): @vmap def integrand_single(z_prime): - chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime)) + _, chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime)) return np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0) radial_kernels.append( @@ -73,7 +73,8 @@ def integrand_single(z_prime): constant_factor = 3.0 * const.H0 ** 2 * cosmo.Omega_m / 2.0 / const.c # Ell dependent factor ell_factor = np.sqrt((ell - 1) * (ell) * (ell + 1) * (ell + 2)) / (ell + 0.5) ** 2 - return constant_factor * ell_factor * radial_kernel + kernel = constant_factor * ell_factor * radial_kernel + return kernel @jit @@ -121,9 +122,8 @@ def nla_kernel(cosmo, pzs, bias, z, ell): radial_kernel = dndz * b * bkgrd.H(cosmo, z2a(z)) # Apply common A_IA normalization to the kernel # Joachimi et al. (2011), arXiv: 1008.3491, Eq. 6. - radial_kernel *= ( - -(5e-14 * const.rhocrit) * cosmo.Omega_m / bkgrd.growth_factor(cosmo, z2a(z)) - ) + _, D = bkgrd.growth_factor(cosmo, z2a(z)) + radial_kernel *= -(5e-14 * const.rhocrit) * cosmo.Omega_m / D # Constant factor constant_factor = 1.0 # Ell dependent factor @@ -136,16 +136,16 @@ class WeakLensing(container): """ Class representing a weak lensing probe, with a bunch of bins - Parameters: - ----------- + Parameters + ---------- redshift_bins: list of nzredshift distributions ia_bias: (optional) if provided, IA will be added with the NLA model, either a single bias object or a list of same size as nzs multiplicative_bias: (optional) adds an (1+m) multiplicative bias, either single value or list of same length as redshift bins - Configuration: - -------------- + Configuration + ------------- sigma_e: intrinsic galaxy ellipticity """ @@ -191,8 +191,8 @@ def kernel(self, cosmo, z, ell): """ Compute the radial kernel for all nz bins in this probe. - Returns: - -------- + Returns + ------- radial_kernel: shape (nbins, nz) """ z = np.atleast_1d(z) @@ -229,12 +229,12 @@ def noise(self): class NumberCounts(container): """Class representing a galaxy clustering probe, with a bunch of bins - Parameters: - ----------- + Parameters + ---------- redshift_bins: nzredshift distributions - Configuration: - -------------- + Configuration + ------------- has_rsd.... """ @@ -262,8 +262,8 @@ def n_tracers(self): def kernel(self, cosmo, z, ell): """Compute the radial kernel for all nz bins in this probe. - Returns: - -------- + Returns + ------- radial_kernel: shape (nbins, nz) """ z = np.atleast_1d(z) diff --git a/jax_cosmo/redshift.py b/jax_cosmo/redshift.py index e02e9c7..6938167 100644 --- a/jax_cosmo/redshift.py +++ b/jax_cosmo/redshift.py @@ -62,8 +62,8 @@ def tree_unflatten(cls, aux_data, children): @register_pytree_node_class class smail_nz(redshift_distribution): """Defines a smail distribution with these arguments - Parameters: - ----------- + Parameters + ---------- a: b: @@ -81,8 +81,8 @@ def pz_fn(self, z): @register_pytree_node_class class delta_nz(redshift_distribution): """Defines a single plane redshift distribution with these arguments - Parameters: - ----------- + Parameters + ---------- z0: """ @@ -102,13 +102,13 @@ class kde_nz(redshift_distribution): given catalog currently uses a Gaussian kernel. TODO: add more if necessary - Parameters: - ----------- + Parameters + ---------- zcat: redshift catalog weights: weight for each galaxy between 0 and 1 - Configuration: - -------------- + Configuration + ------------- bw: Bandwidth for the KDE Example: diff --git a/jax_cosmo/transfer.py b/jax_cosmo/transfer.py index 23def59..2728b85 100644 --- a/jax_cosmo/transfer.py +++ b/jax_cosmo/transfer.py @@ -1,7 +1,6 @@ # This module contains various transfer functions from the literatu import jax.numpy as np -import jax_cosmo.background as bkgrd import jax_cosmo.constants as const __all__ = ["Eisenstein_Hu"] diff --git a/setup.py b/setup.py index 5696aa6..8bd6a77 100644 --- a/setup.py +++ b/setup.py @@ -17,14 +17,16 @@ author="jax-cosmo developers", packages=find_packages(), url="https://github.com/DifferentiableUniverseInitiative/jax_cosmo", + python_requires=">=3.7", install_requires=["jax", "jaxlib"], tests_require=["pyccl"], use_scm_version=True, setup_requires=["setuptools_scm"], classifiers=[ - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "License :: OSI Approved :: MIT License", "Operating System :: MacOS", "Operating System :: POSIX :: Linux", diff --git a/tests/test_background.py b/tests/test_background.py index 4f86a91..5f045e2 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -68,15 +68,18 @@ def test_distances_flat(): a = np.linspace(0.01, 1.0) chi_ccl = ccl.comoving_radial_distance(cosmo_ccl, a) - chi_jax = bkgrd.radial_comoving_distance(cosmo_jax, a) / cosmo_jax.h + cosmo_jax, chi_jax = bkgrd.radial_comoving_distance(cosmo_jax, a) + chi_jax = chi_jax / cosmo_jax.h assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2) chi_ccl = ccl.comoving_angular_distance(cosmo_ccl, a) - chi_jax = bkgrd.transverse_comoving_distance(cosmo_jax, a) / cosmo_jax.h + cosmo_jax, chi_jax = bkgrd.transverse_comoving_distance(cosmo_jax, a) + chi_jax = chi_jax / cosmo_jax.h assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2) chi_ccl = ccl.angular_diameter_distance(cosmo_ccl, a) - chi_jax = bkgrd.angular_diameter_distance(cosmo_jax, a) / cosmo_jax.h + cosmo_jax, chi_jax = bkgrd.angular_diameter_distance(cosmo_jax, a) + chi_jax = chi_jax / cosmo_jax.h assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2) @@ -108,7 +111,7 @@ def test_growth(): a = np.linspace(0.01, 1.0) gccl = ccl.growth_factor(cosmo_ccl, a) - gjax = bkgrd.growth_factor(cosmo_jax, a) + cosmo_jax, gjax = bkgrd.growth_factor(cosmo_jax, a) assert_allclose(gccl, gjax, rtol=1e-2) @@ -141,7 +144,7 @@ def test_growth_rate(): a = np.linspace(0.01, 1.0) fccl = ccl.growth_rate(cosmo_ccl, a) - fjax = bkgrd.growth_rate(cosmo_jax, a) + cosmo_jax, fjax = bkgrd.growth_rate(cosmo_jax, a) assert_allclose(fccl, fjax, rtol=1e-2) @@ -176,7 +179,7 @@ def test_growth_rate_gamma(): a = np.linspace(0.01, 1.0) fccl = ccl.growth_rate(cosmo_ccl, a) - fjax = bkgrd.growth_rate(cosmo_jax, a) + cosmo_jax, fjax = bkgrd.growth_rate(cosmo_jax, a) assert_allclose(fccl, fjax, rtol=1e-2) @@ -210,6 +213,6 @@ def test_growth_gamma(): a = np.linspace(0.01, 1.0) gccl = ccl.growth_factor(cosmo_ccl, a) - gjax = bkgrd.growth_factor(cosmo_jax, a) + cosmo_jax, gjax = bkgrd.growth_factor(cosmo_jax, a) assert_allclose(gccl, gjax, rtol=1e-2) diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..95aeab4 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,62 @@ +from dataclasses import FrozenInstanceError + +from numpy.testing import assert_raises + +from jax_cosmo import Configuration +from jax_cosmo import Cosmology + + +def test_Cosmology_immutability(): + cosmo = Cosmology( + Omega_c=0.25, + Omega_b=0.05, + h=0.67, + sigma8=0.8, + n_s=0.96, + Omega_k=0.0, + w0=-1.0, + wa=0.0, + ) + + with assert_raises(FrozenInstanceError): + cosmo.h = 0.74 # Hubble doesn't budge on the tension + + +def test_Conguration_immutability(): + config = Configuration() + + with assert_raises(FrozenInstanceError): + config.log10_a_max = 0.0 + + +def test_cache(): + cosmo = Cosmology( + Omega_c=0.25, + Omega_b=0.05, + h=0.7, + sigma8=0.8, + n_s=0.96, + Omega_k=0.0, + w0=-1.0, + wa=0.0, + ) + + def assert_pure(c1, c2): + assert c1 is not c2 and c1 == c2 and c1._cache != c2._cache + + cosmo = cosmo.cache_set("a", 1) + cosmo = cosmo.cache_set("c", 3) + + assert cosmo.is_cached("a") + assert not cosmo.is_cached("b") + assert cosmo.cache_get("c") == 3 + + cosmo_add_b = cosmo.cache_set("b", 2) + cosmo_del_c = cosmo.cache_del("c") + cosmo_clean = cosmo.cache_clear() + + assert not cosmo_del_c.is_cached("c") + assert_pure(cosmo_add_b, cosmo) # test set purity + assert_pure(cosmo_del_c, cosmo) # test del purity + assert_pure(cosmo_clean, cosmo) # test clear purity + assert len(cosmo_clean._cache) == 0