-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Notes for improvements #81
Comments
I was testing using custom pytree aux_data to store config parameters but saw some very strange behavior: from jax import jit
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class C:
def __init__(self, p, config={'a': 0}):
self.p = p
self._config = config
def tree_flatten(self):
children = (self.p,)
aux_data = self._config
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
c = cls(*children)
c._config.update(aux_data)
return c
def f(c):
c.p = 1. + c.p
return c
g = jit(f)
(
vars(C(0.)),
vars(f(C(0.))),
vars(f(C(0., {'a': 1}))),
vars(g(C(0., {'b': 2}))),
) returning ({'p': 0.0, '_config': {'a': 0, 'b': 2}},
{'p': 1.0, '_config': {'a': 0, 'b': 2}},
{'p': 1.0, '_config': {'a': 1}},
{'p': DeviceArray(1., dtype=float32, weak_type=True), '_config': {'a': 0, 'b': 2}}) Since pytree docs are pretty incomplete, I am still worried if including configuration parameters as |
hummmmmm not sure I see the problem? What result where you expecting for the jitted function? |
E.g., where are the |
I am still wondering: are there some examples where optimizing I feel the |
Okay, that was just me not knowing one should not use mutable default arguments in python.... 😅 |
Discussing with @eelregit here are few ideas of things to improve:
include_logdet
flag in gaussian_log_likelihood is reversedtransverse_comoving_distance
is actually jittableThe text was updated successfully, but these errors were encountered: