Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using optimistix with an equinox model #56

Open
frostedoyster opened this issue May 7, 2024 · 27 comments
Open

Using optimistix with an equinox model #56

frostedoyster opened this issue May 7, 2024 · 27 comments
Labels
question User queries

Comments

@frostedoyster
Copy link

Hi everyone, thanks for the great library and apologies in advance for this basic question.
I'm trying to find the true minimum of a small neural network, and I thought of using a solver from optimistix together with an equinox model. However, I haven't been able to make the two work together.

Here is a minimal snippet which fails:

import jax 
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx

jax.config.update("jax_enable_x64", True)


X = jax.random.normal(jax.random.PRNGKey(0), (2000, 8))

@jax.vmap
def function(x):
    return x[0] + x[1]**2 + jnp.cos(x[2]) + jnp.sin(x[3]) + x[4]*x[5] + (x[6]*x[7])**3

y = function(X).reshape(-1, 1)

model = eqx.nn.MLP(in_size=8, out_size=1, width_size=4, depth=2, activation=jax.nn.silu, key=jax.random.PRNGKey(0))

static, params = eqx.partition(model, eqx.is_inexact_array)

def loss_fn(params, static, X, y):
    model = eqx.combine(params, static)
    return jnp.sum((jax.vmap(model)(X) - y)**2)

solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.minimise(loss_fn, solver, params)

I'm getting TypeError: Cannot determine dtype of <PjitFunction of <function silu at 0x742fde959300>>.

What am I doing wrong?
Thank you in advance.

@patrick-kidger
Copy link
Owner

You have static and params the wrong way around.

(Caveat: I've not tried running the code, this was just what jumped out at me.)

@patrick-kidger patrick-kidger added the question User queries label May 7, 2024
@frostedoyster
Copy link
Author

Indeed, that was the issue! Thanks a lot for the reply and for the library!

@raj-brown
Copy link

raj-brown commented Dec 5, 2024

@patrick-kidger I ran the same code after fixing the param static but I get following errors

 File "/Users/raj/opt/anaconda3/envs/equinox/lib/python3.9/site-packages/optimistix/_misc.py", line 89, in __call__
    return self.fn(*args, **kwargs), None
  File "/Users/raj/WORK_RAJ/NEURONSE_DARPA/RANS_FSBL/CODES/test.py", line 24, in loss_fn
    model = eqx.combine(params, static)
  File "/Users/raj/opt/anaconda3/envs/equinox/lib/python3.9/site-packages/equinox/_filters.py", line 209, in combine
    return jtu.tree_map(_combine, *pytrees, is_leaf=_is_leaf)
ValueError: Custom node type mismatch: expected type: <class 'equinox.nn._mlp.MLP'>, value: None.

Could you please help me with this?

@patrick-kidger
Copy link
Owner

Sorry, I can't reproduce this. Can you provide a code snippet demonstrating your problem?

@raj-brown
Copy link

Hi @patrick-kidger Sure Here it is

import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx

jax.config.update("jax_enable_x64", True)

# Sample data
X = jax.random.normal(jax.random.PRNGKey(0), (2000, 8))


# Define the function
@jax.vmap
def function(x):
    return x[0] + x[1] ** 2 + jnp.cos(x[2]) + jnp.sin(x[3]) + x[4] * x[5] + (x[6] * x[7]) ** 3


# Compute the output for y
y = function(X).reshape(-1, 1)  # This is a column vector with shape (2000, 1)

# Define the MLP model
model = eqx.nn.MLP(in_size=8, out_size=1, width_size=4, depth=2, activation=jax.nn.silu, key=jax.random.PRNGKey(0))

# Partition the parameters of the model (params) and static (non-trainable) parts
params, static = eqx.partition(model, eqx.is_inexact_array)


# Loss function
def loss_fn(params, static):
    model = eqx.combine(params, static)
    model_output = model(X)
    return jnp.sum((model_output - y) ** 2)


# Set up the solver
solver = optx.Newton(rtol=1e-5, atol=1e-5)

# Minimize the loss function
sol = optx.minimise(loss_fn, solver, params)

@johannahaffner
Copy link
Contributor

johannahaffner commented Dec 6, 2024

Hi Raj,

I can get this to run by:

  • passing static through the args keyword of optx.minimise, which your code is missing. This is the source of the error you were seeing - the default value of args is None
  • adding a vmap to the model call
  • jacking up the number of steps - neural networks are typically not trained in your default 256 steps
  • switching the solver to the quasi-Newton BFGS, which does not have the same tendency to converge to stationary points

Here is the edited example:

import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx

jax.config.update("jax_enable_x64", True)

# Sample data
X = jax.random.normal(jax.random.PRNGKey(0), (2000, 8))

# Define the function
@jax.vmap
def function(x):
    return x[0] + x[1] ** 2 + jnp.cos(x[2]) + jnp.sin(x[3]) + x[4] * x[5] + (x[6] * x[7]) ** 3

# Compute the output for y
y = function(X).reshape(-1, 1)  # This is a column vector with shape (2000, 1)

# Define the MLP model
model = eqx.nn.MLP(in_size=8, out_size=1, width_size=4, depth=2, activation=jax.nn.silu, key=jax.random.PRNGKey(0))

# Partition the parameters of the model (params) and static (non-trainable) parts
params, static = eqx.partition(model, eqx.is_inexact_array)

# Loss function
def loss_fn(params, static):
    model = eqx.combine(params, static)
    model_output = jax.vmap(model)(X)
    return jnp.sum((model_output - y) ** 2)

# Set up the solver
solver = optx.BFGS(rtol=1e-5, atol=1e-5)

# Minimize the loss function
sol = optx.minimise(loss_fn, solver, params, args=static, max_steps=2**14)

Note that you might get better performance if you return the residuals instead of their squared sum and use the Gauss-Newton or Levenberg-Marquardt solvers instead (haven't tested this).

@raj-brown
Copy link

Thank you very much @johannahaffner ! I have a questions

  1. How to combine the params and static after minimization. I did not work
    eqx.combine(params, static)
    I will try with residual and see how it performs. Thank you very much!

@raj-brown
Copy link

@johannahaffner @patrick-kidger How do I Print the value of objective function at each iteration of BFGS? I checked the API and It seems like the verbose is not the option? I will appreciate for your help and suggestions.

@johannahaffner
Copy link
Contributor

johannahaffner commented Dec 7, 2024

Hi Raj,

glad it helped you!

the values of your optimised parameters is accessible as sol.value, params will still point to your initial values. Have you tried passing sol.value to eqx.combine?

BFGS does not (yet) support a verbose option. You could try an interactive solve, following this example.

@patrick-kidger what do you think about adding my verbose version of BFGS? This feature is completely independent of constrained-anything and I am happy to spin this out in its own PR. But it does represent a break for users who have implemented something custom on top of BFGS and I do understand wanting to keep as many releases as possible break-free!

@raj-brown
Copy link

Thank you very much @johannahaffner. I will try out the suggestion.

@patrick-kidger
Copy link
Owner

what do you think about adding my verbose version of BFGS?

I think that'd be good!

@raj-brown
Copy link

Hi @patrick-kidger and @johannahaffner : I had another question. Could you please let me know how to print out number of iterations taken by BFGS Iteration. Thank you!

@johannahaffner
Copy link
Contributor

Hi Raj,

this is indicated in the stats attribute of the solution object. You can access it with sol.stats["num_steps"].

Happy solving :)

@raj-brown
Copy link

Thank you very much @johannahaffner!

@raj-brown
Copy link

Hi @johannahaffner I had a question. Is there any LBFGS api in Optimistix. Should I use BFGS Class and defined a new class for LBFGS..Thank you!

@patrick-kidger
Copy link
Owner

We haven't implemented LBFGS at the moment. We should probably do this :)

@raj-brown
Copy link

@patrick-kidger and @johannahaffner: Happy New Year gyus! I had a question. If I want to terminate the iteration on the basis of number of steps rather than RTOL or ATOL, then which I routine I should change? Thank you very much!

@patrick-kidger
Copy link
Owner

Probably simplest is just to use optx.{minimise, root_find, etc}(..., max_steps=..., throw=False) :)

