Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cosmology, Cache, and Configuration data model #86

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

eelregit
Copy link
Contributor

  • frozen dataclass is semi-immutable
  • aux_fields can be specified to be the pytree aux_data
    • I will add a Cosmology.config using this
    • this might be more flexible than the Container, worth switching?
  • cached intermediate results survive through JAX transformation unflattening.
    Are there cases where this is not desirable?

This makes Cosmology semi-immutable, and allows cached results to
survive through unflattening of JAX transformations.
@eelregit eelregit changed the title (WIP) Pytree dataclass as data containers Pytree dataclass as data containers Jan 26, 2022
@eelregit
Copy link
Contributor Author

eelregit commented Jan 26, 2022

Dataclass is introduced in python 3.7 though, maybe most people have moved on from 3.6?

Edit: 3.6 dropped and 3.9 & 3.10 added. CI are made faster.
Edit2: jax requires python >= 3.7 now

@eelregit
Copy link
Contributor Author

eelregit commented Jan 26, 2022

With caches, many functions in background.py were not pure functions.
7c848b2 attempts to fix this, but it breaks the current API, and requires
cosmo, out = func(cosmo, in) type signatures on many (or probably all for convenience)
functions.

@EiffL What do you think about this?

Right now the cache is a dictionary, so there can still be side effects.
Do you think it's fine to use some immutable container for cache/workspace too?

Relevant discussion: jax-ml/jax#5344 (comment).

@eelregit eelregit changed the title Pytree dataclass as data containers Pytree dataclass and Configuration data model Jan 26, 2022
@eelregit eelregit changed the title Pytree dataclass and Configuration data model Cosmology, Cache, and Configuration data model Jan 26, 2022
@eelregit eelregit force-pushed the eelregit_data_patch branch from 6e099e3 to 7c848b2 Compare January 27, 2022 00:47
@eelregit eelregit force-pushed the eelregit_data_patch branch from 0dd2f76 to 0056535 Compare January 27, 2022 18:48
@eelregit eelregit force-pushed the eelregit_data_patch branch from 756422f to b956647 Compare January 27, 2022 22:24
@EiffL
Copy link
Member

EiffL commented Jan 28, 2022

Thanks @eelregit there is a lot of great things in there ^^! The cache and dataclass looks nice.

And so, yeah the way I see it there is a tradeoff between making pure functions or having a simple API....

The only drawback of the current implementation is in the following case:

cosmo = jc.Planck15()
x = jitted_function1(cosmo, ...)
y = jitted_function2(cosmo, ...)

in that case the cache computed by the first function is not communicated to the second one, so you do some of the cosmology computation twice, but it doesnt lead to any wrong results.

To avoid this and be able to reuse the cache I would just then recommend to write that same code this way:

cosmo = jc.Planck15()

@jax.jit
def my_fun(cosmo):
  x = function1(cosmo, ...)
  y = function2(cosmo, ...)
  return x,y

my_fun(cosmo)

In practice in many cases you would just jit the likelihood or the simulation code itself and then you have no problem.

So the question is whether allowing for using the cache over jitted functions is worth changing the API to have functions return the cosmology object...

I'm leaning towards keeping a simple interface:

chi = bkgrd.radial_comoving_distance(cosmo, a)

instead of

cosmo, chi = bkgrd.radial_comoving_distance(cosmo, a)

just because it appears very suprising to a typical user.

@EiffL
Copy link
Member

EiffL commented Jan 28, 2022

Unless you have a compeling use case that really would benefit from the more optimiized implementation.

I'm also thinking it could be an option/config to have by default the non-pure API, but if an advanced user wants it, they could retrieved the cosmology and associated cache. What do you think?

@eelregit
Copy link
Contributor Author

eelregit commented Jan 28, 2022

Thanks @EiffL !

The previous non-pure API does not allow functional cache in jitted inner functions like out = func(cosmo, in), which is the case in pmwd unfortunately. pmwd needs that for both functional (e.g. if I/O is needed between time steps) and performance reasons (for looping time steps is faster than scanning them). (Besides, inner jitted functions may be quite common, e.g. many jax.numpy or lax functions are.)

What do you think about the second pattern in jax-ml/jax#5344 (comment) ?
That separates init and eval, which is also cumbersome but quite common interfaces.
I am sure there should be some way to make the APIs compatible, right?

@EiffL
Copy link
Member

EiffL commented Jan 28, 2022

Hummmm we could precompute everything at the instantiation of the cosmology object... We could imagine a mechanism that "registers" all functions that use cached values and computes the cache before anything else happens...

Then the user API would stay the same, the functions would be pure.

But.... It would mean that creating a cosmology would be slow for Interactive users....

Hummmm

@EiffL
Copy link
Member

EiffL commented Jan 28, 2022

And we could have an option to decide which type of execution you want, one that plays nicely with jitted functions, and one that sticks to the current behavior for easy interactive use.

@eelregit
Copy link
Contributor Author

eelregit commented Jan 28, 2022

Something like the following?

def compute_y(cosmo, x):
    # initialize cache and output cosmo with cache if input is None
    if x is None:
        if cosmo.is_cached(key):
            return cosmo
        value = ...
        return cosmo.cache_set(key, value)

    if not cosmo.is_cached(key):
        cosmo = comput_y.init(cosmo)    # or more strictly just raise runtimeerror?

    value = cosmo.cache_get(key)
    y = ...
    return y

# and/or something more explicit like
compute_y.init = partial(compute_y, x=None)

with some global Cosmology cache initialization like

class Cosmology:
    ...
    def cache_init(self, *args):
        cosmo = self
        for compute_y in args:
            cosmo = compute_y.init(cosmo)
        return cosmo

Contributor should use compute_y.init first in their probe's.
And interactive users can call the global init to speed things up

cosmo = Planck15()
cosmo = cosmo.cache_init(compute_y, compute_z)

and are encouraged to think functionally.

Maybe we can iterate on this to find convergence ^^

@eelregit
Copy link
Contributor Author

eelregit commented Jan 28, 2022

If functools.lru_cache works with JAX transformations, that would be nice and simple

@lru_cache
def precompute_y(cosmo):
    table = ...
    return table

def compute_y(cosmo, x):
    table = precompute_y(cosmo)
    y = ...
    return y

With this it seems like everything can be pure and one doesn't need to touch Cosmology once instantiated?

@eelregit
Copy link
Contributor Author

Unfortunately, lru_cache doesn't work with tracing:

from functools import lru_cache
from typing import NamedTuple

class C(NamedTuple):
    min: float = 0.
    max: float = 1.

@lru_cache()
def f(c):
    return jnp.linspace(c.min, c.max, 6)

@jit
def g(c, w, b):
    return b + w * f(c)

g(C(), 1., 0.)

results in

TypeError: unhashable type: 'DynamicJaxprTracer'

A similar issue in numba: numba/numba#4062

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants