-
-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] Interaction with optax LBFGS optimizer #796
Comments
From your description is seems possible that you've gotten two arguments the wrong way around, so that you are passing a I suspect |
Hmm.... having some trouble here still (and also, this is not not using the cached Here is something where shows the current issues I am seeing. Consider: (1) To replicate: run the following code, then comment/uncomment to swap out the optimizer used (near the top of the file) and you will see the error: # Takes the baseline version and uses vmap, adds in a learning rate scheduler
import jax
import jax.numpy as jnp
from jax import random
import optax
import equinox as eqx
import jax_dataloader as jdl
from jax_dataloader.loaders import DataLoaderJAX
# LLS loss function with vmap
def residual(model, x, y):
y_hat = model(x)
return (y_hat - y) ** 2
def residuals(model, X, Y):
batched_residuals = jax.vmap(residual, in_axes=(None, 0, 0))
return jnp.mean(batched_residuals(model, X, Y))
# SWITCH OPTIMIZERS HERE!!!!
# reinitialize
optimizer = optax.sgd(0.001)
#optimizer = optax.lbfgs()
N = 500 # samples
M = 2
sigma = 0.0001
key = random.PRNGKey(42)
key, *subkey = random.split(key, num=4)
theta = random.normal(subkey[0], (M,))
X = random.normal(subkey[1], (N, M))
Y = X @ theta + sigma * random.normal(subkey[2], (N,)) # Adding noise
# Hypothesis Class: will start with a linear function, which is randomly initialized
# model is a variable of all parametesr, and supports model(X) calls
key, subkey = random.split(key)
model = eqx.nn.Linear(M, 1, use_bias = False, key = subkey)
# Needs to remove the non-differentiable parts of the "model" object
opt_state = optimizer.init(eqx.filter(model,eqx.is_inexact_array))
@eqx.filter_jit
def make_step(model, opt_state, X, Y):
def step_residuals(model):
return residuals(model, X, Y)
loss_value, grads = eqx.filter_value_and_grad(step_residuals)(model)
updates, opt_state = optimizer.update(grads, opt_state, model, value = loss_value, grad = grads, value_fn = step_residuals)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss_value
num_epochs = 20
batch_size = 1024#64
dataset = jdl.ArrayDataset(X,Y)
train_loader = DataLoaderJAX(dataset, batch_size = batch_size, shuffle = True)
for epoch in range(num_epochs):
for X_batch, Y_batch in train_loader:
model, opt_state, train_loss = make_step(model, opt_state, X_batch, Y_batch)
if epoch % 2 == 0:
print(f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")
print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")
# ## Custom equinox type, like in pytorch
class MyLinear(eqx.Module):
weight: jax.Array
def __init__(self, in_size, out_size, key):
self.weight = jax.random.normal(key, (out_size, in_size))
# Equivalent to Pytorch's forward
def __call__(self, x):
return self.weight @ x
model = MyLinear(M, 1, key = subkey)
opt_state = optimizer.init(eqx.filter(model,eqx.is_inexact_array))
for epoch in range(num_epochs):
for X_batch, Y_batch in train_loader:
model, opt_state, train_loss = make_step(model, opt_state, X_batch, Y_batch)
if epoch % 100 == 0:
print(f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")
print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")
model = eqx.nn.MLP(M, 1, width_size=128, depth=3, key = subkey)
opt_state = optimizer.init(eqx.filter(model,eqx.is_inexact_array))
for epoch in range(num_epochs):
for X_batch, Y_batch in train_loader:
model, opt_state, train_loss = make_step(model, opt_state, X_batch, Y_batch)
if epoch % 2 == 0:
print(f"Epoch {epoch},train_loss={train_loss}") This particular error says
|
I ran into the same issue. After some investigation, I may have some hints for what might be causing it. The problematic line in the error trace above (which I was able to replicate) is: diff_params = otu.tree_sub(params, state.params) within The issue is that params = MLP(
layers=[
Linear(....),
<wrapped function relu>,
...
]
)
state.params = MLP(
layers=[
Linear(....),
None,
...
]
) Therefore, one of the PyTrees has a So I thought to fix it by filtering the model right before the update. I realize this is probably quite inefficient, but before I could even address that I received the following, fairly recondite, error traceback: .... (my code)
model_filtered = eqx.filter(model, eqx.is_array) #!
--> updates, opt_state = optimizer.update(grads, opt_state, model_filtered,
value = loss_value, grad = grads, value_fn = loss)
# Apply the updates to the model
model = eqx.apply_updates(model, updates)
File python3.12/site-packages/optax/transforms/_combining.py:73, in chain.<locals>.update_fn(updates, state, params, **extra_args)
71 new_state = []
72 for s, fn in zip(state, update_fns):
---> 73 updates, new_s = fn(updates, s, params, **extra_args)
74 new_state.append(new_s)
75 return updates, tuple(new_state)
File python3.12/site-packages/optax/_src/linesearch.py:1493, in scale_by_zoom_linesearch.<locals>.update_fn(updates, state, params, value, grad, value_fn, **extra_args)
1484 stepsize_guess = state.learning_rate
1485 init_state = init_ls(
1486 updates,
1487 params,
(...)
1490 stepsize_guess=stepsize_guess,
1491 )
-> 1493 final_state = jax.lax.while_loop(
1494 cond_step_ls,
1495 functools.partial(
1496 step_ls, value_and_grad_fn=value_and_grad_fn, fn_kwargs=fn_kwargs
1497 ),
1498 init_state,
1499 )
1500 learning_rate = final_state.stepsize
1501 scaled_updates = otu.tree_scalar_mul(learning_rate, updates)
[... skipping hidden 9 frame]
File python3.12/site-packages/optax/_src/linesearch.py:1166, in zoom_linesearch.<locals>.step_fn(state, value_and_grad_fn, fn_kwargs)
1159 def step_fn(
1160 state: ZoomLinesearchState,
1161 *,
1162 value_and_grad_fn: Callable[..., tuple[chex.Numeric, base.Updates]],
1163 fn_kwargs: dict[str, Any],
1164 ) -> ZoomLinesearchState:
1165 """Makes a step of the linesearch."""
-> 1166 new_state = jax.lax.cond(
1167 state.interval_found,
1168 functools.partial(
1169 _zoom_into_interval,
1170 value_and_grad_fn=value_and_grad_fn,
1171 fn_kwargs=fn_kwargs,
1172 ),
1173 functools.partial(
1174 _search_interval,
1175 value_and_grad_fn=value_and_grad_fn,
1176 fn_kwargs=fn_kwargs,
1177 ),
1178 state,
1179 )
1180 new_state = jax.lax.cond(
1181 new_state.failed,
1182 _try_safe_step,
1183 lambda x: x,
1184 new_state
1185 )
1186 return new_state
[... skipping hidden 9 frame]
File python3.12/site-packages/optax/_src/linesearch.py:933, in zoom_linesearch.<locals>._zoom_into_interval(state, value_and_grad_fn, fn_kwargs)
929 middle = jnp.where(use_bisection, middle_bisection, middle)
931 # Check if new point is good
932 _, value_middle, grad_middle, slope_middle = (
--> 933 _value_and_slope_on_line(
934 value_and_grad_fn, params, middle, updates, fn_kwargs
935 )
936 )
938 decrease_error = _decrease_error(
939 middle, value_middle, slope_middle, value_init, slope_init
940 )
941 curvature_error = _curvature_error(slope_middle, slope_init)
File python3.12/site-packages/optax/_src/linesearch.py:637, in zoom_linesearch.<locals>._value_and_slope_on_line(value_and_grad_fn, params, stepsize, updates, fn_kwargs)
598 r"""Compute value and slope on line.
599
600 Mathematically, outputs
(...)
634 stepsize at the step.
635 """
636 step = otu.tree_add_scalar_mul(params, stepsize, updates)
--> 637 value_step, grad_step = value_and_grad_fn(step, **fn_kwargs)
638 slope_step = otu.tree_vdot(grad_step, updates)
639 return step, value_step, grad_step, slope_step
[... skipping hidden 13 frame]
File python3.12/inspect.py:3273, in Signature.bind(self, *args, **kwargs)
3268 def bind(self, /, *args, **kwargs):
3269 """Get a BoundArguments object, that maps the passed `args`
3270 and `kwargs` to the function's signature. Raises `TypeError`
3271 if the passed arguments can not be bound.
3272 """
-> 3273 return self._bind(args, kwargs)
File python3.12/inspect.py:3186, in Signature._bind(self, args, kwargs, partial)
3184 msg = 'missing a required{argtype} argument: {arg!r}'
3185 msg = msg.format(arg=param.name, argtype=argtype)
-> 3186 raise TypeError(msg) from None
3187 else:
3188 # We have a positional argument to process
3189 try:
TypeError: missing a required argument: 'data' |
Hi,
I am having trouble implementing optax LBFGS with equinox types. I am trying to run a linear regression model using this notebook https://github.com/ubcecon/ECON622/blob/master/lectures/lectures/examples/linear_regression_jax_equinox.py and changing the optimizer to optax.lbfgs(). Here is the core of the code:
I get the following error:
TypeError: unsupported operand type(s) for @: 'DynamicJaxprTracer' and 'Linear'
From the traceback, the error shows up when evaluating the resids_closure function. Particularly, when I evaluate model(x) and it tries to multiply weights by x:
Does anybody have an idea on how to tackle this? Thanks!
The text was updated successfully, but these errors were encountered: