-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using optimistix
with an equinox
model
#56
Comments
You have (Caveat: I've not tried running the code, this was just what jumped out at me.) |
Indeed, that was the issue! Thanks a lot for the reply and for the library! |
@patrick-kidger I ran the same code after fixing the param static but I get following errors
Could you please help me with this? |
Sorry, I can't reproduce this. Can you provide a code snippet demonstrating your problem? |
Hi @patrick-kidger Sure Here it is
|
Hi Raj, I can get this to run by:
Here is the edited example: import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx
jax.config.update("jax_enable_x64", True)
# Sample data
X = jax.random.normal(jax.random.PRNGKey(0), (2000, 8))
# Define the function
@jax.vmap
def function(x):
return x[0] + x[1] ** 2 + jnp.cos(x[2]) + jnp.sin(x[3]) + x[4] * x[5] + (x[6] * x[7]) ** 3
# Compute the output for y
y = function(X).reshape(-1, 1) # This is a column vector with shape (2000, 1)
# Define the MLP model
model = eqx.nn.MLP(in_size=8, out_size=1, width_size=4, depth=2, activation=jax.nn.silu, key=jax.random.PRNGKey(0))
# Partition the parameters of the model (params) and static (non-trainable) parts
params, static = eqx.partition(model, eqx.is_inexact_array)
# Loss function
def loss_fn(params, static):
model = eqx.combine(params, static)
model_output = jax.vmap(model)(X)
return jnp.sum((model_output - y) ** 2)
# Set up the solver
solver = optx.BFGS(rtol=1e-5, atol=1e-5)
# Minimize the loss function
sol = optx.minimise(loss_fn, solver, params, args=static, max_steps=2**14) Note that you might get better performance if you return the residuals instead of their squared sum and use the Gauss-Newton or Levenberg-Marquardt solvers instead (haven't tested this). |
Thank you very much @johannahaffner ! I have a questions
|
@johannahaffner @patrick-kidger How do I Print the value of objective function at each iteration of BFGS? I checked the API and It seems like the verbose is not the option? I will appreciate for your help and suggestions. |
Hi Raj, glad it helped you! the values of your optimised parameters is accessible as BFGS does not (yet) support a verbose option. You could try an interactive solve, following this example. @patrick-kidger what do you think about adding my verbose version of BFGS? This feature is completely independent of constrained-anything and I am happy to spin this out in its own PR. But it does represent a break for users who have implemented something custom on top of BFGS and I do understand wanting to keep as many releases as possible break-free! |
Thank you very much @johannahaffner. I will try out the suggestion. |
I think that'd be good! |
Hi @patrick-kidger and @johannahaffner : I had another question. Could you please let me know how to print out number of iterations taken by BFGS Iteration. Thank you! |
Hi Raj, this is indicated in the Happy solving :) |
Thank you very much @johannahaffner! |
Hi @johannahaffner I had a question. Is there any |
We haven't implemented LBFGS at the moment. We should probably do this :) |
@patrick-kidger and @johannahaffner: Happy New Year gyus! I had a question. If I want to terminate the iteration on the basis of number of steps rather than RTOL or ATOL, then which I routine I should change? Thank you very much! |
Probably simplest is just to use But if you want to customize the solver directly then you could subclass it and adjust its |
Thank you very much @patrick-kidger . |
Hi @patrick-kidger and @johannahaffner I needed to add the following curvature condition along with Armijo condition as follows
My Python code looks like
It does not converge. Could you please help me with this Thank you! |
@patrick-kidger @johannahaffner Specially I don't know how to compute
|
Hi back! And nice to see you tinkering with optimistix. Just looking at this, I think I have a hunch where the lack of convergence might come from, and how you could fix it. I think that implementing a Wolfe line search is probably going to require expanding as well as shrinking the step size. In Now, to your other question, about computing
and the predicted reduction is computed here:
Note that no equivalent of the step_size ( |
Thanks @johannahaffner I had another question: If I have to compute: $$ |
You move this line optimistix/optimistix/_solver/bfgs.py Line 224 in dcafc48
out of the optimistix/optimistix/_solver/bfgs.py Line 208 in dcafc48
then you can pass optimistix/optimistix/_solver/bfgs.py Line 216 in dcafc48
I think this should "just work", if not then let me know (in case you get complaints about search state discrepancies and such). By the way, here is an implementation of a Wolfe line search in Julia that might offer some inspiration: https://github.com/JuliaNLSolvers/LineSearches.jl/blob/master/src/strongwolfe.jl |
Thanks @johannahaffner . I will try out and get back to you. Thank you very much! |
Hi @johannahaffner I wa able to implement Wolfe Criteria, like below but I did not get any improvement a It seems like the result is similar to what i get with Armijo condition only. Do you have any suggestion? Thank you!
|
I can't see any step length increases, is that correct? |
Hi everyone, thanks for the great library and apologies in advance for this basic question.
I'm trying to find the true minimum of a small neural network, and I thought of using a solver from
optimistix
together with anequinox
model. However, I haven't been able to make the two work together.Here is a minimal snippet which fails:
I'm getting
TypeError: Cannot determine dtype of <PjitFunction of <function silu at 0x742fde959300>>
.What am I doing wrong?
Thank you in advance.
The text was updated successfully, but these errors were encountered: