Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Damped Newton solve(?) / Scipy and Optimistix #75

Closed
djbower opened this issue Aug 16, 2024 · 6 comments
Closed

Damped Newton solve(?) / Scipy and Optimistix #75

djbower opened this issue Aug 16, 2024 · 6 comments

Comments

@djbower
Copy link

djbower commented Aug 16, 2024

Similar to #74, I'm in the process of swapping in optimistix for a previous solver. The system I'm solving is a relatively small chemical network with mass balance (so think law of mass action plus conservation of moles/mass of elements). I was previously using scipy root finder and then added JAX to provide the Jacobian as a callable for the scipy root finder. This approach generally works well and seems decently robust (I've had success with LM, Newton).

I'm now swapping in optimistix and I'm finding a challenge with solving the same system. In brief, using the Chord solver works, but trying to use Newton (or even Dogleg or LM) blows up the solution guess after a handful of iterations, causing an infinity due to summation and then NaN due to log-ing further downstream. When I monitor the scipy solver with the JAX-provided Jacobian the progression to the solution is smooth and seemingly well-behaved. So I assume somewhere along the line an erroneous step is being taken by Optimistix which cannot be recovered from. I noticed the arxiv paper has a comparison with scipy in terms of performance, so I'm wondering if you had any insights as to which parameters/options I could tune to improve the behaviour of the optimistix solver (or maybe bring it in line with the Scipy solver for side-by-side comparison?). Unfortunately my code is wrapped up in a larger package so it's difficult for me to provide a MWE at the present time. Nevertheless, any pointers are much appreciated (and thanks for developing so many excellent JAX packages and making them available).

@patrick-kidger
Copy link
Owner

I think I would need a MWE to see what's going on for you, I'm afraid!
(If it means anything, I have used Optimistix as part of modelling chemical networks, successfully.)

Probably a good first step would be to place some jax.debug.{print,breakpoint}s inside the solver and check how the iterations are proceeding -- to figure out why they are blowing up.

@djbower
Copy link
Author

djbower commented Aug 17, 2024

A minimum working example is below. You can comment/uncomment the relevant solvers in the main guard to see the performance comparison between SciPy and Optimistix. The system is formulated in terms of log10 (molecule) number densities, but since the elemental mass balance requires a summation I'm using logsumexp to try and retain precision. Maybe you can identify a better way to scale the problem to close the gap between the Scipy and Optimistix performance. The crux of the problem is that the Optimistix Newton solver blows up the solution which causes an NaN:

Error below (reproduce by running the script):

Solving with Optimistix
solution in = [26. 26. 26. 26. 26. 26.]
residual out = [ 52.20882135 -11.52502446   2.60645764   0.10997     -1.14325911
  -0.49309319]
solution in = [26.71993707 43.1716189  39.70471743  6.69454232 19.76625234 20.64657795]
residual out = [3.55271368e-15 2.84217094e-14 3.55271368e-15 1.66161417e+01
 1.55387075e+01 1.28469107e+01]
solution in = [ 3.05567524e+16  2.55958745e+01  2.58966937e+01 -6.11135048e+16
  1.22227010e+17  3.05567524e+16]
residual out = [ 1.02088214e+01 -2.03332754e+01 -1.39354236e+00  3.05567524e+16
  1.22227010e+17  1.22227010e+17]
equinox._errors.EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the ...

Minimum working example:

#!/usr/bin/env python
"""Minimum working example (MWE) for comparing Scipy and Optimistix solvers
"""
from typing import Callable

import jax
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
import optimistix as optx
from jax import Array
from jax.typing import ArrayLike
from scipy.constants import gas_constant
from scipy.optimize import OptimizeResult, root

# Scipy also fails if this is commented out. Evidently double precision is required regardless.
jax.config.update("jax_enable_x64", True)

# This should be kept at this temperature since the equilibrium constants for the reactions below
# are hard-coded for this temperature
temperature: float = 450
planet_surface_area: float = 510064471909788.25  # SI units
planet_surface_gravity: float = 9.819973426224687  # SI units

# MWE for reaction network / mass balance
# Species order is: H2, H2O, CO2, O2, CH4, CO

# Species molar masses in kg/mol
molar_masses_dict: dict[str, float] = {
    "H2": 0.002015882,
    "H2O": 0.018015287,
    "CO2": 0.044009549999999995,
    "O2": 0.031998809999999996,
    "CH4": 0.016042504000000003,
    "CO": 0.028010145,
}

# Element log10 number of total molecules constraints:
log10_oxygen_constraint: float = 45.58848007858896
log10_hydrogen_constraint: float = 46.96664792007732
log10_carbon_constraint: float = 45.89051326565627

# Initial solution guess number density (molecules/m^3)
initial_solution: Array = jnp.array([26, 26, 26, 26, 26, 26], dtype=jnp.float_)

# Reaction set is linearly independent (determined by Gaussian elimination in a previous step)
# log10 equilibrium constants
# 2 H2O = 2 H2 + 1 O2
reaction0_log10Kc: float = -26.208821352166428
# 4 H2 + 1 CO2 = 2 H2O + 1 CH4
reaction1_log10Kc: float = -40.474975543925524
# 1 H2 + 1 O2 = 1 H2O + 1 CO
reaction2_log10Kc: float = -2.6064576440642178

# Coefficient matrix (reaction stoichiometry)
# Columns correspond to species: H2, H2O, CO, CO2, CH4, O2
# Rows refer to reactions (three in total)
coefficient_matrix: Array = jnp.array(
    [
        [2.0, -2.0, 0.0, 1.0, 0.0, 0.0],
        [-4.0, 2.0, -1.0, 0.0, 1.0, 0.0],
        [-1.0, 1.0, -1.0, 0.0, 0.0, 1.0],
    ]
)

# rhs constraints are the equilibrium constants of the reaction
rhs: Array = jnp.array([reaction0_log10Kc, reaction1_log10Kc, reaction2_log10Kc])

# For testing solvers, this is the known solution of the system
known_solution: dict[str, float] = {
    "H2": 26.950804260065272,
    "H2O": 26.109794057030303,
    "CO2": 11.303173861822636,
    "O2": -27.890841758236377,
    "CH4": 26.411827244097612,
    "CO": 9.537726420793389,
}

known_solution_array: npt.NDArray[np.float_] = np.array([val for val in known_solution.values()])


def solve_with_scipy(jacobian: bool = True) -> None:
    """Solve the system with Scipy"""

    if jacobian:
        jacobian_function: Callable | None = jax.jacobian(objective_function)
    else:
        jacobian_function = None

    print("Solving with SciPy")
    sol: OptimizeResult = root(objective_function, initial_solution, jac=jacobian_function)

    if sol.success and np.isclose(sol.x, known_solution_array).all():
        print("SciPy success and agrees with known solution. Steps = %d" % sol["nfev"])

    print(sol)


def solve_with_optimistix(method="Dogleg", tol: float = 1.0e-8) -> None:
    """Solve the system with Optimistix"""

    if method == "Dogleg":
        solver = optx.Dogleg(atol=tol, rtol=tol)
    elif method == "Newton":
        solver = optx.Newton(atol=tol, rtol=tol)

    print("Solving with Optimistix")
    sol = optx.root_find(
        objective_function,
        solver,
        initial_solution,
        throw=True,
    )

    if optx.RESULTS[sol.result] == "" and np.isclose(sol.value, known_solution_array).all():
        print(
            "Optimistix success and agrees with known solution. Steps = %d"
            % sol.stats["num_steps"]
        )


def atmosphere_log10_molar_mass(solution: Array) -> Array:
    """Log10 of the molar mass of the atmosphere"""
    molar_masses: Array = jnp.array([value for value in molar_masses_dict.values()])
    molar_mass: Array = logsumexp_base10(solution, molar_masses) - logsumexp_base10(solution)

    return molar_mass


def atmosphere_log10_volume(solution: Array) -> Array:
    """Log10 of the volume of the atmosphere"""
    return (
        jnp.log10(gas_constant)
        + jnp.log10(temperature)
        - atmosphere_log10_molar_mass(solution)
        + jnp.log10(planet_surface_area)
        - jnp.log10(planet_surface_gravity)
    )


