Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notes for improvements #81

Open
2 of 7 tasks
EiffL opened this issue Jan 19, 2022 · 5 comments · Fixed by #82
Open
2 of 7 tasks

Notes for improvements #81

EiffL opened this issue Jan 19, 2022 · 5 comments · Fixed by #82

Comments

@EiffL
Copy link
Member

EiffL commented Jan 19, 2022

Discussing with @eelregit here are few ideas of things to improve:

  • Allow parameterisation in terms of As
  • Allow for flattening of the cosmology object
  • Switch to jax.numpy.interp !
  • Try to use jax.experimental.odeint instead of jax_cosmo.ode
  • Configuration parameters stored in cosmo structure
  • include_logdet flag in gaussian_log_likelihood is reversed
  • Not sure if transverse_comoving_distance is actually jittable
@eelregit
Copy link
Contributor

I was testing using custom pytree aux_data to store config parameters but saw some very strange behavior:

from jax import jit
from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class C:
    def __init__(self, p, config={'a': 0}):
        self.p = p
        self._config = config

    def tree_flatten(self):
        children = (self.p,)
        aux_data = self._config
        return children, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        c = cls(*children)
        c._config.update(aux_data)
        return c
    
def f(c):
    c.p = 1. + c.p
    return c

g = jit(f)

(
    vars(C(0.)),
    vars(f(C(0.))),
    vars(f(C(0., {'a': 1}))),
    vars(g(C(0., {'b': 2}))),
)

returning

({'p': 0.0, '_config': {'a': 0, 'b': 2}},
 {'p': 1.0, '_config': {'a': 0, 'b': 2}},
 {'p': 1.0, '_config': {'a': 1}},
 {'p': DeviceArray(1., dtype=float32, weak_type=True), '_config': {'a': 0, 'b': 2}})

Since pytree docs are pretty incomplete, I am still worried if including configuration parameters as aux_data is intended by JAX devs.
I should probably open an issue there.

@EiffL EiffL linked a pull request Jan 20, 2022 that will close this issue
@EiffL EiffL closed this as completed in #82 Jan 20, 2022
@EiffL EiffL reopened this Jan 20, 2022
@EiffL
Copy link
Member Author

EiffL commented Jan 20, 2022

hummmmmm not sure I see the problem? What result where you expecting for the jitted function?

@eelregit
Copy link
Contributor

E.g., where are the 'b': 2 entries (except the last one) coming from?

@eelregit
Copy link
Contributor

I am still wondering: are there some examples where optimizing sigma_8 instead of A_s is better?

I feel the sigma_8 coordinate system would make an optimizer focus too much on the 8 Mpc/h scale, in a way that all parameters are adjusted to prioritize on that agreement. So it doesn't feel natural and may not be a good default in general?

@eelregit
Copy link
Contributor

Okay, that was just me not knowing one should not use mutable default arguments in python.... 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants