Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Interaction with optax LBFGS optimizer #796

Open
bnesposito opened this issue Aug 8, 2024 · 3 comments
Open

[Question] Interaction with optax LBFGS optimizer #796

bnesposito opened this issue Aug 8, 2024 · 3 comments
Labels
question User queries

Comments

@bnesposito
Copy link

Hi,
I am having trouble implementing optax LBFGS with equinox types. I am trying to run a linear regression model using this notebook https://github.com/ubcecon/ECON622/blob/master/lectures/lectures/examples/linear_regression_jax_equinox.py and changing the optimizer to optax.lbfgs(). Here is the core of the code:

def residual(model, x, y):
    y_hat = model(x)
    return (y_hat - y) ** 2

def residuals(model, X, Y):
    batched_residuals = vmap(residual, in_axes=(None, 0, 0))
    return jnp.mean(batched_residuals(model, X, Y))

@eqx.filter_jit
def make_step(model, opt_state, X, Y):     
  loss_value, grads = eqx.filter_value_and_grad(residuals)(model, X, Y)
  resids_closure = lambda x: residual(model, x, Y)
  updates, opt_state = optimizer.update(grads, opt_state, model, value=loss_value, grad=grads, value_fn=resids_closure)
  model = eqx.apply_updates(model, updates)
  return model, opt_state, loss_value

I get the following error:

TypeError: unsupported operand type(s) for @: 'DynamicJaxprTracer' and 'Linear'

From the traceback, the error shows up when evaluating the resids_closure function. Particularly, when I evaluate model(x) and it tries to multiply weights by x:

class MyLinear(eqx.Module):
    weight: jax.Array

    def __init__(self, in_size, out_size, key):
        self.weight = jax.random.normal(key, (out_size, in_size))

    # Equivalent to Pytorch's forward
    def __call__(self, x):
        return self.weight @ x

Does anybody have an idea on how to tackle this? Thanks!

@patrick-kidger
Copy link
Owner

From your description is seems possible that you've gotten two arguments the wrong way around, so that you are passing a Linear layer in as the input x.

I suspect optimizer.update(..., value_fn=...) is meant to accept the parameters as input, not data.

@patrick-kidger patrick-kidger added the question User queries label Aug 8, 2024
@jlperla
Copy link
Contributor

jlperla commented Aug 27, 2024

Hmm.... having some trouble here still (and also, this is not not using the cached optax.value_and_grad_from_state which is part of the performance.

Here is something where shows the current issues I am seeing. Consider: (1) eqx.nn.Linear; (2) a custom nn.Module; and (3) a MLP. And run with LBFGS and with SGD. Everything works except the MLP+lbfgs.

To replicate: run the following code, then comment/uncomment to swap out the optimizer used (near the top of the file) and you will see the error:

# Takes the baseline version and uses vmap, adds in a learning rate scheduler
import jax
import jax.numpy as jnp
from jax import random
import optax
import equinox as eqx
import jax_dataloader as jdl
from jax_dataloader.loaders import DataLoaderJAX


# LLS loss function with vmap
def residual(model, x, y):
    y_hat = model(x)
    return (y_hat - y) ** 2

def residuals(model, X, Y):
    batched_residuals = jax.vmap(residual, in_axes=(None, 0, 0))
    return jnp.mean(batched_residuals(model, X, Y))


# SWITCH OPTIMIZERS HERE!!!!
# reinitialize
optimizer = optax.sgd(0.001)
#optimizer = optax.lbfgs()


N = 500  # samples
M = 2
sigma = 0.0001
key = random.PRNGKey(42)
key, *subkey = random.split(key, num=4)
theta = random.normal(subkey[0], (M,))
X = random.normal(subkey[1], (N, M))
Y = X @ theta + sigma * random.normal(subkey[2], (N,))  # Adding noise


# Hypothesis Class: will start with a linear function, which is randomly initialized
# model is a variable of all parametesr, and supports model(X) calls
key, subkey = random.split(key)
model = eqx.nn.Linear(M, 1, use_bias = False, key = subkey)

# Needs to remove the non-differentiable parts of the "model" object
opt_state = optimizer.init(eqx.filter(model,eqx.is_inexact_array))

@eqx.filter_jit
def make_step(model, opt_state, X, Y):     
  def step_residuals(model):
    return residuals(model, X, Y)
  loss_value, grads = eqx.filter_value_and_grad(step_residuals)(model)
  updates, opt_state = optimizer.update(grads, opt_state, model, value = loss_value, grad = grads, value_fn = step_residuals)
  model = eqx.apply_updates(model, updates)
  return model, opt_state, loss_value

num_epochs = 20
batch_size = 1024#64
dataset = jdl.ArrayDataset(X,Y)
train_loader = DataLoaderJAX(dataset, batch_size = batch_size, shuffle = True)
for epoch in range(num_epochs):
    for X_batch, Y_batch in train_loader:
        model, opt_state, train_loss = make_step(model, opt_state, X_batch, Y_batch)
    
    if epoch % 2 == 0:
        print(f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")

print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")


# ## Custom equinox type, like in pytorch
class MyLinear(eqx.Module):
    weight: jax.Array

    def __init__(self, in_size, out_size, key):
        self.weight = jax.random.normal(key, (out_size, in_size))

    # Equivalent to Pytorch's forward
    def __call__(self, x):
        return self.weight @ x

model = MyLinear(M, 1, key = subkey)
opt_state = optimizer.init(eqx.filter(model,eqx.is_inexact_array))

for epoch in range(num_epochs):
    for X_batch, Y_batch in train_loader:
        model, opt_state, train_loss = make_step(model, opt_state, X_batch, Y_batch)
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")

print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")


model = eqx.nn.MLP(M, 1, width_size=128, depth=3, key = subkey)
opt_state = optimizer.init(eqx.filter(model,eqx.is_inexact_array))

for epoch in range(num_epochs):
    for X_batch, Y_batch in train_loader:
        model, opt_state, train_loss = make_step(model, opt_state, X_batch, Y_batch)
    
    if epoch % 2 == 0:
         print(f"Epoch {epoch},train_loss={train_loss}")

This particular error says

Traceback (most recent call last):
  File "c:\Users\jesse\Documents\GitHub\ECON622_instructor\lectures\examples\linear_regression_jax_equinox.py", line 97, in <module>
    model, opt_state, train_loss = make_step(model, opt_state, X_batch, Y_batch)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\anaconda3\envs\econ622\Lib\site-packages\equinox\_jit.py", line 242, in __call__
    return self._call(False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\anaconda3\envs\econ622\Lib\site-packages\equinox\_module.py", line 1078, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\anaconda3\envs\econ622\Lib\site-packages\equinox\_jit.py", line 215, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\jesse\Documents\GitHub\ECON622_instructor\lectures\examples\linear_regression_jax_equinox.py", line 50, in make_step
    updates, opt_state = optimizer.update(grads, opt_state, model, value = loss_value, grad = grads, value_fn = step_residuals)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\anaconda3\envs\econ622\Lib\site-packages\optax\transforms\_combining.py", line 73, in update_fn
    updates, new_s = fn(updates, s, params, **extra_args)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\anaconda3\envs\econ622\Lib\site-packages\optax\_src\base.py", line 330, in update
    return tx.update(updates, state, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\anaconda3\envs\econ622\Lib\site-packages\optax\_src\transform.py", line 1438, in update_fn
    diff_params = otu.tree_sub(params, state.params)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jesse\anaconda3\envs\econ622\Lib\site-packages\optax\tree_utils\_tree_math.py", line 57, in tree_sub
    return jtu.tree_map(operator.sub, tree_x, tree_y)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: unsupported operand type(s) for -: 'custom_jvp' and 'NoneType'

@HarshBabla99
Copy link

HarshBabla99 commented Oct 9, 2024

I ran into the same issue. After some investigation, I may have some hints for what might be causing it.

The problematic line in the error trace above (which I was able to replicate) is:

diff_params = otu.tree_sub(params, state.params)

within optax\_src\transform.py

The issue is that state.params is "filtered" (using eqx.filter(..., eqx.is_array)), while params is not. This is evident if you print out both variables

params = MLP(
  layers=[
    Linear(....),
    <wrapped function relu>,
    ...
  ]
)

state.params = MLP(
  layers=[
    Linear(....),
    None,
    ...
  ]
)

Therefore, one of the PyTrees has a custom.jvp (i.e. relu) at the same leaf-location where the other has a NoneType. This results in the eventual TypeError.

So I thought to fix it by filtering the model right before the update. I realize this is probably quite inefficient, but before I could even address that I received the following, fairly recondite, error traceback:

.... (my code)    
    model_filtered = eqx.filter(model, eqx.is_array) #!
--> updates, opt_state = optimizer.update(grads, opt_state, model_filtered,
                                          value = loss_value, grad = grads, value_fn = loss)
    # Apply the updates to the model
    model = eqx.apply_updates(model, updates)

File python3.12/site-packages/optax/transforms/_combining.py:73, in chain.<locals>.update_fn(updates, state, params, **extra_args)
     71 new_state = []
     72 for s, fn in zip(state, update_fns):
---> 73   updates, new_s = fn(updates, s, params, **extra_args)
     74   new_state.append(new_s)
     75 return updates, tuple(new_state)

File python3.12/site-packages/optax/_src/linesearch.py:1493, in scale_by_zoom_linesearch.<locals>.update_fn(updates, state, params, value, grad, value_fn, **extra_args)
   1484 stepsize_guess = state.learning_rate
   1485 init_state = init_ls(
   1486     updates,
   1487     params,
   (...)
   1490     stepsize_guess=stepsize_guess,
   1491 )
-> 1493 final_state = jax.lax.while_loop(
   1494     cond_step_ls,
   1495     functools.partial(
   1496         step_ls, value_and_grad_fn=value_and_grad_fn, fn_kwargs=fn_kwargs
   1497     ),
   1498     init_state,
   1499 )
   1500 learning_rate = final_state.stepsize
   1501 scaled_updates = otu.tree_scalar_mul(learning_rate, updates)

    [... skipping hidden 9 frame]

File python3.12/site-packages/optax/_src/linesearch.py:1166, in zoom_linesearch.<locals>.step_fn(state, value_and_grad_fn, fn_kwargs)
   1159 def step_fn(
   1160     state: ZoomLinesearchState,
   1161     *,
   1162     value_and_grad_fn: Callable[..., tuple[chex.Numeric, base.Updates]],
   1163     fn_kwargs: dict[str, Any],
   1164 ) -> ZoomLinesearchState:
   1165   """Makes a step of the linesearch."""
-> 1166   new_state = jax.lax.cond(
   1167       state.interval_found,
   1168       functools.partial(
   1169           _zoom_into_interval,
   1170           value_and_grad_fn=value_and_grad_fn,
   1171           fn_kwargs=fn_kwargs,
   1172       ),
   1173       functools.partial(
   1174           _search_interval,
   1175           value_and_grad_fn=value_and_grad_fn,
   1176           fn_kwargs=fn_kwargs,
   1177       ),
   1178       state,
   1179   )
   1180   new_state = jax.lax.cond(
   1181       new_state.failed,
   1182       _try_safe_step,
   1183       lambda x: x,
   1184       new_state
   1185   )
   1186   return new_state

    [... skipping hidden 9 frame]

File python3.12/site-packages/optax/_src/linesearch.py:933, in zoom_linesearch.<locals>._zoom_into_interval(state, value_and_grad_fn, fn_kwargs)
    929 middle = jnp.where(use_bisection, middle_bisection, middle)
    931 # Check if new point is good
    932 _, value_middle, grad_middle, slope_middle = (
--> 933     _value_and_slope_on_line(
    934         value_and_grad_fn, params, middle, updates, fn_kwargs
    935     )
    936 )
    938 decrease_error = _decrease_error(
    939     middle, value_middle, slope_middle, value_init, slope_init
    940 )
    941 curvature_error = _curvature_error(slope_middle, slope_init)

File python3.12/site-packages/optax/_src/linesearch.py:637, in zoom_linesearch.<locals>._value_and_slope_on_line(value_and_grad_fn, params, stepsize, updates, fn_kwargs)
    598 r"""Compute value and slope on line.
    599 
    600 Mathematically, outputs
   (...)
    634       stepsize at the step.
    635 """
    636 step = otu.tree_add_scalar_mul(params, stepsize, updates)
--> 637 value_step, grad_step = value_and_grad_fn(step, **fn_kwargs)
    638 slope_step = otu.tree_vdot(grad_step, updates)
    639 return step, value_step, grad_step, slope_step

    [... skipping hidden 13 frame]

File python3.12/inspect.py:3273, in Signature.bind(self, *args, **kwargs)
   3268 def bind(self, /, *args, **kwargs):
   3269     """Get a BoundArguments object, that maps the passed `args`
   3270     and `kwargs` to the function's signature.  Raises `TypeError`
   3271     if the passed arguments can not be bound.
   3272     """
-> 3273     return self._bind(args, kwargs)

File python3.12/inspect.py:3186, in Signature._bind(self, args, kwargs, partial)
   3184                 msg = 'missing a required{argtype} argument: {arg!r}'
   3185                 msg = msg.format(arg=param.name, argtype=argtype)
-> 3186                 raise TypeError(msg) from None
   3187 else:
   3188     # We have a positional argument to process
   3189     try:

TypeError: missing a required argument: 'data'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

4 participants