But if you want to customize the solver directly then you could subclass it and adjust its terminate method.

@raj-brown
Copy link

Thank you very much @patrick-kidger .

@raj-brown
Copy link

Hi @patrick-kidger and @johannahaffner I needed to add the following curvature condition along with Armijo condition as follows

  1. Armijo Condition (Sufficient Decrease Condition): Ensures the step size reduces the function value sufficiently.

$$ f(x+\alpha p) \leq f(x)+c_1 \alpha \nabla f(x)^T p $$

  1. Curvature Condition: Ensures the gradient in the new point is not too small, maintaining sufficient decrease.

$$ \nabla f(x+\alpha p)^T p \geq c_2 \nabla f(x)^T p $$

My Python code looks like

def step(
        self,
        first_step: Bool[Array, ""],
        y: Y,
        y_eval: Y,
        f_info: _FnInfo,
        f_eval_info: _FnEvalInfo,
        state: _BacktrackingState,
    ) -> tuple[Scalar, Bool[Array, ""], RESULTS, _BacktrackingState]:
        if not isinstance(
            f_info,
            (
                FunctionInfo.EvalGrad,
                FunctionInfo.EvalGradHessian,
                FunctionInfo.EvalGradHessianInv,
                FunctionInfo.ResidualJac,
            ),
        ):
            raise ValueError(
                "Cannot use `BacktrackingArmijo` with this solver. This is because "
                "`BacktrackingArmijo` requires gradients of the target function, but "
                "this solver does not evaluate such gradients."
            )

        y_diff = (y_eval**ω - y**ω).ω
        predicted_reduction = f_info.compute_grad_dot(y_diff)


        # Terminate when the Armijo condition is satisfied. That is, `fn(y_eval)`                                                                                                       
        # must do better than its linear approximation:                                                                                                                                 
        # `fn(y_eval) < fn(y) + grad•y_diff`                                                                                                                                            
        f_min = f_info.as_min()
        f_min_eval = f_eval_info.as_min()
        f_min_diff = f_min_eval - f_min  # This number is probably negative                                                                                                             

        lhs_wolfe = self.slope*f_info.compute_grad_dot(y_eval)
        rhs_wolfe = self.slope*f_info.compute_grad_dot(y)
       satisfies_armijo = f_min_diff <= self.slope * predicted_reduction
        ## Added Wolfe                                                                                                                                                                  
        satisfies_wolfe = lhs_wolfe >= rhs_wolfe
        has_reduction = predicted_reduction <= 0

        accept = first_step | (satisfies_armijo & has_reduction & satisfies_wolfe)
        step_size = jnp.where(
            accept, self.step_init, self.decrease_factor * state.step_size
        )
        step_size = cast(Scalar, step_size)
        return (
            step_size,
            accept,
            RESULTS.successful,
            _BacktrackingState(step_size=step_size),
        )

It does not converge. Could you please help me with this Thank you!

@raj-brown
Copy link

@patrick-kidger @johannahaffner Specially I don't know how to compute

$$
f(x+\alpha p) \leq f(x)+c_1 \alpha \nabla f(x)^T p
$$

@johannahaffner
Copy link
Contributor

Hi back!

And nice to see you tinkering with optimistix. Just looking at this, I think I have a hunch where the lack of convergence might come from, and how you could fix it. I think that implementing a Wolfe line search is probably going to require expanding as well as shrinking the step size.

In BacktrackingArmijo, we're chopping the step size in half until we get something that satisfies the Armijo condition, starting from a full (Newton) step. That means that we constrain the step size ($\alpha$) to be in $(0, 1]$.
However, if no step size in that interval meets both conditions, you're going to be stuck. So just adding the Wolfe conditions to a backtracking line search is not going to get you where you want to be (having a Wolfe line search), since you're not handling the case where satisfying the curvature condition is going to require a step size that is greater than one.
I'd have to look up the best way to handle this.

Now, to your other question, about computing $f(x+\alpha p) \leq f(x)+c_1 \alpha \nabla f(x)^T p$, this is evaluated here:

satisfies_armijo = f_min_diff <= self.slope * predicted_reduction

and the predicted reduction is computed here:
predicted_reduction = f_info.compute_grad_dot(y_diff)

Note that no equivalent of the step_size $\alpha$ shows up in these computations because this is implicitly part of y_diff: y_eval has been scaled by the step size after the previous (rejected) step, and y_diff has changed accordingly.

(y_eval is updated at every step, and y is only overwritten when the step is accepted at whatever the current y_eval is. This section of the documentation has more info about this.)

@raj-brown
Copy link

Thanks @johannahaffner I had another question: If I have to compute: $$
\nabla f(x+\alpha p)^T p \geq c_2 \nabla f(x)^T p
$$. Then how to do it as f_eval does not have any gradient function.

@johannahaffner
Copy link
Contributor

Then how to do it as f_eval does not have any gradient function.

You move this line

(grad,) = lin_to_grad(lin_fn, state.y_eval)

out of the accepted branch, and call it earlier, below this one

f_eval, lin_fn, aux_eval = jax.linearize(

then you can pass FunctionInfo.EvalGrad(...) instead when calling the search step, here

FunctionInfo.Eval(f_eval),

I think this should "just work", if not then let me know (in case you get complaints about search state discrepancies and such).

By the way, here is an implementation of a Wolfe line search in Julia that might offer some inspiration: https://github.com/JuliaNLSolvers/LineSearches.jl/blob/master/src/strongwolfe.jl

@raj-brown
Copy link

Thanks @johannahaffner . I will try out and get back to you. Thank you very much!

@raj-brown
Copy link

Hi @johannahaffner I wa able to implement Wolfe Criteria, like below but I did not get any improvement a It seems like the result is similar to what i get with Armijo condition only. Do you have any suggestion? Thank you!

lhs_wolfe = f_eval_info.compute_grad_dot(y_diff)


        # Terminate when the Armijo condition is satisfied. That is, `fn(y_eval)`                                                                                                       
        # must do better than its linear approximation:                                                                                                                                 
        # `fn(y_eval) < fn(y) + grad•y_diff`                                                                                                                                            

        f_min = f_info.as_min()
        f_min_eval = f_eval_info.as_min()
        f_min_diff = f_min_eval - f_min  # This number is probably negative                                                                                                             

        ##                                                                                                                                                                              
        rhs_wolfe = predicted_reduction

        satisfies_armijo = f_min_diff <= self.slope * predicted_reduction

        ## Added Wolfe                                                                                                                                                                  

        satisfies_wolfe = lhs_wolfe >=  0.9 * rhs_wolfe
        print(f"wolfe: {satisfies_wolfe}")

        has_reduction = predicted_reduction <= 0

        accept = first_step | (satisfies_armijo & has_reduction & satisfies_wolfe)

@johannahaffner
Copy link
Contributor

I can't see any step length increases, is that correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

4 participants