Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't vmap across input using Gauss Newton fwd #63

Open
packquickly opened this issue May 31, 2024 · 11 comments
Open

Can't vmap across input using Gauss Newton fwd #63

packquickly opened this issue May 31, 2024 · 11 comments
Labels
bug Something isn't working

Comments

@packquickly
Copy link
Collaborator

Vmapping across y0 with any method using AbstractGaussNewton throws a TypeError. MWE

import jax
import jax.numpy as jnp

import optimistix as optx


def rosenbrock(x, args):
    del args
    term1 = 10 * (x[1:] - x[:-1] ** 2)
    term2 = x - 1
    return term1, term2


inits = jnp.zeros((4, 10))
solve = lambda x: optx.least_squares(rosenbrock, optx.LevenbergMarquardt(1e-8, 1e-9), x)
out = jax.vmap(solve)(inits)  # throws error

The reason of this looks to be that the state includes an f_info with a FunctionLinearOperator whose linearised function is a Jaxpr which can't be batched over.

@packquickly packquickly added the bug Something isn't working label May 31, 2024
@johannahaffner

This comment was marked as outdated.

@johannahaffner

This comment was marked as outdated.

@johannahaffner

This comment was marked as outdated.

@johannahaffner

This comment was marked as outdated.

@johannahaffner
Copy link
Contributor

Related: I opened jax-ml/jax#21581

@johannahaffner
Copy link
Contributor

Here is my latest iteration, still poking at the two lines in _make_f_info from Gauss Newton.
I could show that the jaxpr is not causing the issue - at least not outside of FunctionLinearOperator, where vmapping over a jacobian that contains a jaxpr and is an output of jax.linearize raises no errata.

import pytest

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

import equinox as eqx
import lineax as lx

def _no_nan(x):
    """Compied from test/helpers.py in diffrax."""
    if eqx.is_array(x):
        return x.at[jnp.isnan(x)].set(8.9568)  # arbitrary magic value
    else:
        return x
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
    """Copied from test/helpers.py in diffrax."""
    if equal_nan:
        x = jtu.tree_map(_no_nan, x)
        y = jtu.tree_map(_no_nan, y)
    return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)

def fn(y):
    def shifted_parabola(x0):
        x = jnp.linspace(0, 10)
        return (x - x0)**2
    true = shifted_parabola(2.)  # True value
    fit = shifted_parabola(y)
    return true - fit

def aux_wrapper(y):
    return fn(y), None

y0 = 1.  # starting guess
y0s = jnp.arange(0., 4., 0.1)  # Many initial values

# Get jacobians the simple way
_, jac_of_fn = jax.linearize(fn, y0)
_, jac_of_aux_wrapper, _ = jax.linearize(aux_wrapper, y0, has_aux=True)
assert tree_allclose(jax.jacfwd(fn)(y0), jac_of_fn(y0)) 
assert tree_allclose(jax.jacfwd(fn)(y0), jac_of_aux_wrapper(y0))  

vmapped_jac_of_fn = jax.vmap(jac_of_fn)(y0s) # Does not raise error
vmapped_jac_of_aux_wrapper = jax.vmap(jac_of_aux_wrapper)(y0s) # Does not raise error

# Now include the lambda function in jax.linearize (status quo in optimistix)   
_, jac_of_fn_with_lambda = jax.linearize(lambda _y: fn(_y), y0)
_, jac_of_aux_wrapper_with_lambda, _ = jax.linearize(lambda _y: aux_wrapper(_y), y0, has_aux=True)

vmapped_jac_of_fn_with_lambda = jax.vmap(jac_of_fn_with_lambda)(y0s) # Does not raise error
vmapped_jac_of_aux_wrapper_with_lambda = jax.vmap(jac_of_aux_wrapper_with_lambda)(y0s) # Does not raise error

# Context: using lambda functions produces subtle difference in pytrees, not legible when examining pytreedef (as a human)
with pytest.raises(AssertionError):
    assert jtu.tree_structure(jac_of_fn) == jtu.tree_structure(jac_of_fn_with_lambda)
assert str(jtu.tree_structure(jac_of_fn)) == str(jtu.tree_structure(jac_of_fn_with_lambda))

# Create a lineax Linear Operator
def lin_fun(y):
    return 2 * y
lin_op = lx.FunctionLinearOperator(lin_fun, jax.eval_shape(lin_fun, y0))  # Confirm that it works in this case

with pytest.raises(ValueError): # I don't understand why it does not work in these cases
    lx.FunctionLinearOperator(jac_of_fn, jax.eval_shape(jac_of_fn, y0))
    lx.FunctionLinearOperator(jac_of_fn_with_lambda, jax.eval_shape(jac_of_fn_with_lambda, y0))

@tjltjl
Copy link

tjltjl commented Jun 5, 2024

A simple workaround: dataclasses.replace() the offending member of the output

@johannahaffner
Copy link
Contributor

The tricky thing is that I can't figure out what the offending member is.

@johannahaffner
Copy link
Contributor

And it has now been shown that it is a deeper issue in jax.linearize, which produces pytrees with nonidentical structure even for identical input functions, called with identical inputs.

@patrick-kidger
Copy link
Owner

I think I understand what's going on here. The output of optx.least_squares includes a jaxpr inside of out.state. This isn't an arraylike object, so JAX doesn't understand how to handle it as an output of the vmap. Morally speaking, what's going on here is the same as jax.vmap(lambda x: object())(...), in which again non-array-like object is being returned.

The solution is pretty simple: use eqx.filter_vmap instead. This passes through all non-array-like objects unchanged. Indeec the use case in this issue is the raison d'etre of eqx.filter_vmap!

Does this solve the issues everyone is facing?

@johannahaffner
Copy link
Contributor

Oh dear :D

It does solve my issue. I was actually in the process of replacing all vmaps with filter_vmaps, but there were still some around. Not anymore, though!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants