-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't vmap across input using Gauss Newton fwd #63
Comments
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
Related: I opened jax-ml/jax#21581 |
Here is my latest iteration, still poking at the two lines in import pytest
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import equinox as eqx
import lineax as lx
def _no_nan(x):
"""Compied from test/helpers.py in diffrax."""
if eqx.is_array(x):
return x.at[jnp.isnan(x)].set(8.9568) # arbitrary magic value
else:
return x
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
"""Copied from test/helpers.py in diffrax."""
if equal_nan:
x = jtu.tree_map(_no_nan, x)
y = jtu.tree_map(_no_nan, y)
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)
def fn(y):
def shifted_parabola(x0):
x = jnp.linspace(0, 10)
return (x - x0)**2
true = shifted_parabola(2.) # True value
fit = shifted_parabola(y)
return true - fit
def aux_wrapper(y):
return fn(y), None
y0 = 1. # starting guess
y0s = jnp.arange(0., 4., 0.1) # Many initial values
# Get jacobians the simple way
_, jac_of_fn = jax.linearize(fn, y0)
_, jac_of_aux_wrapper, _ = jax.linearize(aux_wrapper, y0, has_aux=True)
assert tree_allclose(jax.jacfwd(fn)(y0), jac_of_fn(y0))
assert tree_allclose(jax.jacfwd(fn)(y0), jac_of_aux_wrapper(y0))
vmapped_jac_of_fn = jax.vmap(jac_of_fn)(y0s) # Does not raise error
vmapped_jac_of_aux_wrapper = jax.vmap(jac_of_aux_wrapper)(y0s) # Does not raise error
# Now include the lambda function in jax.linearize (status quo in optimistix)
_, jac_of_fn_with_lambda = jax.linearize(lambda _y: fn(_y), y0)
_, jac_of_aux_wrapper_with_lambda, _ = jax.linearize(lambda _y: aux_wrapper(_y), y0, has_aux=True)
vmapped_jac_of_fn_with_lambda = jax.vmap(jac_of_fn_with_lambda)(y0s) # Does not raise error
vmapped_jac_of_aux_wrapper_with_lambda = jax.vmap(jac_of_aux_wrapper_with_lambda)(y0s) # Does not raise error
# Context: using lambda functions produces subtle difference in pytrees, not legible when examining pytreedef (as a human)
with pytest.raises(AssertionError):
assert jtu.tree_structure(jac_of_fn) == jtu.tree_structure(jac_of_fn_with_lambda)
assert str(jtu.tree_structure(jac_of_fn)) == str(jtu.tree_structure(jac_of_fn_with_lambda))
# Create a lineax Linear Operator
def lin_fun(y):
return 2 * y
lin_op = lx.FunctionLinearOperator(lin_fun, jax.eval_shape(lin_fun, y0)) # Confirm that it works in this case
with pytest.raises(ValueError): # I don't understand why it does not work in these cases
lx.FunctionLinearOperator(jac_of_fn, jax.eval_shape(jac_of_fn, y0))
lx.FunctionLinearOperator(jac_of_fn_with_lambda, jax.eval_shape(jac_of_fn_with_lambda, y0)) |
A simple workaround: dataclasses.replace() the offending member of the output |
The tricky thing is that I can't figure out what the offending member is. |
And it has now been shown that it is a deeper issue in |
I think I understand what's going on here. The output of The solution is pretty simple: use Does this solve the issues everyone is facing? |
Oh dear :D It does solve my issue. I was actually in the process of replacing all |
Vmapping across
y0
with any method usingAbstractGaussNewton
throws a TypeError. MWEThe reason of this looks to be that the state includes an
f_info
with a FunctionLinearOperator whose linearised function is a Jaxpr which can't be batched over.The text was updated successfully, but these errors were encountered: