patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
333 stars 14 forks source link

Damped Newton solve(?) / Scipy and Optimistix #75

Open djbower opened 3 months ago

djbower commented 3 months ago

Similar to #74, I'm in the process of swapping in optimistix for a previous solver. The system I'm solving is a relatively small chemical network with mass balance (so think law of mass action plus conservation of moles/mass of elements). I was previously using scipy root finder and then added JAX to provide the Jacobian as a callable for the scipy root finder. This approach generally works well and seems decently robust (I've had success with LM, Newton).

I'm now swapping in optimistix and I'm finding a challenge with solving the same system. In brief, using the Chord solver works, but trying to use Newton (or even Dogleg or LM) blows up the solution guess after a handful of iterations, causing an infinity due to summation and then NaN due to log-ing further downstream. When I monitor the scipy solver with the JAX-provided Jacobian the progression to the solution is smooth and seemingly well-behaved. So I assume somewhere along the line an erroneous step is being taken by Optimistix which cannot be recovered from. I noticed the arxiv paper has a comparison with scipy in terms of performance, so I'm wondering if you had any insights as to which parameters/options I could tune to improve the behaviour of the optimistix solver (or maybe bring it in line with the Scipy solver for side-by-side comparison?). Unfortunately my code is wrapped up in a larger package so it's difficult for me to provide a MWE at the present time. Nevertheless, any pointers are much appreciated (and thanks for developing so many excellent JAX packages and making them available).

patrick-kidger commented 3 months ago

I think I would need a MWE to see what's going on for you, I'm afraid! (If it means anything, I have used Optimistix as part of modelling chemical networks, successfully.)

Probably a good first step would be to place some jax.debug.{print,breakpoint}s inside the solver and check how the iterations are proceeding -- to figure out why they are blowing up.

djbower commented 3 months ago

A minimum working example is below. You can comment/uncomment the relevant solvers in the main guard to see the performance comparison between SciPy and Optimistix. The system is formulated in terms of log10 (molecule) number densities, but since the elemental mass balance requires a summation I'm using logsumexp to try and retain precision. Maybe you can identify a better way to scale the problem to close the gap between the Scipy and Optimistix performance. The crux of the problem is that the Optimistix Newton solver blows up the solution which causes an NaN:

Error below (reproduce by running the script):

Solving with Optimistix
solution in = [26. 26. 26. 26. 26. 26.]
residual out = [ 52.20882135 -11.52502446   2.60645764   0.10997     -1.14325911
  -0.49309319]
solution in = [26.71993707 43.1716189  39.70471743  6.69454232 19.76625234 20.64657795]
residual out = [3.55271368e-15 2.84217094e-14 3.55271368e-15 1.66161417e+01
 1.55387075e+01 1.28469107e+01]
solution in = [ 3.05567524e+16  2.55958745e+01  2.58966937e+01 -6.11135048e+16
  1.22227010e+17  3.05567524e+16]
residual out = [ 1.02088214e+01 -2.03332754e+01 -1.39354236e+00  3.05567524e+16
  1.22227010e+17  1.22227010e+17]
equinox._errors.EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the ...

Minimum working example:

#!/usr/bin/env python
"""Minimum working example (MWE) for comparing Scipy and Optimistix solvers
"""
from typing import Callable

import jax
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
import optimistix as optx
from jax import Array
from jax.typing import ArrayLike
from scipy.constants import gas_constant
from scipy.optimize import OptimizeResult, root

# Scipy also fails if this is commented out. Evidently double precision is required regardless.
jax.config.update("jax_enable_x64", True)

# This should be kept at this temperature since the equilibrium constants for the reactions below
# are hard-coded for this temperature
temperature: float = 450
planet_surface_area: float = 510064471909788.25  # SI units
planet_surface_gravity: float = 9.819973426224687  # SI units

# MWE for reaction network / mass balance
# Species order is: H2, H2O, CO2, O2, CH4, CO

# Species molar masses in kg/mol
molar_masses_dict: dict[str, float] = {
    "H2": 0.002015882,
    "H2O": 0.018015287,
    "CO2": 0.044009549999999995,
    "O2": 0.031998809999999996,
    "CH4": 0.016042504000000003,
    "CO": 0.028010145,
}

# Element log10 number of total molecules constraints:
log10_oxygen_constraint: float = 45.58848007858896
log10_hydrogen_constraint: float = 46.96664792007732
log10_carbon_constraint: float = 45.89051326565627

# Initial solution guess number density (molecules/m^3)
initial_solution: Array = jnp.array([26, 26, 26, 26, 26, 26], dtype=jnp.float_)

# Reaction set is linearly independent (determined by Gaussian elimination in a previous step)
# log10 equilibrium constants
# 2 H2O = 2 H2 + 1 O2
reaction0_log10Kc: float = -26.208821352166428
# 4 H2 + 1 CO2 = 2 H2O + 1 CH4
reaction1_log10Kc: float = -40.474975543925524
# 1 H2 + 1 O2 = 1 H2O + 1 CO
reaction2_log10Kc: float = -2.6064576440642178

# Coefficient matrix (reaction stoichiometry)
# Columns correspond to species: H2, H2O, CO, CO2, CH4, O2
# Rows refer to reactions (three in total)
coefficient_matrix: Array = jnp.array(
    [
        [2.0, -2.0, 0.0, 1.0, 0.0, 0.0],
        [-4.0, 2.0, -1.0, 0.0, 1.0, 0.0],
        [-1.0, 1.0, -1.0, 0.0, 0.0, 1.0],
    ]
)

# rhs constraints are the equilibrium constants of the reaction
rhs: Array = jnp.array([reaction0_log10Kc, reaction1_log10Kc, reaction2_log10Kc])

# For testing solvers, this is the known solution of the system
known_solution: dict[str, float] = {
    "H2": 26.950804260065272,
    "H2O": 26.109794057030303,
    "CO2": 11.303173861822636,
    "O2": -27.890841758236377,
    "CH4": 26.411827244097612,
    "CO": 9.537726420793389,
}

known_solution_array: npt.NDArray[np.float_] = np.array([val for val in known_solution.values()])

def solve_with_scipy(jacobian: bool = True) -> None:
    """Solve the system with Scipy"""

    if jacobian:
        jacobian_function: Callable | None = jax.jacobian(objective_function)
    else:
        jacobian_function = None

    print("Solving with SciPy")
    sol: OptimizeResult = root(objective_function, initial_solution, jac=jacobian_function)

    if sol.success and np.isclose(sol.x, known_solution_array).all():
        print("SciPy success and agrees with known solution. Steps = %d" % sol["nfev"])

    print(sol)

def solve_with_optimistix(method="Dogleg", tol: float = 1.0e-8) -> None:
    """Solve the system with Optimistix"""

    if method == "Dogleg":
        solver = optx.Dogleg(atol=tol, rtol=tol)
    elif method == "Newton":
        solver = optx.Newton(atol=tol, rtol=tol)

    print("Solving with Optimistix")
    sol = optx.root_find(
        objective_function,
        solver,
        initial_solution,
        throw=True,
    )

    if optx.RESULTS[sol.result] == "" and np.isclose(sol.value, known_solution_array).all():
        print(
            "Optimistix success and agrees with known solution. Steps = %d"
            % sol.stats["num_steps"]
        )

def atmosphere_log10_molar_mass(solution: Array) -> Array:
    """Log10 of the molar mass of the atmosphere"""
    molar_masses: Array = jnp.array([value for value in molar_masses_dict.values()])
    molar_mass: Array = logsumexp_base10(solution, molar_masses) - logsumexp_base10(solution)

    return molar_mass

def atmosphere_log10_volume(solution: Array) -> Array:
    """Log10 of the volume of the atmosphere"""
    return (
        jnp.log10(gas_constant)
        + jnp.log10(temperature)
        - atmosphere_log10_molar_mass(solution)
        + jnp.log10(planet_surface_area)
        - jnp.log10(planet_surface_gravity)
    )

def objective_function(solution: Array, *args) -> Array:
    """Residual of the reaction network and mass balance"""
    jax.debug.print("solution in = {solution}", solution=solution)
    # Reaction network
    reaction_residual: Array = coefficient_matrix.dot(solution) - rhs

    log10_volume: Array = atmosphere_log10_volume(solution)

    # Mass balance residuals (stoichiometry coefficients are hard-coded for this MWE)
    oxygen_residual: Array = jnp.array(
        [
            solution[1],
            jnp.log10(2) + solution[2],
            jnp.log10(2) + solution[3],
            solution[5],
        ]
    )
    oxygen_residual = logsumexp_base10(oxygen_residual) - (log10_oxygen_constraint - log10_volume)

    hydrogen_residual: Array = jnp.array(
        [jnp.log10(2) + solution[0], jnp.log10(2) + solution[1], jnp.log10(4) + solution[4]]
    )
    hydrogen_residual = logsumexp_base10(hydrogen_residual) - (
        log10_hydrogen_constraint - log10_volume
    )

    carbon_residual: Array = jnp.array([solution[2], solution[4], solution[5]])
    carbon_residual = logsumexp_base10(carbon_residual) - (log10_carbon_constraint - log10_volume)

    residual: Array = jnp.concatenate(
        (
            reaction_residual,
            jnp.array([oxygen_residual]),
            jnp.array([hydrogen_residual]),
            jnp.array([carbon_residual]),
        )
    )

    jax.debug.print("residual out = {residual}", residual=residual)

    return residual

def logsumexp_base10(log_values: Array, prefactors: ArrayLike = 1) -> Array:
    max_log: Array = jnp.max(log_values)
    prefactors_: Array = jnp.asarray(prefactors)

    return max_log + jnp.log10(jnp.sum(prefactors_ * jnp.power(10, log_values - max_log)))

if __name__ == "__main__":

    # Solving with scipy and a numerical Jacobian in 54 steps
    # solve_with_scipy(jacobian=False)

    # Solving with scipy and a JAX provided Jacobian in 30 steps
    # solve_with_scipy(jacobian=True)

    # Solving with Optimistix Dogleg in 157 steps
    # solve_with_optimistix(method="Dogleg")

    # Solving with Optimistix Newton fails
    solve_with_optimistix(method="Newton")
djbower commented 3 months ago

I should add, although Dogleg solves, the solution again blows up at the beginning similar to Newton (but then recovers):

(.venv) (base) dan@Dans-MBP tests % ./simple_CHO_low_temperature.py
Solving with Optimistix
solution in = [26. 26. 26. 26. 26. 26.]
residual out = [ 52.20882135 -11.52502446   2.60645764   0.10997     -1.14325911
  -0.49309319]
solution in = [26.71993707 43.1716189  39.70471743  6.69454232 19.76625234 20.64657795]
residual out = [3.55271368e-15 2.84217094e-14 3.55271368e-15 1.66161417e+01
 1.55387075e+01 1.28469107e+01]
solution in = [ 3.05567524e+16  2.55958745e+01  2.58966937e+01 -6.11135048e+16
  1.22227010e+17  3.05567524e+16]
residual out = [ 1.02088214e+01 -2.03332754e+01 -1.39354236e+00  3.05567524e+16
  1.22227010e+17  1.22227010e+17]
solution in = [ 7.63918810e+15  3.56990697e+01  3.50241519e+01 -1.52783762e+16
  3.05567524e+16  7.63918810e+15]
residual out = [ 6.20882135e+00 -8.12688515e+00 -1.39354236e+00  7.63918810e+15
  3.05567524e+16  3.05567524e+16]
solution in = [ 1.90979703e+15  3.82248684e+01  3.73060164e+01 -3.81959405e+15
  7.63918810e+15  1.90979703e+15]
patrick-kidger commented 2 months ago

Right, sorry for the delay, I'm just getting back around to this now. Can you help by removing the extraneous pieces in this MWE? Right now you still have a lot of problem-specific stuff in this example -- temperature, planet_surface_gravity, log10_oxygen_constraint etc.

I'd like to help debug this for you but it's much harder to do so when the example is this large -- when it contains so many moving pieces that won't be required to reproduce the issue.

djbower commented 2 months ago

I've been working on improving my main code so I might be able to provide a better example, or even a suite of examples, in the near future. I'll get back to you, and I appreciate the response.