Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migration integration: ClenshawCurtis #92

Open
jecampagne opened this issue Apr 10, 2022 · 1 comment
Open

Migration integration: ClenshawCurtis #92

jecampagne opened this issue Apr 10, 2022 · 1 comment

Comments

@jecampagne
Copy link
Collaborator

jecampagne commented Apr 10, 2022

Currently there are several integration methods:
o in scipy/integrate.py: Romberg, Simpson used for instance in angular_cl.py and power.py...
o in scipy/ode.py : Rugge-Kutta used for instance in background.py (nb. odeintis included in core.pybut not used)

I propose to revisit this by using a ClenshawCurtis Quadrature which can be used very similarly as the Simpson code, and the purpose is to decrease the number of points for the same level of accuracy.

Here is the class and a function as well as some tests.

class ClenshawCurtisQuad:
    """
        Clenshaw-Curtis quadrature of order (2n-1) those abscissa and weights are computed by FFT
       (by default we also compute the error weights)
       The ascissa & weights are for a [0,1] interval and should be rescaled on purpose
    """

    def __init__(self,order=5):
        # 2n-1 quad
        self._order = jnp.int64(2*order-1)
        self._absc, self._absw, self._errw = self.ComputeAbsWeights()
        self.rescaleAbsWeights()
    
    def __str__(self):
        return f"xi={self._absc}\n wi={self._absw}\n errwi={self._errw}"
    
    @property
    def absc(self):
        return self._absc
    @property
    def absw(self):
        return self._absw
    @property
    def errw(self):
        return self._errw
        
    def ComputeAbsWeights(self):
        x,wx = self.absweights(self._order)
        nsub = (self._order+1)//2
        xSub, wSub = self.absweights(nsub)
        errw = jnp.array(wx, copy=True)                           # np.copy(wx)
        errw=errw.at[::2].add(-wSub)             # errw[::2] -= wSub
        return x,wx,errw
  
    
    def absweights(self,n):

        points = -jnp.cos((jnp.pi * jnp.arange(n)) / (n - 1))

        if n == 2:
            weights = jnp.array([1.0, 1.0])
            return points, weights
            
        n -= 1
        N = jnp.arange(1, n, 2)
        length = len(N)
        m = n - length
        v0 = jnp.concatenate([2.0 / N / (N - 2), jnp.array([1.0 / N[-1]]), jnp.zeros(m)])
        v2 = -v0[:-1] - v0[:0:-1]
        g0 = -jnp.ones(n)
        g0 = g0.at[length].add(n)     # g0[length] += n
        g0 = g0.at[m].add(n)          # g0[m] += n
        g = g0 / (n ** 2 - 1 + (n % 2))

        w = jnp.fft.ihfft(v2 + g)
        ###assert max(w.imag) < 1.0e-15
        w = w.real

        if n % 2 == 1:
            weights = jnp.concatenate([w, w[::-1]])
        else:
            weights = jnp.concatenate([w, w[len(w) - 2 :: -1]])
            
        #return
        return points, weights
    
    def rescaleAbsWeights(self, xInmin=-1.0, xInmax=1.0, xOutmin=0.0, xOutmax=1.0):
        """
            Translate nodes,weights for [xInmin,xInmax] integral to [xOutmin,xOutmax] 
        """
        deltaXIn = xInmax-xInmin
        deltaXOut= xOutmax-xOutmin
        scale = deltaXOut/deltaXIn
        self._absw *= scale
        tmp = jnp.array([((xi-xInmin)*xOutmax
                         -(xi-xInmax)*xOutmin)/deltaXIn for xi in self._absc])
        self._absc=tmp

A integration routine with an API close to the simp code by the way we allow the possibility
to integrate functions with optional args, kargs.

@partial(jit, static_argnums=(0,3,4,5))
def quadIntegral(f,a,b,quad, f_args=(), f_kargs={}):
    a = jnp.atleast_1d(a)
    b = jnp.atleast_1d(b)
    d = b-a
    xi = a[jnp.newaxis,:]+ jnp.einsum('i...,k...->ik...',quad.absc,d)
    fi = f(xi, *f_args, **f_kargs)
    S = d * jnp.einsum('i...,i...',quad.absw,fi)
    return S.squeeze()

Some tests:

quad=ClenshawCurtisQuad(100)  # 199 pts
def func(x):
    return x**(1/10) * jnp.exp(-x)
a = 0.
b = a+0.5
res_simp= jax_simps(func, a, b, N=2**15) # 32768 pts
res_cc= quadIntegral(func,a,b,quad)
res_true = jnp.exp(jsc.special.gammaln(1.+1./10)) *(1.-jsc.special.gammaincc(1.+1./10,0.5))

print(f"Simp-True: {res_true-res_simp:.3e},CC-True: {res_true-res_cc:.3e}")
# Simp-True: 1.300e-06,CC-True: 1.610e-06

# function familly
@jit
def jax_funcN(x):
    return jnp.stack([x**(i/10) * jnp.exp(-x) for i in range(50)],axis=1)

# set of integration intervals [a,a+1/2] a:0,0.1,0.2...
ja = jnp.arange(0,10,0.5)
jb = ja+0.5
res_cc = quadIntegral(jax_funcN,ja,jb,quad)
res_sim = jax_simps(jax_funcN,ja,jb,N=2**15)
np.allclose(res_cc,res_sim,rtol=0.,atol=1e-6) 
#True

So it demonstrates that with 200 pts CC gives the same accuracy of Simpson with 2**15=32768 points.

Hope that we can migrate progressively to speed up the code.

@jecampagne
Copy link
Collaborator Author