def objective_function(solution: Array, *args) -> Array:
    """Residual of the reaction network and mass balance"""
    jax.debug.print("solution in = {solution}", solution=solution)
    # Reaction network
    reaction_residual: Array = coefficient_matrix.dot(solution) - rhs

    log10_volume: Array = atmosphere_log10_volume(solution)

    # Mass balance residuals (stoichiometry coefficients are hard-coded for this MWE)
    oxygen_residual: Array = jnp.array(
        [
            solution[1],
            jnp.log10(2) + solution[2],
            jnp.log10(2) + solution[3],
            solution[5],
        ]
    )
    oxygen_residual = logsumexp_base10(oxygen_residual) - (log10_oxygen_constraint - log10_volume)

    hydrogen_residual: Array = jnp.array(
        [jnp.log10(2) + solution[0], jnp.log10(2) + solution[1], jnp.log10(4) + solution[4]]
    )
    hydrogen_residual = logsumexp_base10(hydrogen_residual) - (
        log10_hydrogen_constraint - log10_volume
    )

    carbon_residual: Array = jnp.array([solution[2], solution[4], solution[5]])
    carbon_residual = logsumexp_base10(carbon_residual) - (log10_carbon_constraint - log10_volume)

    residual: Array = jnp.concatenate(
        (
            reaction_residual,
            jnp.array([oxygen_residual]),
            jnp.array([hydrogen_residual]),
            jnp.array([carbon_residual]),
        )
    )

    jax.debug.print("residual out = {residual}", residual=residual)

    return residual


def logsumexp_base10(log_values: Array, prefactors: ArrayLike = 1) -> Array:
    max_log: Array = jnp.max(log_values)
    prefactors_: Array = jnp.asarray(prefactors)

    return max_log + jnp.log10(jnp.sum(prefactors_ * jnp.power(10, log_values - max_log)))


if __name__ == "__main__":

    # Solving with scipy and a numerical Jacobian in 54 steps
    # solve_with_scipy(jacobian=False)

    # Solving with scipy and a JAX provided Jacobian in 30 steps
    # solve_with_scipy(jacobian=True)

    # Solving with Optimistix Dogleg in 157 steps
    # solve_with_optimistix(method="Dogleg")

    # Solving with Optimistix Newton fails
    solve_with_optimistix(method="Newton")

@djbower
Copy link
Author

djbower commented Aug 17, 2024

I should add, although Dogleg solves, the solution again blows up at the beginning similar to Newton (but then recovers):

(.venv) (base) dan@Dans-MBP tests % ./simple_CHO_low_temperature.py
Solving with Optimistix
solution in = [26. 26. 26. 26. 26. 26.]
residual out = [ 52.20882135 -11.52502446   2.60645764   0.10997     -1.14325911
  -0.49309319]
solution in = [26.71993707 43.1716189  39.70471743  6.69454232 19.76625234 20.64657795]
residual out = [3.55271368e-15 2.84217094e-14 3.55271368e-15 1.66161417e+01
 1.55387075e+01 1.28469107e+01]
solution in = [ 3.05567524e+16  2.55958745e+01  2.58966937e+01 -6.11135048e+16
  1.22227010e+17  3.05567524e+16]
residual out = [ 1.02088214e+01 -2.03332754e+01 -1.39354236e+00  3.05567524e+16
  1.22227010e+17  1.22227010e+17]
solution in = [ 7.63918810e+15  3.56990697e+01  3.50241519e+01 -1.52783762e+16
  3.05567524e+16  7.63918810e+15]
residual out = [ 6.20882135e+00 -8.12688515e+00 -1.39354236e+00  7.63918810e+15
  3.05567524e+16  3.05567524e+16]
solution in = [ 1.90979703e+15  3.82248684e+01  3.73060164e+01 -3.81959405e+15
  7.63918810e+15  1.90979703e+15]

@patrick-kidger
Copy link
Owner

Right, sorry for the delay, I'm just getting back around to this now.
Can you help by removing the extraneous pieces in this MWE? Right now you still have a lot of problem-specific stuff in this example -- temperature, planet_surface_gravity, log10_oxygen_constraint etc.

I'd like to help debug this for you but it's much harder to do so when the example is this large -- when it contains so many moving pieces that won't be required to reproduce the issue.

@djbower
Copy link
Author

djbower commented Sep 2, 2024

I've been working on improving my main code so I might be able to provide a better example, or even a suite of examples, in the near future. I'll get back to you, and I appreciate the response.

@djbower
Copy link
Author

djbower commented Dec 19, 2024

I've improved the code on a range of issues since this post, notably scaling, boundedness, and improved initial guesses. This issue is now superseded by a more targeted issue of implementing bounded solvers.

@djbower djbower closed this as completed Dec 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants