patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
334 stars 14 forks source link

Non-finite values in the root function are not handled well #44

Closed FFroehlich closed 9 months ago

FFroehlich commented 9 months ago

Running into NaN values during rootfinding in optimistixs errors rather than decreasing step size. This should probably also be fixed in lineax, which does not appear to check inputs (vector, and potentially operators) for non-finiteness.

Can be reproduced using the following code:

import jax
import jax.numpy as jnp
import jax.random as jrandom
import optimistix as optx
import equinox as eqx
import lineax as lx

class Model(eqx.Module):
    atol: float = eqx.static_field()
    rtol: float = eqx.static_field()
    maxsteps: int = eqx.static_field()

    def __init__(self):
        self.atol = 1e-6
        self.rtol = 1e-4
        self.maxsteps = int(1e5)

    def xdot(self, _, x, __, p, k):
        dxdt = jnp.array(
            [
                jnp.exp(p[2] - x[0] - x[1] + p[0]) - jnp.exp(p[0]),
                jnp.exp(k[0] - x[1] + p[1]) - jnp.exp(p[1]),
            ]
        )
        jax.debug.print(
            "t: {t}, x: {x}, dxdt: {dxdt}", t=_, x=x, dxdt=dxdt, ordered=True
        )
        return dxdt

    @jax.value_and_grad
    def loss(self, p):
        mapped_simulate = jax.vmap(
            self.simulate,
            in_axes=(None, 1, 1, 2),
        )
        n_repeat = 1

        k_range_ss = 5
        k_sss = (
            jrandom.uniform(jrandom.PRNGKey(1), shape=(1, n_repeat))
            * k_range_ss
            * 2
            - k_range_ss
        )
        y0s = jrandom.uniform(jrandom.PRNGKey(0), shape=(2, n_repeat))
        yms = jrandom.uniform(
            jrandom.PRNGKey(0),
            shape=(1, 2, n_repeat),
        )
        r = mapped_simulate(p, k_sss, y0s, yms)
        return jnp.sqrt(jnp.mean(jnp.square(r)))

    def simulate(
        self,
        p: jnp.ndarray,
        ts: jnp.ndarray,
        k_ss: jnp.ndarray,
        k_sim: jnp.ndarray,
        y0: jnp.ndarray,
        ym: jnp.ndarray,
    ):
        xdot_ss = eqx.Partial(self.xdot, 0.0, k=k_ss, p=p)

        solver_ss = optx.Newton(
            rtol=self.rtol,
            atol=self.atol,
            linear_solver=lx.AutoLinearSolver(well_posed=None),
        )

        sol_ss = optx.root_find(
            fn=xdot_ss,
            y0=y0,
            solver=solver_ss,
            max_steps=self.maxsteps,
        )

        return sol_ss.value - ym

if __name__ == "__main__":
    model = Model()
    p = jnp.array([-50, 50, 1.0])
    print(model.loss(p))

produces

Traceback (most recent call last):
  File ".../test.py", line 57, in loss
    r = mapped_simulate(p, k_sss, y0s, yms)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../test.py", line 75, in simulate
    sol_ss = optx.root_find(
             ^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.
If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.
If you *were* expecting this solver to work with this operator, then it may be because:
(a) the operator is singular, and your code has a bug; or
(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or
(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.
-------
This error occurred during the runtime of your JAX program. Setting the environment
variable `EQX_ON_ERROR=breakpoint` is usually the most useful way to debug such errors.
(This can be navigated using most of the usual commands for the Python debugger:
`u` and `d` to move through stack frames, the name of a variable to print its value,
etc.) See also `https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more
information.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
patrick-kidger commented 9 months ago

I'm afraid this example doesn't run: TypeError: Model.simulate() missing 2 required positional arguments: 'y0' and 'ym'.

For something like this, would you be able to write a fairly short (10 lines?) MWE demonstrating the sort of behaviour you're after?

FFroehlich commented 9 months ago

The implementation of Newtons method doesn't actually implement line search and any implementation is bound to run into the same problems as https://github.com/patrick-kidger/diffrax/issues/368, so avoiding non-finite values in the first place is the way to go.