In addition I have setup an incremental integration routine along the line of odeint which uses Runge-Kutta-4. Here is the implementation using CC with 5pts (order=3, npts=2*order-1):

def incremental_int(fn,y0,t, order=3):
    quad = ClenshawCurtisQuad(order)
    def integ(carry,t):
        y, t_prev = carry
        y = y + quadIntegral(fn,t_prev,t,quad)
        return (y,t),y
    (yf, _), y = jax.lax.scan(integ, (y0, jnp.array(t[0])), t)
    return y

Here is a usage for radial_comoving_distance where I have used the workspace to store different approx:

def radial_comoving_distance(cosmo, a, log10_amin=-3, steps=256):
    """
    \chi(a) =  R_H \int_a^1 \frac{da^\prime}{{a^\prime}^2 E(a^\prime)}
    """
    # Check if distances have already been computed
    if not "background.radial_comoving_distance" in cosmo._workspace.keys():
        # Compute tabulated array
        atab = jnp.logspace(log10_amin, 0.0, steps)

        def dchioverdlna_ode(y,x):
            xa = jnp.exp(x)
            return dchioverda(cosmo, xa) * xa

        def dchioverdlna(x):
            xa = jnp.exp(x)
            return dchioverda(cosmo, xa) * xa

        
        quad = ClenshawCurtisQuad(150)
        chitab_cc = -quadIntegral(dchioverdlna, 0.0, jnp.log(atab), quad)
        
        chitab_cc_inc = incremental_int(dchioverdlna, 0.0, jnp.log(atab), order=3)
        chitab_cc_inc = chitab_cc_inc[-1]-chitab_cc_inc
        
        chitab_simp = -simps(dchioverdlna, 0.0, jnp.log(atab),N=2**15)   ## here just for fun ...
        
        chitab = odeint(dchioverdlna_ode, 0.0, jnp.log(atab))
        chitab = chitab[-1] - chitab

        cache = {"a": atab, "chi": chitab,   "chi_cc":chitab_cc, 
                 "chi_cc_inc":chitab_cc_inc,
                 "chi_simp":chitab_simp}
        cosmo._workspace["background.radial_comoving_distance"] = cache
    else:
        cache = cosmo._workspace["background.radial_comoving_distance"]

    a = np.atleast_1d(a)
    # Return the results as an interpolation of the table
    return jnp.clip(jnp.interp(a, cache["a"], cache["chi"]), 0.0)

Then, if one uses a std cosmology:

cosmo_jax = Cosmology(
    Omega_c=0.3,
    Omega_b=0.05,
    h=0.7,
    Omega_k=0.0,
    w0=-1.0,
    wa=0.0  
)
z = jnp.logspace(-2, 3,100)
chi_z = radial_comoving_distance(cosmo_jax, z2a(z))/ cosmo_jax.h
trans_com_z = transverse_comoving_distance(cosmo_jax, z2a(z))/ cosmo_jax.h
ang_diam_z= angular_diameter_distance(cosmo_jax, z2a(z))/ cosmo_jax.h
lum_z = luminosity_distance(cosmo_jax, z2a(z)) / cosmo_jax.h

plt.figure(figsize=(8,8))
fact = 0.0032615 #Mpc -> Gly
plt.plot(z,fact*chi_z, label=r"$\chi(z)$")
plt.plot(z,fact*trans_com_z,ls="--",lw=3, label=rf"$f_k(z)$ (k={cosmo_jax.k})")
plt.plot(z,fact*ang_diam_z, label=r"$d_A(z)$")
plt.plot(z,fact*lum_z, label=r"$d_L(z)$")
plt.xlabel("redshift (z)")
plt.ylabel("Distance (Gly)")
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.grid()

image

And so analysing the different chi(a) approx:

dico = cosmo_jax._workspace["background.radial_comoving_distance"]
print(len(dico["a"]))
plt.plot(dico["a"],dico["chi"],label="chi (ODE)")
plt.plot(dico["a"],dico["chi_cc"],ls="--",lw=3,label="chi (CC) $300$pts")
plt.plot(dico["a"],dico["chi_cc_inc"],ls="--",lw=3,label="chi (CC Incr) $5$pts")
plt.plot(dico["a"],dico["chi_simp"],ls="--",lw=3,label=r"chi (Simp) $2^{15}$pts")
plt.legend()

image

dico = cosmo_jax._workspace["background.radial_comoving_distance"]
plt.plot(dico["a"],dico["chi"]-dico["chi_cc"],label="chi (ODE) - chi (CC)")
plt.plot(dico["a"],dico["chi"]-dico["chi_simp"],ls="--",lw=3,
         label=r"chi (ODE) - chi (Simp)")
plt.plot(dico["a"],dico["chi"]-dico["chi_cc_inc"],ls="--",lw=3,
         label=r"chi (ODE) - chi (CC inc)")
plt.legend()

image

plt.plot(dico["a"],dico["chi_cc"]-dico["chi_simp"],label="chi (CC) - chi (Simp)")
plt.plot(dico["a"],dico["chi_cc"]-dico["chi_cc_inc"],label="chi (CC) - chi (CC inc)")
plt.legend()

image

Now, odeint uses 4x256 calls to the function (ie. 4 pts per subinterval), CC incremental uses 5 pts per subinterval, but with a clear improvement on accuracy if one compares to the 'true values' consisting of CC which uses 300 pts for each of the [0,a_i] i:0,...,255 intervals (or 2**15 pts for Simpson for each [0,a_i]).

So, I think that this incremental integration using a 5pts CC quadrature on each sub_intervalle is a good candidate top replace odeint-rk4.

For the seek of completeness CC incremental with 3pts quadrature is very close to the odeint-rk4with 1 pts less per sub-interval.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant