-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migration integration: ClenshawCurtis #92
Comments
In addition I have setup an incremental integration routine along the line of def incremental_int(fn,y0,t, order=3):
quad = ClenshawCurtisQuad(order)
def integ(carry,t):
y, t_prev = carry
y = y + quadIntegral(fn,t_prev,t,quad)
return (y,t),y
(yf, _), y = jax.lax.scan(integ, (y0, jnp.array(t[0])), t)
return y Here is a usage for def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=256):
"""
\chi(a) = R_H \int_a^1 \frac{da^\prime}{{a^\prime}^2 E(a^\prime)}
"""
# Check if distances have already been computed
if not "background.radial_comoving_distance" in cosmo._workspace.keys():
# Compute tabulated array
atab = jnp.logspace(log10_amin, 0.0, steps)
def dchioverdlna_ode(y,x):
xa = jnp.exp(x)
return dchioverda(cosmo, xa) * xa
def dchioverdlna(x):
xa = jnp.exp(x)
return dchioverda(cosmo, xa) * xa
quad = ClenshawCurtisQuad(150)
chitab_cc = -quadIntegral(dchioverdlna, 0.0, jnp.log(atab), quad)
chitab_cc_inc = incremental_int(dchioverdlna, 0.0, jnp.log(atab), order=3)
chitab_cc_inc = chitab_cc_inc[-1]-chitab_cc_inc
chitab_simp = -simps(dchioverdlna, 0.0, jnp.log(atab),N=2**15) ## here just for fun ...
chitab = odeint(dchioverdlna_ode, 0.0, jnp.log(atab))
chitab = chitab[-1] - chitab
cache = {"a": atab, "chi": chitab, "chi_cc":chitab_cc,
"chi_cc_inc":chitab_cc_inc,
"chi_simp":chitab_simp}
cosmo._workspace["background.radial_comoving_distance"] = cache
else:
cache = cosmo._workspace["background.radial_comoving_distance"]
a = np.atleast_1d(a)
# Return the results as an interpolation of the table
return jnp.clip(jnp.interp(a, cache["a"], cache["chi"]), 0.0) Then, if one uses a std cosmology: cosmo_jax = Cosmology(
Omega_c=0.3,
Omega_b=0.05,
h=0.7,
Omega_k=0.0,
w0=-1.0,
wa=0.0
)
z = jnp.logspace(-2, 3,100)
chi_z = radial_comoving_distance(cosmo_jax, z2a(z))/ cosmo_jax.h
trans_com_z = transverse_comoving_distance(cosmo_jax, z2a(z))/ cosmo_jax.h
ang_diam_z= angular_diameter_distance(cosmo_jax, z2a(z))/ cosmo_jax.h
lum_z = luminosity_distance(cosmo_jax, z2a(z)) / cosmo_jax.h
plt.figure(figsize=(8,8))
fact = 0.0032615 #Mpc -> Gly
plt.plot(z,fact*chi_z, label=r"$\chi(z)$")
plt.plot(z,fact*trans_com_z,ls="--",lw=3, label=rf"$f_k(z)$ (k={cosmo_jax.k})")
plt.plot(z,fact*ang_diam_z, label=r"$d_A(z)$")
plt.plot(z,fact*lum_z, label=r"$d_L(z)$")
plt.xlabel("redshift (z)")
plt.ylabel("Distance (Gly)")
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.grid() And so analysing the different dico = cosmo_jax._workspace["background.radial_comoving_distance"]
print(len(dico["a"]))
plt.plot(dico["a"],dico["chi"],label="chi (ODE)")
plt.plot(dico["a"],dico["chi_cc"],ls="--",lw=3,label="chi (CC) $300$pts")
plt.plot(dico["a"],dico["chi_cc_inc"],ls="--",lw=3,label="chi (CC Incr) $5$pts")
plt.plot(dico["a"],dico["chi_simp"],ls="--",lw=3,label=r"chi (Simp) $2^{15}$pts")
plt.legend() dico = cosmo_jax._workspace["background.radial_comoving_distance"]
plt.plot(dico["a"],dico["chi"]-dico["chi_cc"],label="chi (ODE) - chi (CC)")
plt.plot(dico["a"],dico["chi"]-dico["chi_simp"],ls="--",lw=3,
label=r"chi (ODE) - chi (Simp)")
plt.plot(dico["a"],dico["chi"]-dico["chi_cc_inc"],ls="--",lw=3,
label=r"chi (ODE) - chi (CC inc)")
plt.legend() plt.plot(dico["a"],dico["chi_cc"]-dico["chi_simp"],label="chi (CC) - chi (Simp)")
plt.plot(dico["a"],dico["chi_cc"]-dico["chi_cc_inc"],label="chi (CC) - chi (CC inc)")
plt.legend() Now, So, I think that this incremental integration using a 5pts CC quadrature on each sub_intervalle is a good candidate top replace For the seek of completeness |
Currently there are several integration methods:
o in
scipy/integrate.py
: Romberg, Simpson used for instance inangular_cl.py
andpower.py
...o in
scipy/ode.py
: Rugge-Kutta used for instance inbackground.py
(nb.odeint
is included incore.py
but not used)I propose to revisit this by using a ClenshawCurtis Quadrature which can be used very similarly as the Simpson code, and the purpose is to decrease the number of points for the same level of accuracy.
Here is the class and a function as well as some tests.
A integration routine with an API close to the
simp
code by the way we allow the possibilityto integrate functions with optional
args, kargs
.Some tests:
So it demonstrates that with 200 pts CC gives the same accuracy of Simpson with 2**15=32768 points.
Hope that we can migrate progressively to speed up the code.
The text was updated successfully, but these errors were encountered: