From 4edca62b39781482ca4ccc7156c18887e2e45cfd Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 19 Jan 2022 11:49:16 -0500 Subject: [PATCH 01/13] Remove unecessary gamma_growth flag --- jax_cosmo/background.py | 14 +++++++------- jax_cosmo/core.py | 38 +++++--------------------------------- 2 files changed, 12 insertions(+), 40 deletions(-) diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index 1a2a182..71103e6 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -109,7 +109,7 @@ def Esqr(cosmo, a): ----- The Hubble parameter at scale factor `a` is given by - :math:`H^2(a) = E^2(a) H_o^2` where :math:`E^2` is obtained through + :math:`H^2(a) = E^2(a) H_0^2` where :math:`E^2` is obtained through Friedman's Equation (see :cite:`2005:Percival`) : .. math:: @@ -374,7 +374,7 @@ def growth_factor(cosmo, a): Parameters ---------- - cosmo: `Cosmology` + cosmo: Cosmology Cosmology object a: array_like @@ -392,7 +392,7 @@ def growth_factor(cosmo, a): assuming the $f = \Omega^\gamma$ growth rate, otherwise the usual ODE for growth will be solved. """ - if cosmo._flags["gamma_growth"]: + if cosmo.gamma is not None: return _growth_factor_gamma(cosmo, a) else: return _growth_factor_ODE(cosmo, a) @@ -403,7 +403,7 @@ def growth_rate(cosmo, a): Parameters ---------- - cosmo: `Cosmology` + cosmo: Cosmology Cosmology object a: array_like @@ -434,7 +434,7 @@ def growth_rate(cosmo, a): see :cite:`2019:Euclid Preparation VII, eqn.32` """ - if cosmo._flags["gamma_growth"]: + if cosmo.gamma is not None: return _growth_rate_gamma(cosmo, a) else: return _growth_rate_ODE(cosmo, a) @@ -494,7 +494,7 @@ def _growth_rate_ODE(cosmo, a): Parameters ---------- - cosmo: `Cosmology` + cosmo: Cosmology Cosmology object a: array_like @@ -553,7 +553,7 @@ def _growth_rate_gamma(cosmo, a): Parameters ---------- - cosmo: `Cosmology` + cosmo: Cosmology Cosmology object a : array_like diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index 4ff49d6..fbf8178 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -54,11 +54,8 @@ def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None self._w0 = w0 self._wa = wa - self._flags = {} - # Secondary optional parameters self._gamma = gamma - self._flags["gamma_growth"] = gamma is not None # Create a workspace where functions can store some precomputed # results @@ -97,7 +94,7 @@ def __repr__(self): # Operations for flattening/unflattening representation def tree_flatten(self): - params = ( + children = ( self._Omega_c, self._Omega_b, self._h, @@ -106,40 +103,15 @@ def tree_flatten(self): self._Omega_k, self._w0, self._wa, + self._gamma, ) + aux_data = None - if self._flags["gamma_growth"]: - params += (self._gamma,) - - return ( - params, - self._flags, - ) + return children, aux_data @classmethod def tree_unflatten(cls, aux_data, children): - # Retrieve base parameters - Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa = children[:8] - children = list(children[8:]).reverse() - - # We extract the remaining parameters in reverse order from how they - # were inserted - if aux_data["gamma_growth"]: - gamma = children.pop() - else: - gamma = None - - return cls( - Omega_c=Omega_c, - Omega_b=Omega_b, - h=h, - n_s=n_s, - sigma8=sigma8, - Omega_k=Omega_k, - w0=w0, - wa=wa, - gamma=gamma, - ) + return cls(*children) # Cosmological parameters, base and derived @property From 333476ef2130bdcf338ee95794f7e2d8f8539125 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Fri, 21 Jan 2022 12:10:40 -0500 Subject: [PATCH 02/13] Remove seemingly unecessary imports --- jax_cosmo/core.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index fbf8178..b8eb129 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -1,11 +1,6 @@ import jax.numpy as np -from jax.experimental.ode import odeint from jax.tree_util import register_pytree_node_class -import jax_cosmo.constants as const -from jax_cosmo.utils import a2z -from jax_cosmo.utils import z2a - __all__ = ["Cosmology"] From ae9ab992dd7f9c763258f996fdb2d433bc8beaf7 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Tue, 25 Jan 2022 20:23:18 -0500 Subject: [PATCH 03/13] Add pytree dataclass decorator --- jax_cosmo/dataclasses.py | 55 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 jax_cosmo/dataclasses.py diff --git a/jax_cosmo/dataclasses.py b/jax_cosmo/dataclasses.py new file mode 100644 index 0000000..a78d460 --- /dev/null +++ b/jax_cosmo/dataclasses.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from dataclasses import fields +from dataclasses import is_dataclass + +from jax.tree_util import register_pytree_node + + +def pytree_dataclass(cls, aux_fields=None, **kwargs): + """Register python dataclasses as custom pytree nodes. + + Parameters + ---------- + cls : type + Class to be registered, not a python dataclass yet. + aux_fields : str or Sequence[str], optional + Fields to be treated as pytree aux_data; unrecognized ones are ignored. + kwargs + Keyword arguments to be passed to python dataclass decorator. + + Returns + ------- + cls : type + Registered dataclass. + + Raises + ------ + TypeError + If cls is already a python dataclass. + + .. _Augmented dataclass for JAX pytree: + https://gist.github.com/odashi/813810a5bc06724ea3643456f8d3942d + + """ + if is_dataclass(cls): + raise TypeError("cls cannot already be a dataclass") + cls = dataclass(cls, **kwargs) + + if aux_fields is None: + aux_fields = [] + elif isinstance(aux_fields, str): + aux_fields = [aux_fields] + akeys = [field.name for field in fields(cls) if field.name in aux_fields] + ckeys = [field.name for field in fields(cls) if field.name not in aux_fields] + + def tree_flatten(obj): + children = [getattr(obj, key) for key in ckeys] + aux_data = [getattr(obj, key) for key in akeys] + return children, aux_data + + def tree_unflatten(aux_data, children): + return cls(**dict(zip(ckeys, children)), **dict(zip(akeys, aux_data))) + + register_pytree_node(cls, tree_flatten, tree_unflatten) + + return cls From d7cef13ca6f09ec66825f761e2d5918431d17ffc Mon Sep 17 00:00:00 2001 From: Yin Li Date: Tue, 25 Jan 2022 20:26:12 -0500 Subject: [PATCH 04/13] Refactor Cosmology as pytree dataclass This makes Cosmology semi-immutable, and allows cached results to survive through unflattening of JAX transformations. --- jax_cosmo/background.py | 26 ++--- jax_cosmo/core.py | 207 +++++++++++++--------------------------- tests/test_core.py | 21 ++++ 3 files changed, 98 insertions(+), 156 deletions(-) create mode 100644 tests/test_core.py diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index 71103e6..d0b28df 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -220,7 +220,7 @@ def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=256): \chi(a) = R_H \int_a^1 \frac{da^\prime}{{a^\prime}^2 E(a^\prime)} """ # Check if distances have already been computed - if not "background.radial_comoving_distance" in cosmo._workspace.keys(): + if not "background.radial_comoving_distance" in cosmo._cache.keys(): # Compute tabulated array atab = np.logspace(log10_amin, 0.0, steps) @@ -233,9 +233,9 @@ def dchioverdlna(y, x): chitab = chitab[-1] - chitab cache = {"a": atab, "chi": chitab} - cosmo._workspace["background.radial_comoving_distance"] = cache + cosmo._cache["background.radial_comoving_distance"] = cache else: - cache = cosmo._workspace["background.radial_comoving_distance"] + cache = cosmo._cache["background.radial_comoving_distance"] a = np.atleast_1d(a) # Return the results as an interpolation of the table @@ -260,9 +260,9 @@ def a_of_chi(cosmo, chi): Scale factors corresponding to requested distances """ # Check if distances have already been computed, force computation otherwise - if not "background.radial_comoving_distance" in cosmo._workspace.keys(): + if not "background.radial_comoving_distance" in cosmo._cache.keys(): radial_comoving_distance(cosmo, 1.0) - cache = cosmo._workspace["background.radial_comoving_distance"] + cache = cosmo._cache["background.radial_comoving_distance"] chi = np.atleast_1d(chi) return interp(chi, cache["chi"], cache["a"]) @@ -458,7 +458,7 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): Growth factor computed at requested scale factor """ # Check if growth has already been computed - if not "background.growth_factor" in cosmo._workspace.keys(): + if not "background.growth_factor" in cosmo._cache.keys(): # Compute tabulated array atab = np.logspace(log10_amin, 0.0, steps) @@ -482,9 +482,9 @@ def D_derivs(y, x): ftab = y[:, 1] / y1[-1] * atab / gtab cache = {"a": atab, "g": gtab, "f": ftab} - cosmo._workspace["background.growth_factor"] = cache + cosmo._cache["background.growth_factor"] = cache else: - cache = cosmo._workspace["background.growth_factor"] + cache = cosmo._cache["background.growth_factor"] return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) @@ -506,9 +506,9 @@ def _growth_rate_ODE(cosmo, a): Growth rate computed at requested scale factor """ # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): + if not "background.growth_factor" in cosmo._cache.keys(): _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) - cache = cosmo._workspace["background.growth_factor"] + cache = cosmo._cache["background.growth_factor"] return interp(a, cache["a"], cache["f"]) @@ -531,7 +531,7 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128): """ # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._workspace.keys(): + if not "background.growth_factor" in cosmo._cache.keys(): # Compute tabulated array atab = np.logspace(log10_amin, 0.0, steps) @@ -542,9 +542,9 @@ def integrand(y, loga): gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab))) gtab = gtab / gtab[-1] # Normalize to a=1. cache = {"a": atab, "g": gtab} - cosmo._workspace["background.growth_factor"] = cache + cosmo._cache["background.growth_factor"] = cache else: - cache = cosmo._workspace["background.growth_factor"] + cache = cosmo._cache["background.growth_factor"] return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index b8eb129..0fb13d9 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -1,166 +1,87 @@ +from dataclasses import field +from functools import partial +from pprint import pformat +from typing import Any +from typing import Optional + import jax.numpy as np -from jax.tree_util import register_pytree_node_class + +from jax_cosmo.dataclasses import pytree_dataclass __all__ = ["Cosmology"] -@register_pytree_node_class +@partial(pytree_dataclass, frozen=True) class Cosmology: - def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None): - """ - Cosmology object, stores primary and derived cosmological parameters. - - Parameters: - ----------- - Omega_c, float - Cold dark matter density fraction. - Omega_b, float - Baryonic matter density fraction. - h, float - Hubble constant divided by 100 km/s/Mpc; unitless. - n_s, float - Primordial scalar perturbation spectral index. - sigma8, float - Variance of matter density perturbations at an 8 Mpc/h scale - Omega_k, float - Curvature density fraction. - w0, float - First order term of dark energy equation - wa, float - Second order term of dark energy equation of state - gamma: float - Index of the growth rate (optional) - - Notes: - ------ - - If `gamma` is specified, the emprical characterisation of growth in - terms of dlnD/dlna = \omega^\gamma will be used to define growth throughout. - Otherwise the linear growth factor and growth rate will be solved by ODE. - - """ - # Store primary parameters - self._Omega_c = Omega_c - self._Omega_b = Omega_b - self._h = h - self._n_s = n_s - self._sigma8 = sigma8 - self._Omega_k = Omega_k - self._w0 = w0 - self._wa = wa - - # Secondary optional parameters - self._gamma = gamma - - # Create a workspace where functions can store some precomputed - # results - self._workspace = {} + """ + Cosmology parameter class, including primary, secondary, and derived parameters; immutable. + + Parameters: + ----------- + Omega_c : float + Cold dark matter density fraction. + Omega_b : float + Baryonic matter density fraction. + h : float + Hubble constant divided by 100 km/s/Mpc; unitless. + n_s : float + Primordial scalar perturbation spectral index. + sigma8 : float + RMS of matter density perturbations in an 8 Mpc/h spherical tophat. + Omega_k : float + Curvature density fraction. + w0 : float + First order term of dark energy equation. + wa : float + Second order term of dark energy equation of state. + gamma : float, optional + Exponent of growth rate fitting formula. + + Notes: + ------ + + If `gamma` is specified, the growth rate fitting formula + :math:`dlnD/dlna = \Omega_m(a)^\gamma` will be used to model the growth history. + Otherwise the linear growth factor and growth rate will be solved by ODE. + + """ + + # Primary parameters + Omega_c: float + Omega_b: float + h: float + n_s: float + sigma8: float + Omega_k: float + w0: float + wa: float + + # Secondary optional parameters + gamma: Optional[float] = None + + # cached intermediate computations + _cache: dict[str, Any] = field(default_factory=dict, repr=False, compare=False) def __str__(self): - return ( - "Cosmological parameters: \n" - + " h: " - + str(self.h) - + " \n" - + " Omega_b: " - + str(self.Omega_b) - + " \n" - + " Omega_c: " - + str(self.Omega_c) - + " \n" - + " Omega_k: " - + str(self.Omega_k) - + " \n" - + " w0: " - + str(self.w0) - + " \n" - + " wa: " - + str(self.wa) - + " \n" - + " n: " - + str(self.n_s) - + " \n" - + " sigma8: " - + str(self.sigma8) - ) - - def __repr__(self): - return self.__str__() - - # Operations for flattening/unflattening representation - def tree_flatten(self): - children = ( - self._Omega_c, - self._Omega_b, - self._h, - self._n_s, - self._sigma8, - self._Omega_k, - self._w0, - self._wa, - self._gamma, - ) - aux_data = None - - return children, aux_data - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(*children) - - # Cosmological parameters, base and derived - @property - def Omega(self): - return 1.0 - self._Omega_k - - @property - def Omega_b(self): - return self._Omega_b + return pformat(self, indent=4, width=1) # for python >= 3.10 + # Derived parameters @property - def Omega_c(self): - return self._Omega_c + def Omega(self): + return 1.0 - self.Omega_k @property def Omega_m(self): - return self._Omega_b + self._Omega_c + return self.Omega_b + self.Omega_c @property def Omega_de(self): return self.Omega - self.Omega_m - @property - def Omega_k(self): - return self._Omega_k - @property def k(self): - return -np.sign(self._Omega_k).astype(np.int8) + return -np.sign(self.Omega_k).astype(np.int8) @property def sqrtk(self): - return np.sqrt(np.abs(self._Omega_k)) - - @property - def h(self): - return self._h - - @property - def w0(self): - return self._w0 - - @property - def wa(self): - return self._wa - - @property - def n_s(self): - return self._n_s - - @property - def sigma8(self): - return self._sigma8 - - @property - def gamma(self): - return self._gamma + return np.sqrt(np.abs(self.Omega_k)) diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..1b28566 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,21 @@ +from dataclasses import FrozenInstanceError + +from numpy.testing import assert_raises + +from jax_cosmo import Cosmology + + +def test_Cosmology_immutability(): + cosmo = Cosmology( + Omega_c=0.3, + Omega_b=0.05, + h=0.67, + sigma8=0.8, + n_s=0.96, + Omega_k=0.0, + w0=-1.0, + wa=0.0, + ) + + with assert_raises(FrozenInstanceError): + cosmo.h = 0.74 # Hubble doesn't budge on the tension From 7c848b29d36a477f2468558ea6ed9b64418d53da Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 26 Jan 2022 13:53:01 -0500 Subject: [PATCH 05/13] Fix side effects due to Cosmology cache --- jax_cosmo/background.py | 365 +++++++++++++++++++++++---------------- jax_cosmo/core.py | 38 ++-- tests/test_background.py | 17 +- 3 files changed, 251 insertions(+), 169 deletions(-) diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index d0b28df..dbc1c5d 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -1,4 +1,7 @@ -# This module implements various functions for the background COSMOLOGY +"""This module implements various functions for the cosmological background +and linear perturbations. + +""" import jax.numpy as np from jax import lax @@ -14,6 +17,7 @@ "Omega_m_a", "Omega_de_a", "radial_comoving_distance", + "a_of_chi", "dchioverda", "transverse_comoving_distance", "angular_diameter_distance", @@ -23,55 +27,53 @@ def w(cosmo, a): - r"""Dark Energy equation of state parameter using the Linder + r"""Dark energy equation of state parameter using the Linder parametrisation. Parameters ---------- - cosmo: Cosmology - Cosmological parameters structure - + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- w : ndarray, or float if input scalar - The Dark Energy equation of state parameter at the specified - scale factor + Dark energy equation of state parameters at specified scale factors. Notes ----- - - The Linder parametrization :cite:`2003:Linder` for the Dark Energy + The Linder parametrization :cite:`2003:Linder` for the dark energy equation of state :math:`p = w \rho` is given by: .. math:: w(a) = w_0 + w_a (1 - a) + """ return cosmo.w0 + (1.0 - a) * cosmo.wa # Equation (6) in Linder (2003) def f_de(cosmo, a): - r"""Evolution parameter for the Dark Energy density. + r"""Evolution parameter for the dark energy density. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- f : ndarray, or float if input scalar - The evolution parameter of the Dark Energy density as a function - of scale factor + Dark energy density evolution parameters at specified scale factors. Notes ----- - - For a given parametrisation of the Dark Energy equation of state, - the scaling of the Dark Energy density with time can be written as: + For a given parametrisation of the dark energy equation of state, + the scaling of the dark energy density with time can be written as: .. math:: @@ -86,28 +88,28 @@ def f_de(cosmo, a): .. math:: f(a) = -3 (1 + w_0 + w_a) \ln(a) + 3 w_a (a - 1) + """ return -3.0 * (1.0 + cosmo.w0 + cosmo.wa) * np.log(a) + 3.0 * cosmo.wa * (a - 1.0) def Esqr(cosmo, a): - r"""Square of the scale factor dependent factor E(a) in the Hubble - parameter. + r"""Squared time scaling factors E(a) of the Hubble expansion. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- E^2 : ndarray, or float if input scalar - Square of the scaling of the Hubble constant as a function of - scale factor + Squared scaling of the Hubble expansion at specified scale factors. Notes ----- - The Hubble parameter at scale factor `a` is given by :math:`H^2(a) = E^2(a) H_0^2` where :math:`E^2` is obtained through Friedman's Equation (see :cite:`2005:Percival`) : @@ -116,8 +118,9 @@ def Esqr(cosmo, a): E^2(a) = \Omega_m a^{-3} + \Omega_k a^{-2} + \Omega_{de} e^{f(a)} - where :math:`f(a)` is the Dark Energy evolution parameter computed + where :math:`f(a)` is the dark energy evolution parameter computed by :py:meth:`.f_de`. + """ return ( cosmo.Omega_m * np.power(a, -3) @@ -127,33 +130,38 @@ def Esqr(cosmo, a): def H(cosmo, a): - r"""Hubble parameter [km/s/(Mpc/h)] at scale factor `a` + r"""Hubble expansion rate [km/s/(Mpc/h)] at given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- H : ndarray, or float if input scalar - Hubble parameter at the requested scale factor. + Hubble parameters at specified scale factors. + """ return const.H0 * np.sqrt(Esqr(cosmo, a)) def Omega_m_a(cosmo, a): - r"""Matter density at scale factor `a`. + r"""Matter density at given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- Omega_m : ndarray, or float if input scalar - Non-relativistic matter density at the requested scale factor + Non-relativistic matter density at specified scale factors. Notes ----- @@ -164,63 +172,72 @@ def Omega_m_a(cosmo, a): \Omega_m(a) = \frac{\Omega_m a^{-3}}{E^2(a)} see :cite:`2005:Percival` Eq. (6) + """ return cosmo.Omega_m * np.power(a, -3) / Esqr(cosmo, a) def Omega_de_a(cosmo, a): - r"""Dark Energy density at scale factor `a`. + r"""Dark energy density at given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- Omega_de : ndarray, or float if input scalar - Dark Energy density at the requested scale factor + Dark energy density at specified scale factors. Notes ----- - The evolution of Dark Energy density :math:`\Omega_{de}(a)` is given + The evolution of dark energy density :math:`\Omega_{de}(a)` is given by: .. math:: \Omega_{de}(a) = \frac{\Omega_{de} e^{f(a)}}{E^2(a)} - where :math:`f(a)` is the Dark Energy evolution parameter computed by + where :math:`f(a)` is the dark energy evolution parameter computed by :py:meth:`.f_de` (see :cite:`2005:Percival` Eq. (6)). + """ return cosmo.Omega_de * np.exp(f_de(cosmo, a)) / Esqr(cosmo, a) def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=256): - r"""Radial comoving distance in [Mpc/h] for a given scale factor. + r"""Radial comoving distances in [Mpc/h] at given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- + cosmo : Cosmology + Cosmological parameters with cached computations. chi : ndarray, or float if input scalar - Radial comoving distance corresponding to the specified scale - factor. + Radial comoving distances at specified scale factors. Notes ----- - The radial comoving distance is computed by performing the following + The radial comoving distances is computed by performing the following integration: .. math:: \chi(a) = R_H \int_a^1 \frac{da^\prime}{{a^\prime}^2 E(a^\prime)} + """ # Check if distances have already been computed - if not "background.radial_comoving_distance" in cosmo._cache.keys(): + key = "background.radial_comoving_distance" + if not cosmo.is_cached(key): # Compute tabulated array atab = np.logspace(log10_amin, 0.0, steps) @@ -232,81 +249,92 @@ def dchioverdlna(y, x): # np.clip(- 3000*np.log(atab), 0, 10000)#odeint(dchioverdlna, 0., np.log(atab), cosmo) chitab = chitab[-1] - chitab - cache = {"a": atab, "chi": chitab} - cosmo._cache["background.radial_comoving_distance"] = cache + value = {"a": atab, "chi": chitab} + cosmo = cosmo.cache_set(key, value) else: - cache = cosmo._cache["background.radial_comoving_distance"] + value = cosmo.cache_get(key) a = np.atleast_1d(a) # Return the results as an interpolation of the table - return np.clip(interp(a, cache["a"], cache["chi"]), 0.0) + chi = np.clip(interp(a, value["a"], value["chi"]), 0.0) + return cosmo, chi def a_of_chi(cosmo, chi): - r"""Computes the scale factor for corresponding (array) of radial comoving - distance by reverse linear interpolation. + r"""Computes the scale factors at given radial comoving distances by + reverse linear interpolation. Parameters: ----------- - cosmo: Cosmology - Cosmological parameters - - chi: array-like - radial comoving distance to query. + cosmo : Cosmology + Cosmological parameters. + chi : array-like + Radial comoving distances to query. Returns: -------- + cosmo : Cosmology + Cosmological parameters with cached computations. a : array-like - Scale factors corresponding to requested distances + Scale factors at specified distances. + """ # Check if distances have already been computed, force computation otherwise - if not "background.radial_comoving_distance" in cosmo._cache.keys(): - radial_comoving_distance(cosmo, 1.0) - cache = cosmo._cache["background.radial_comoving_distance"] + key = "background.radial_comoving_distance" + if not cosmo.is_cached(key): + cosmo, _ = radial_comoving_distance(cosmo, 1.0) + value = cosmo.cache_get(key) + chi = np.atleast_1d(chi) - return interp(chi, cache["chi"], cache["a"]) + a = interp(chi, value["chi"], value["a"]) + return cosmo, a def dchioverda(cosmo, a): - r"""Derivative of the radial comoving distance with respect to the - scale factor. + r"""Derivative of the radial comoving distances with respect to the + scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- - dchi/da : ndarray, or float if input scalar - Derivative of the radial comoving distance with respect to the - scale factor at the specified scale factor. + dchi/da : ndarray, or float if input scalar + Derivative of the radial comoving distances with respect to the + scale factors at specified scale factors. Notes ----- - The expression for :math:`\frac{d \chi}{da}` is: .. math:: \frac{d \chi}{da}(a) = \frac{R_H}{a^2 E(a)} + """ return const.rh / (a ** 2 * np.sqrt(Esqr(cosmo, a))) def transverse_comoving_distance(cosmo, a): - r"""Transverse comoving distance in [Mpc/h] for a given scale factor. + r"""Transverse comoving distances in [Mpc/h] for given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- + cosmo : Cosmology + Cosmological parameters with cached computations. f_k : ndarray, or float if input scalar - Transverse comoving distance corresponding to the specified - scale factor. + Transverse comoving distances at specified scale factors. Notes ----- @@ -325,8 +353,8 @@ def transverse_comoving_distance(cosmo, a): \mbox{for } \Omega_k < 0 \end{matrix} \right. + """ - index = cosmo.k + 1 def open_universe(chi): return const.rh / cosmo.sqrtk * np.sinh(cosmo.sqrtk * chi / const.rh) @@ -339,22 +367,28 @@ def close_universe(chi): branches = (open_universe, flat_universe, close_universe) - chi = radial_comoving_distance(cosmo, a) + cosmo, chi = radial_comoving_distance(cosmo, a) - return lax.switch(cosmo.k + 1, branches, chi) + f_k = lax.switch(cosmo.k + 1, branches, chi) + return cosmo, f_k def angular_diameter_distance(cosmo, a): - r"""Angular diameter distance in [Mpc/h] for a given scale factor. + r"""Angular diameter distances in [Mpc/h] for given scale factors. Parameters ---------- + cosmo : Cosmology + Cosmological parameters. a : array_like - Scale factor + Scale factors. Returns ------- + cosmo : Cosmology + Cosmological parameters with cached computations. d_A : ndarray, or float if input scalar + Angular diameter distances at specified scale factors. Notes ----- @@ -364,62 +398,69 @@ def angular_diameter_distance(cosmo, a): .. math:: d_A(a) = a f_k(a) + """ - return a * transverse_comoving_distance(cosmo, a) + cosmo, f_k = transverse_comoving_distance(cosmo, a) + d_A = a * f_k + return cosmo, d_A def growth_factor(cosmo, a): - """Compute linear growth factor D(a) at a given scale factor, - normalized such that D(a=1) = 1. + r"""Compute linear growth factors :math:`D(a)` at given scale factors, + normalized such that :math:`D(a=1) = 1`. Parameters ---------- - cosmo: Cosmology - Cosmology object - - a: array_like - Scale factor + cosmo : Cosmology + Cosmology parameters. + a : array_like + Scale factors. Returns ------- - D: ndarray, or float if input scalar - Growth factor computed at requested scale factor + cosmo : Cosmology + Cosmological parameters with cached computations. + D : ndarray, or float if input scalar + Growth factors at specified scale factors. Notes ----- - The growth computation will depend on the cosmology parametrization, for - instance if the $\gamma$ parameter is defined, the growth will be computed - assuming the $f = \Omega^\gamma$ growth rate, otherwise the usual ODE for - growth will be solved. + The growth computation depends on the cosmology parametrization: + if the :math:`\gamma` parameter is defined, the growth history is computed + assuming the growth rate :math:`f = \Omega_m(a)^\gamma`, otherwise the + usual ODE for growth will be solved. + """ if cosmo.gamma is not None: - return _growth_factor_gamma(cosmo, a) + cosmo, D = _growth_factor_gamma(cosmo, a) else: - return _growth_factor_ODE(cosmo, a) + cosmo, D = _growth_factor_ODE(cosmo, a) + return cosmo, D def growth_rate(cosmo, a): - """Compute growth rate dD/dlna at a given scale factor. + r"""Compute growth rates :math:`dD/d\ln\ a` at given scale factors. Parameters ---------- - cosmo: Cosmology - Cosmology object - - a: array_like - Scale factor + cosmo : Cosmology + Cosmology parameters. + a : array_like + Scale factors. Returns ------- - f: ndarray, or float if input scalar - Growth rate computed at requested scale factor + cosmo : Cosmology + Cosmological parameters with cached computations. + f : ndarray, or float if input scalar + Growth rate at specified scale factors. Notes ----- - The growth computation will depend on the cosmology parametrization, for - instance if the $\gamma$ parameter is defined, the growth will be computed - assuming the $f = \Omega^\gamma$ growth rate, otherwise the usual ODE for - growth will be solved. + The growth computation depends on the cosmology parametrization: + if the :math:`\gamma` parameter is defined, the growth history is computed + assuming the growth rate :math:`f = \Omega_m(a)^\gamma`, otherwise the + usual ODE for growth will be solved. The LCDM approximation to the growth rate :math:`f_{\gamma}(a)` is given by: @@ -427,38 +468,45 @@ def growth_rate(cosmo, a): f_{\gamma}(a) = \Omega_m^{\gamma} (a) - with :math: `\gamma` in LCDM, given approximately by: + with :math:`\gamma` in LCDM, given approximately by: .. math:: \gamma = 0.55 see :cite:`2019:Euclid Preparation VII, eqn.32` + """ if cosmo.gamma is not None: - return _growth_rate_gamma(cosmo, a) + f = _growth_rate_gamma(cosmo, a) else: - return _growth_rate_ODE(cosmo, a) + cosmo, f = _growth_rate_ODE(cosmo, a) + return cosmo, f def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): - """Compute linear growth factor D(a) at a given scale factor, - normalised such that D(a=1) = 1. + r"""Compute linear growth factors :math:`D(a)` at given scale factors, + normalised such that :math:`D(a=1) = 1`. Parameters ---------- - a: array_like - Scale factor - - amin: float - Mininum scale factor, default 1e-3 + cosmo : Cosmology + Cosmological parameters. + a : array_like + Scale factors. + amin : float + Mininum scale factor, default 1e-3. Returns ------- - D: ndarray, or float if input scalar - Growth factor computed at requested scale factor + cosmo : Cosmology + Cosmological parameters with cached computations. + D : ndarray, or float if input scalar + Growth factors at specified scale factors. + """ # Check if growth has already been computed - if not "background.growth_factor" in cosmo._cache.keys(): + key = "background.growth_factor" + if not cosmo.is_cached(key): # Compute tabulated array atab = np.logspace(log10_amin, 0.0, steps) @@ -481,57 +529,68 @@ def D_derivs(y, x): # To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da ftab = y[:, 1] / y1[-1] * atab / gtab - cache = {"a": atab, "g": gtab, "f": ftab} - cosmo._cache["background.growth_factor"] = cache + value = {"a": atab, "g": gtab, "f": ftab} + cosmo = cosmo.cache_set(key, value) else: - cache = cosmo._cache["background.growth_factor"] - return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) + value = cosmo.cache_get(key) + + D = np.clip(interp(a, value["a"], value["g"]), 0.0, 1.0) + return cosmo, D def _growth_rate_ODE(cosmo, a): - """Compute growth rate dD/dlna at a given scale factor by solving the linear - growth ODE. + r"""Compute growth rates :math:`dD/d\ln\ a` at given scale factors by + solving the linear growth ODE. Parameters ---------- - cosmo: Cosmology - Cosmology object - - a: array_like - Scale factor + cosmo : Cosmology + Cosmology parameters. + a : array_like + Scale factors. Returns ------- - f: ndarray, or float if input scalar - Growth rate computed at requested scale factor + cosmo : Cosmology + Cosmological parameters with cached computations. + f : ndarray, or float if input scalar + Growth rates at specified scale factors. + """ # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._cache.keys(): - _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) - cache = cosmo._cache["background.growth_factor"] - return interp(a, cache["a"], cache["f"]) + key = "background.growth_factor" + if not cosmo.is_cached(key): + cosmo, _ = _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) + value = cosmo.cache_get(key) + + f = interp(a, value["a"], value["f"]) + return cosmo, f def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128): - r"""Computes growth factor by integrating the growth rate provided by the - \gamma parametrization. Normalized such that D( a=1) =1 + r"""Growth factors by integrating the :math:`\gamma`-parametrized growth + rates, normalized such that :math:`D(a=1) = 1`. Parameters ---------- - a: array_like - Scale factor - - amin: float - Mininum scale factor, default 1e-3 + cosmo : Cosmology + Cosmological parameters. + a : array_like + Scale factors. + amin : float + Mininum scale factor, default 1e-3 Returns ------- - D: ndarray, or float if input scalar - Growth factor computed at requested scale factor + cosmo : Cosmology + Cosmological parameters with cached computations. + D : ndarray, or float if input scalar + Growth factors at specified scale factors. """ # Check if growth has already been computed, if not, compute it - if not "background.growth_factor" in cosmo._cache.keys(): + key = "background.growth_factor" + if not cosmo.is_cached(key): # Compute tabulated array atab = np.logspace(log10_amin, 0.0, steps) @@ -541,28 +600,29 @@ def integrand(y, loga): gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab))) gtab = gtab / gtab[-1] # Normalize to a=1. - cache = {"a": atab, "g": gtab} - cosmo._cache["background.growth_factor"] = cache + value = {"a": atab, "g": gtab} + cosmo = cosmo.cache_set(key, value) else: - cache = cosmo._cache["background.growth_factor"] - return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) + value = cosmo.cache_get(key) + + D = np.clip(interp(a, value["a"], value["g"]), 0.0, 1.0) + return cosmo, D def _growth_rate_gamma(cosmo, a): - r"""Growth rate approximation at scale factor `a`. + r"""Growth rate approximation at given scale factors. Parameters ---------- - cosmo: Cosmology - Cosmology object - + cosmo : Cosmology + Cosmology parameters. a : array_like - Scale factor + Scale factors. Returns ------- f_gamma : ndarray, or float if input scalar - Growth rate approximation at the requested scale factor + Growth rate approximation at specified scale factors. Notes ----- @@ -572,11 +632,12 @@ def _growth_rate_gamma(cosmo, a): f_{\gamma}(a) = \Omega_m^{\gamma} (a) - with :math: `\gamma` in LCDM, given approximately by: + with :math:`\gamma` in LCDM, given approximately by: .. math:: \gamma = 0.55 see :cite:`2019:Euclid Preparation VII, eqn.32` + """ return Omega_m_a(cosmo, a) ** cosmo.gamma diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index 0fb13d9..26e08f1 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -1,4 +1,5 @@ from dataclasses import field +from dataclasses import replace from functools import partial from pprint import pformat from typing import Any @@ -19,23 +20,23 @@ class Cosmology: Parameters: ----------- Omega_c : float - Cold dark matter density fraction. + Cold dark matter density fraction. Omega_b : float - Baryonic matter density fraction. + Baryonic matter density fraction. h : float - Hubble constant divided by 100 km/s/Mpc; unitless. + Hubble constant divided by 100 km/s/Mpc; unitless. n_s : float - Primordial scalar perturbation spectral index. + Primordial scalar perturbation spectral index. sigma8 : float - RMS of matter density perturbations in an 8 Mpc/h spherical tophat. + RMS of matter density perturbations in an 8 Mpc/h spherical tophat. Omega_k : float - Curvature density fraction. + Curvature density fraction. w0 : float - First order term of dark energy equation. + First order term of dark energy equation. wa : float - Second order term of dark energy equation of state. + Second order term of dark energy equation of state. gamma : float, optional - Exponent of growth rate fitting formula. + Exponent of growth rate fitting formula. Notes: ------ @@ -59,12 +60,29 @@ class Cosmology: # Secondary optional parameters gamma: Optional[float] = None - # cached intermediate computations + # cache for intermediate computations; + # users should not access it directly but use the class methods instead _cache: dict[str, Any] = field(default_factory=dict, repr=False, compare=False) def __str__(self): return pformat(self, indent=4, width=1) # for python >= 3.10 + def is_cached(self, key): + return key in self._cache + + def cache_get(self, key): + return self._cache[key] + + def cache_set(self, key, value): + cache = self._cache + cache[key] = value + return replace(self, _cache=cache) + + def cache_clear(self): + cache = self._cache + cache.clear() + return replace(self, _cache=cache) + # Derived parameters @property def Omega(self): diff --git a/tests/test_background.py b/tests/test_background.py index 4f86a91..5f045e2 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -68,15 +68,18 @@ def test_distances_flat(): a = np.linspace(0.01, 1.0) chi_ccl = ccl.comoving_radial_distance(cosmo_ccl, a) - chi_jax = bkgrd.radial_comoving_distance(cosmo_jax, a) / cosmo_jax.h + cosmo_jax, chi_jax = bkgrd.radial_comoving_distance(cosmo_jax, a) + chi_jax = chi_jax / cosmo_jax.h assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2) chi_ccl = ccl.comoving_angular_distance(cosmo_ccl, a) - chi_jax = bkgrd.transverse_comoving_distance(cosmo_jax, a) / cosmo_jax.h + cosmo_jax, chi_jax = bkgrd.transverse_comoving_distance(cosmo_jax, a) + chi_jax = chi_jax / cosmo_jax.h assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2) chi_ccl = ccl.angular_diameter_distance(cosmo_ccl, a) - chi_jax = bkgrd.angular_diameter_distance(cosmo_jax, a) / cosmo_jax.h + cosmo_jax, chi_jax = bkgrd.angular_diameter_distance(cosmo_jax, a) + chi_jax = chi_jax / cosmo_jax.h assert_allclose(chi_ccl, chi_jax, rtol=0.5e-2) @@ -108,7 +111,7 @@ def test_growth(): a = np.linspace(0.01, 1.0) gccl = ccl.growth_factor(cosmo_ccl, a) - gjax = bkgrd.growth_factor(cosmo_jax, a) + cosmo_jax, gjax = bkgrd.growth_factor(cosmo_jax, a) assert_allclose(gccl, gjax, rtol=1e-2) @@ -141,7 +144,7 @@ def test_growth_rate(): a = np.linspace(0.01, 1.0) fccl = ccl.growth_rate(cosmo_ccl, a) - fjax = bkgrd.growth_rate(cosmo_jax, a) + cosmo_jax, fjax = bkgrd.growth_rate(cosmo_jax, a) assert_allclose(fccl, fjax, rtol=1e-2) @@ -176,7 +179,7 @@ def test_growth_rate_gamma(): a = np.linspace(0.01, 1.0) fccl = ccl.growth_rate(cosmo_ccl, a) - fjax = bkgrd.growth_rate(cosmo_jax, a) + cosmo_jax, fjax = bkgrd.growth_rate(cosmo_jax, a) assert_allclose(fccl, fjax, rtol=1e-2) @@ -210,6 +213,6 @@ def test_growth_gamma(): a = np.linspace(0.01, 1.0) gccl = ccl.growth_factor(cosmo_ccl, a) - gjax = bkgrd.growth_factor(cosmo_jax, a) + cosmo_jax, gjax = bkgrd.growth_factor(cosmo_jax, a) assert_allclose(gccl, gjax, rtol=1e-2) From 12fb76fbbfec4b0c4c185755b4c225233e6ccb6a Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 27 Jan 2022 11:19:26 -0500 Subject: [PATCH 06/13] Fix cache dict side effects --- jax_cosmo/core.py | 14 ++++++++++---- tests/test_core.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index 26e08f1..b59433a 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -74,15 +74,21 @@ def cache_get(self, key): return self._cache[key] def cache_set(self, key, value): - cache = self._cache + """Add key-value pair to cache and return a new ``Cosmology`` instance.""" + cache = self._cache.copy() cache[key] = value return replace(self, _cache=cache) - def cache_clear(self): - cache = self._cache - cache.clear() + def cache_del(self, key): + """Remove key from cache and return a new ``Cosmology`` instance.""" + cache = self._cache.copy() + del cache[key] return replace(self, _cache=cache) + def cache_clear(self): + """Return a new ``Cosmology`` instance with empty cache.""" + return replace(self, _cache={}) + # Derived parameters @property def Omega(self): diff --git a/tests/test_core.py b/tests/test_core.py index 1b28566..ddf08ea 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -7,7 +7,7 @@ def test_Cosmology_immutability(): cosmo = Cosmology( - Omega_c=0.3, + Omega_c=0.25, Omega_b=0.05, h=0.67, sigma8=0.8, @@ -19,3 +19,35 @@ def test_Cosmology_immutability(): with assert_raises(FrozenInstanceError): cosmo.h = 0.74 # Hubble doesn't budge on the tension + + +def test_cache(): + cosmo = Cosmology( + Omega_c=0.25, + Omega_b=0.05, + h=0.7, + sigma8=0.8, + n_s=0.96, + Omega_k=0.0, + w0=-1.0, + wa=0.0, + ) + + cosmo = cosmo.cache_set("a", 1) + cosmo = cosmo.cache_set("c", 3) + assert cosmo.is_cached("a") + assert not cosmo.is_cached("b") + assert cosmo.cache_get("c") == 3 + + cosmo = cosmo.cache_set("b", 2) + cosmo_no_c = cosmo.cache_del("c") + assert cosmo.is_cached("b") + assert not cosmo_no_c.is_cached("c") + assert ( + cosmo_no_c is not cosmo + and cosmo_no_c == cosmo + and cosmo_no_c._cache != cosmo._cache + ) + + cosmo = cosmo.cache_clear() + assert len(cosmo._cache) == 0 From efd6fcb71f253901c8ec8a87ff07d706364d9ecc Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 27 Jan 2022 12:40:12 -0500 Subject: [PATCH 07/13] Add Configuration type, as part of Cosmology --- jax_cosmo/background.py | 28 ++++++++++++++-------- jax_cosmo/core.py | 51 ++++++++++++++++++++++++++++++++++++++--- jax_cosmo/power.py | 10 ++++---- 3 files changed, 71 insertions(+), 18 deletions(-) diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index dbc1c5d..fc9cee2 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -208,7 +208,7 @@ def Omega_de_a(cosmo, a): return cosmo.Omega_de * np.exp(f_de(cosmo, a)) / Esqr(cosmo, a) -def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=256): +def radial_comoving_distance(cosmo, a): r"""Radial comoving distances in [Mpc/h] at given scale factors. Parameters @@ -239,7 +239,11 @@ def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=256): key = "background.radial_comoving_distance" if not cosmo.is_cached(key): # Compute tabulated array - atab = np.logspace(log10_amin, 0.0, steps) + atab = np.logspace( + cosmo.config.log10_a_min, + cosmo.config.log10_a_max, + cosmo.config.log10_a_steps, + ) def dchioverdlna(y, x): xa = np.exp(x) @@ -483,7 +487,7 @@ def growth_rate(cosmo, a): return cosmo, f -def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): +def _growth_factor_ODE(cosmo, a): r"""Compute linear growth factors :math:`D(a)` at given scale factors, normalised such that :math:`D(a=1) = 1`. @@ -493,8 +497,6 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): Cosmological parameters. a : array_like Scale factors. - amin : float - Mininum scale factor, default 1e-3. Returns ------- @@ -508,7 +510,11 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): key = "background.growth_factor" if not cosmo.is_cached(key): # Compute tabulated array - atab = np.logspace(log10_amin, 0.0, steps) + atab = np.logspace( + cosmo.config.log10_a_min, + cosmo.config.log10_a_max, + cosmo.config.log10_a_steps, + ) def D_derivs(y, x): q = ( @@ -567,7 +573,7 @@ def _growth_rate_ODE(cosmo, a): return cosmo, f -def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128): +def _growth_factor_gamma(cosmo, a): r"""Growth factors by integrating the :math:`\gamma`-parametrized growth rates, normalized such that :math:`D(a=1) = 1`. @@ -577,8 +583,6 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128): Cosmological parameters. a : array_like Scale factors. - amin : float - Mininum scale factor, default 1e-3 Returns ------- @@ -592,7 +596,11 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128): key = "background.growth_factor" if not cosmo.is_cached(key): # Compute tabulated array - atab = np.logspace(log10_amin, 0.0, steps) + atab = np.logspace( + cosmo.config.log10_a_min, + cosmo.config.log10_a_max, + cosmo.config.log10_a_steps, + ) def integrand(y, loga): xa = np.exp(loga) diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index b59433a..2e9713b 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -1,3 +1,8 @@ +"""This module implements the Cosmology type containing cosmological parameters +and cached computations, and the Configuration type containing configuration parameters. + +""" +from dataclasses import dataclass from dataclasses import field from dataclasses import replace from functools import partial @@ -9,13 +14,48 @@ from jax_cosmo.dataclasses import pytree_dataclass -__all__ = ["Cosmology"] +__all__ = ["Cosmology", "Configuration"] + + +@dataclass(frozen=True) +class Configuration: + """Configuration parameters, that are not to be traced by JAX. + Parameters: + ----------- + log10_a_min : float, optional + Minimum for scale factor logspace range + log10_a_max : float, optional + Maximum for scale factor logspace range + log10_a_num : int, optional + Number of samples in scale factor logspace range + growth_rtol : float, optional + Relative error tolerance for solving growth ODEs + growth_atol : float, optional + Absolute error tolerance for solving growth ODEs + + log10_k_min : float, optional + Minimum for wavenumber logspace range + log10_k_max : float, optional + Maximum for wavenumber logspace range -@partial(pytree_dataclass, frozen=True) + """ + + log10_a_min: float = -3.0 + log10_a_max: float = 0.0 + log10_a_steps: int = 256 # TODO revisit after improving odeint and interpolation + growth_atol: float = 0.0 + growth_rtol: float = 1e-4 + + log10_k_min: float = -4.0 + log10_k_max: float = 3.0 + + +@partial(pytree_dataclass, aux_fields="config", frozen=True) class Cosmology: """ - Cosmology parameter class, including primary, secondary, and derived parameters; immutable. + Cosmology parameter type, containing primary, secondary, derived parameters, + cached computations, and configurations; immutable as a frozen dataclass. Parameters: ----------- @@ -37,6 +77,8 @@ class Cosmology: Second order term of dark energy equation of state. gamma : float, optional Exponent of growth rate fitting formula. + config : Configuration, optional + Configuration parameters. Notes: ------ @@ -64,6 +106,9 @@ class Cosmology: # users should not access it directly but use the class methods instead _cache: dict[str, Any] = field(default_factory=dict, repr=False, compare=False) + # configuration parameters, immutable (frozen dataclass) + config: Configuration = field(default_factory=Configuration) + def __str__(self): return pformat(self, indent=4, width=1) # for python >= 3.10 diff --git a/jax_cosmo/power.py b/jax_cosmo/power.py index 4d28cb1..0923fe2 100644 --- a/jax_cosmo/power.py +++ b/jax_cosmo/power.py @@ -54,18 +54,18 @@ def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwar return pk.squeeze() -def sigmasqr(cosmo, R, transfer_fn, kmin=0.0001, kmax=1000.0, ksteps=5, **kwargs): - """Computes the energy of the fluctuations within a sphere of R h^{-1} Mpc +def sigmasqr(cosmo, R, transfer_fn, **kwargs): + r"""Computes the energy of the fluctuations within a sphere of R h^{-1} Mpc .. math:: - \\sigma^2(R)= \\frac{1}{2 \\pi^2} \\int_0^\\infty \\frac{dk}{k} k^3 P(k,z) W^2(kR) + \sigma^2(R)= \frac{1}{2 \pi^2} \int_0^\infty \frac{dk}{k} k^3 P(k,z) W^2(kR) where .. math:: - W(kR) = \\frac{3j_1(kR)}{kR} + W(kR) = \frac{3j_1(kR)}{kR} """ def int_sigma(logk): @@ -75,7 +75,7 @@ def int_sigma(logk): pk = transfer_fn(cosmo, k, **kwargs) ** 2 * primordial_matter_power(cosmo, k) return k * (k * w) ** 2 * pk - y = romb(int_sigma, np.log10(kmin), np.log10(kmax), divmax=7) + y = romb(int_sigma, cosmo.config.log10_k_min, cosmo.config.log10_k_max, divmax=7) return 1.0 / (2.0 * np.pi ** 2.0) * y From 00565355acd2b85171741fec646a612ea4c68374 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 27 Jan 2022 13:43:06 -0500 Subject: [PATCH 08/13] Add cache and configuration tests --- tests/test_core.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index ddf08ea..95aeab4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2,6 +2,7 @@ from numpy.testing import assert_raises +from jax_cosmo import Configuration from jax_cosmo import Cosmology @@ -21,6 +22,13 @@ def test_Cosmology_immutability(): cosmo.h = 0.74 # Hubble doesn't budge on the tension +def test_Conguration_immutability(): + config = Configuration() + + with assert_raises(FrozenInstanceError): + config.log10_a_max = 0.0 + + def test_cache(): cosmo = Cosmology( Omega_c=0.25, @@ -33,21 +41,22 @@ def test_cache(): wa=0.0, ) + def assert_pure(c1, c2): + assert c1 is not c2 and c1 == c2 and c1._cache != c2._cache + cosmo = cosmo.cache_set("a", 1) cosmo = cosmo.cache_set("c", 3) + assert cosmo.is_cached("a") assert not cosmo.is_cached("b") assert cosmo.cache_get("c") == 3 - cosmo = cosmo.cache_set("b", 2) - cosmo_no_c = cosmo.cache_del("c") - assert cosmo.is_cached("b") - assert not cosmo_no_c.is_cached("c") - assert ( - cosmo_no_c is not cosmo - and cosmo_no_c == cosmo - and cosmo_no_c._cache != cosmo._cache - ) + cosmo_add_b = cosmo.cache_set("b", 2) + cosmo_del_c = cosmo.cache_del("c") + cosmo_clean = cosmo.cache_clear() - cosmo = cosmo.cache_clear() - assert len(cosmo._cache) == 0 + assert not cosmo_del_c.is_cached("c") + assert_pure(cosmo_add_b, cosmo) # test set purity + assert_pure(cosmo_del_c, cosmo) # test del purity + assert_pure(cosmo_clean, cosmo) # test clear purity + assert len(cosmo_clean._cache) == 0 From b2f3c042f59be28835db2b7051f9049fa521a83c Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 27 Jan 2022 14:56:52 -0500 Subject: [PATCH 09/13] Fix func calls superficially to pass tests --- jax_cosmo/angular_cl.py | 2 +- jax_cosmo/background.py | 8 ++++---- jax_cosmo/bias.py | 15 ++++++++------- jax_cosmo/core.py | 20 ++++++++++---------- jax_cosmo/power.py | 7 ++++--- jax_cosmo/probes.py | 38 +++++++++++++++++++------------------- jax_cosmo/redshift.py | 16 ++++++++-------- jax_cosmo/transfer.py | 1 - 8 files changed, 54 insertions(+), 53 deletions(-) diff --git a/jax_cosmo/angular_cl.py b/jax_cosmo/angular_cl.py index 6db6c9f..3cfe77c 100644 --- a/jax_cosmo/angular_cl.py +++ b/jax_cosmo/angular_cl.py @@ -70,7 +70,7 @@ def angular_cl( def cl(ell): def integrand(a): # Step 1: retrieve the associated comoving distance - chi = bkgrd.radial_comoving_distance(cosmo, a) + _, chi = bkgrd.radial_comoving_distance(cosmo, a) # Step 2: get the power spectrum for this combination of chi and a k = (ell + 0.5) / np.clip(chi, 1.0) diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index fc9cee2..83ae8bd 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -268,15 +268,15 @@ def a_of_chi(cosmo, chi): r"""Computes the scale factors at given radial comoving distances by reverse linear interpolation. - Parameters: - ----------- + Parameters + ---------- cosmo : Cosmology Cosmological parameters. chi : array-like Radial comoving distances to query. - Returns: - -------- + Returns + ------- cosmo : Cosmology Cosmological parameters with cached computations. a : array-like diff --git a/jax_cosmo/bias.py b/jax_cosmo/bias.py index 6887ce8..fadd7b2 100644 --- a/jax_cosmo/bias.py +++ b/jax_cosmo/bias.py @@ -13,8 +13,8 @@ class constant_linear_bias(container): """ Class representing a linear bias - Parameters: - ----------- + Parameters + ---------- b: redshift independent bias value """ @@ -29,15 +29,16 @@ class inverse_growth_linear_bias(container): TODO: what's a better name for this? Class representing an inverse bias in 1/growth(a) - Parameters: - ----------- + Parameters + ---------- cosmo: cosmology b: redshift independent bias value at z=0 """ def __call__(self, cosmo, z): b = self.params[0] - return b / bkgrd.growth_factor(cosmo, z2a(z)) + _, D = bkgrd.growth_factor(cosmo, z2a(z)) + return b / D @register_pytree_node_class @@ -45,8 +46,8 @@ class des_y1_ia_bias(container): """ https://arxiv.org/pdf/1708.01538.pdf Sec. VII.B - Parameters: - ----------- + Parameters + ---------- cosmo: cosmology A: amplitude eta: redshift dependent slope diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index 2e9713b..22ea304 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -21,12 +21,12 @@ class Configuration: """Configuration parameters, that are not to be traced by JAX. - Parameters: - ----------- + Parameters + ---------- log10_a_min : float, optional - Minimum for scale factor logspace range + Minimum in scale factor logspace range log10_a_max : float, optional - Maximum for scale factor logspace range + Maximum in scale factor logspace range log10_a_num : int, optional Number of samples in scale factor logspace range growth_rtol : float, optional @@ -35,9 +35,9 @@ class Configuration: Absolute error tolerance for solving growth ODEs log10_k_min : float, optional - Minimum for wavenumber logspace range + Minimum in wavenumber logspace range log10_k_max : float, optional - Maximum for wavenumber logspace range + Maximum in wavenumber logspace range """ @@ -57,8 +57,8 @@ class Cosmology: Cosmology parameter type, containing primary, secondary, derived parameters, cached computations, and configurations; immutable as a frozen dataclass. - Parameters: - ----------- + Parameters + ---------- Omega_c : float Cold dark matter density fraction. Omega_b : float @@ -80,8 +80,8 @@ class Cosmology: config : Configuration, optional Configuration parameters. - Notes: - ------ + Notes + ----- If `gamma` is specified, the growth rate fitting formula :math:`dlnD/dlna = \Omega_m(a)^\gamma` will be used to model the growth history. diff --git a/jax_cosmo/power.py b/jax_cosmo/power.py index 0923fe2..16b9aac 100644 --- a/jax_cosmo/power.py +++ b/jax_cosmo/power.py @@ -42,7 +42,7 @@ def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwar """ k = np.atleast_1d(k) a = np.atleast_1d(a) - g = bkgrd.growth_factor(cosmo, a) + cosmo, g = bkgrd.growth_factor(cosmo, a) t = transfer_fn(cosmo, k, **kwargs) pknorm = cosmo.sigma8 ** 2 / sigmasqr(cosmo, 8.0, transfer_fn, **kwargs) @@ -101,7 +101,7 @@ def int_sigma(logk): r = np.exp(logr) y = np.outer(k, r) pk = linear_matter_power(cosmo, k, transfer_fn=transfer_fn) - g = bkgrd.growth_factor(cosmo, np.atleast_1d(a)) + _, g = bkgrd.growth_factor(cosmo, np.atleast_1d(a)) return ( np.expand_dims(pk * k ** 3, axis=1) * np.exp(-(y ** 2)) @@ -123,7 +123,8 @@ def integrand(logk): k = np.exp(logk) y = np.outer(k, 1.0 / k_nl) pk = linear_matter_power(cosmo, k, transfer_fn=transfer_fn) - g = np.expand_dims(bkgrd.growth_factor(cosmo, np.atleast_1d(a)), 0) + _, g = bkgrd.growth_factor(cosmo, np.atleast_1d(a)) + g = np.expand_dims(g, 0) res = ( np.expand_dims(pk * k ** 3, axis=1) * np.exp(-(y ** 2)) diff --git a/jax_cosmo/probes.py b/jax_cosmo/probes.py index fe47895..825f4a0 100644 --- a/jax_cosmo/probes.py +++ b/jax_cosmo/probes.py @@ -26,7 +26,7 @@ def weak_lensing_kernel(cosmo, pzs, z, ell): z = np.atleast_1d(z) zmax = max([pz.zmax for pz in pzs]) # Retrieve comoving distance corresponding to z - chi = bkgrd.radial_comoving_distance(cosmo, z2a(z)) + cosmo, chi = bkgrd.radial_comoving_distance(cosmo, z2a(z)) # Extract the indices of pzs that can be treated as extended distributions, # and the ones that need to be treated as delta functions. @@ -45,7 +45,7 @@ def weak_lensing_kernel(cosmo, pzs, z, ell): @vmap def integrand(z_prime): - chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime)) + _, chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime)) # Stack the dndz of all redshift bins dndz = np.stack([pzs[i](z_prime) for i in pzs_extended_idx], axis=0) return dndz * np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0) @@ -56,7 +56,7 @@ def integrand(z_prime): @vmap def integrand_single(z_prime): - chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime)) + _, chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime)) return np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0) radial_kernels.append( @@ -73,7 +73,8 @@ def integrand_single(z_prime): constant_factor = 3.0 * const.H0 ** 2 * cosmo.Omega_m / 2.0 / const.c # Ell dependent factor ell_factor = np.sqrt((ell - 1) * (ell) * (ell + 1) * (ell + 2)) / (ell + 0.5) ** 2 - return constant_factor * ell_factor * radial_kernel + kernel = constant_factor * ell_factor * radial_kernel + return kernel @jit @@ -121,9 +122,8 @@ def nla_kernel(cosmo, pzs, bias, z, ell): radial_kernel = dndz * b * bkgrd.H(cosmo, z2a(z)) # Apply common A_IA normalization to the kernel # Joachimi et al. (2011), arXiv: 1008.3491, Eq. 6. - radial_kernel *= ( - -(5e-14 * const.rhocrit) * cosmo.Omega_m / bkgrd.growth_factor(cosmo, z2a(z)) - ) + _, D = bkgrd.growth_factor(cosmo, z2a(z)) + radial_kernel *= -(5e-14 * const.rhocrit) * cosmo.Omega_m / D # Constant factor constant_factor = 1.0 # Ell dependent factor @@ -136,16 +136,16 @@ class WeakLensing(container): """ Class representing a weak lensing probe, with a bunch of bins - Parameters: - ----------- + Parameters + ---------- redshift_bins: list of nzredshift distributions ia_bias: (optional) if provided, IA will be added with the NLA model, either a single bias object or a list of same size as nzs multiplicative_bias: (optional) adds an (1+m) multiplicative bias, either single value or list of same length as redshift bins - Configuration: - -------------- + Configuration + ------------- sigma_e: intrinsic galaxy ellipticity """ @@ -191,8 +191,8 @@ def kernel(self, cosmo, z, ell): """ Compute the radial kernel for all nz bins in this probe. - Returns: - -------- + Returns + ------- radial_kernel: shape (nbins, nz) """ z = np.atleast_1d(z) @@ -229,12 +229,12 @@ def noise(self): class NumberCounts(container): """Class representing a galaxy clustering probe, with a bunch of bins - Parameters: - ----------- + Parameters + ---------- redshift_bins: nzredshift distributions - Configuration: - -------------- + Configuration + ------------- has_rsd.... """ @@ -262,8 +262,8 @@ def n_tracers(self): def kernel(self, cosmo, z, ell): """Compute the radial kernel for all nz bins in this probe. - Returns: - -------- + Returns + ------- radial_kernel: shape (nbins, nz) """ z = np.atleast_1d(z) diff --git a/jax_cosmo/redshift.py b/jax_cosmo/redshift.py index e02e9c7..6938167 100644 --- a/jax_cosmo/redshift.py +++ b/jax_cosmo/redshift.py @@ -62,8 +62,8 @@ def tree_unflatten(cls, aux_data, children): @register_pytree_node_class class smail_nz(redshift_distribution): """Defines a smail distribution with these arguments - Parameters: - ----------- + Parameters + ---------- a: b: @@ -81,8 +81,8 @@ def pz_fn(self, z): @register_pytree_node_class class delta_nz(redshift_distribution): """Defines a single plane redshift distribution with these arguments - Parameters: - ----------- + Parameters + ---------- z0: """ @@ -102,13 +102,13 @@ class kde_nz(redshift_distribution): given catalog currently uses a Gaussian kernel. TODO: add more if necessary - Parameters: - ----------- + Parameters + ---------- zcat: redshift catalog weights: weight for each galaxy between 0 and 1 - Configuration: - -------------- + Configuration + ------------- bw: Bandwidth for the KDE Example: diff --git a/jax_cosmo/transfer.py b/jax_cosmo/transfer.py index 23def59..2728b85 100644 --- a/jax_cosmo/transfer.py +++ b/jax_cosmo/transfer.py @@ -1,7 +1,6 @@ # This module contains various transfer functions from the literatu import jax.numpy as np -import jax_cosmo.background as bkgrd import jax_cosmo.constants as const __all__ = ["Eisenstein_Hu"] From b9566474c759a4e50a83998e5b009108a20f27ce Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 27 Jan 2022 15:02:59 -0500 Subject: [PATCH 10/13] Add python 3.9 & 3.10 supports and Drop 3.6 --- .github/workflows/pythonpackage.yml | 6 +++--- setup.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 7eabd1c..d3342f8 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8] + python-version: ['3.7', '3.8', '3.9', '3.10'] steps: - uses: actions/checkout@v2 @@ -32,10 +32,10 @@ jobs: shell: bash -l {0} run: | conda config --set always_yes yes - conda install pytest pip + conda install pytest pip conda install -c conda-forge pyccl pip install . - + - name: Test with pytest shell: bash -l {0} run: | diff --git a/setup.py b/setup.py index 5696aa6..8bd6a77 100644 --- a/setup.py +++ b/setup.py @@ -17,14 +17,16 @@ author="jax-cosmo developers", packages=find_packages(), url="https://github.com/DifferentiableUniverseInitiative/jax_cosmo", + python_requires=">=3.7", install_requires=["jax", "jaxlib"], tests_require=["pyccl"], use_scm_version=True, setup_requires=["setuptools_scm"], classifiers=[ - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "License :: OSI Approved :: MIT License", "Operating System :: MacOS", "Operating System :: POSIX :: Linux", From defcb57768eaaa3cf47895819c1e13d225c97e06 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 27 Jan 2022 17:39:28 -0500 Subject: [PATCH 11/13] Use Dict instead of dict for python < 3.10 --- jax_cosmo/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index 22ea304..9439af8 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -8,6 +8,7 @@ from functools import partial from pprint import pformat from typing import Any +from typing import Dict # use dict instead for python >= 3.10 from typing import Optional import jax.numpy as np @@ -104,7 +105,7 @@ class Cosmology: # cache for intermediate computations; # users should not access it directly but use the class methods instead - _cache: dict[str, Any] = field(default_factory=dict, repr=False, compare=False) + _cache: Dict[str, Any] = field(default_factory=dict, repr=False, compare=False) # configuration parameters, immutable (frozen dataclass) config: Configuration = field(default_factory=Configuration) From 9758f322ac402316b2817d167b94b83f726ba55b Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 27 Jan 2022 21:04:29 -0500 Subject: [PATCH 12/13] Speed up github actions with mamba --- .github/workflows/pythonpackage.yml | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index d3342f8..7511db9 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -15,25 +15,24 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10'] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v2 - uses: conda-incubator/setup-miniconda@v2 with: + miniforge-variant: Mambaforge + use-mamba: true python-version: ${{ matrix.python-version }} - channels: conda-forge,defaults - channel-priority: strict - show-channel-urls: true auto-update-conda: true + show-channel-urls: true - - name: Install dependencies + - name: Install package and dependencies shell: bash -l {0} run: | conda config --set always_yes yes - conda install pytest pip - conda install -c conda-forge pyccl + mamba install pytest pyccl pip install . - name: Test with pytest From c1821b61a1e0bb73165fa955eb4e726fc68c701c Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 27 Jan 2022 21:09:36 -0500 Subject: [PATCH 13/13] Speed up github actions with pytest-xdist --- .github/workflows/pythonpackage.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 7511db9..49dd9b5 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -32,10 +32,10 @@ jobs: shell: bash -l {0} run: | conda config --set always_yes yes - mamba install pytest pyccl + mamba install pytest pytest-xdist pyccl pip install . - name: Test with pytest shell: bash -l {0} run: | - pytest + pytest -n auto