patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
333 stars 14 forks source link

Incompatibility of least_squares and custom_vjp #50

Open ahwillia opened 8 months ago

ahwillia commented 8 months ago

I'm running into some trouble applying optimistix.least_squares(fn, LevenbergMarquardt(...), x0) to certain problems. From the error message below, my understanding of the root cause is that forward-mode autodiff cannot be used on jax.custom_vjp. In my case I am using diffrax to solve an ODE within fn(...), which I think might be causing the problem.

Is my basic understanding correct? Are there specific constraints / assumptions that fn(...) must follow for optimistix.least_squares to work (e.g. cannot use jax.custom_vjp)? Is there any way around this?

The error I get is:

TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

The full code to reproduce the error is below. By the way I get the same problem when trying to use jaxopt.LevenbergMarquardt on this problem.

# === imports === #
from jax.config import config; config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import diffrax
from diffrax import ODETerm, Dopri5, SaveAt
from tqdm import trange
import optimistix
from optimistix import LevenbergMarquardt

# === functions defining flow field and residuals === #
def geodesic_vector_field(P):
    jacP = jax.jacobian(P)
    def vector_field(t, state, args):
        x, v = state
        Pdx = jacP(x)
        q1 = 0.5 * jnp.einsum("jki,j,k->i",Pdx, v, v)
        q2 = jnp.einsum("ilp,l,p->i", Pdx, v, v)
        dxdt = v
        dvdt = jnp.linalg.solve(P(x), q1 + q2)
        return (dxdt, dvdt)
    return vector_field

def exponential_map(x0, v0, term, solver):
    return diffrax.diffeqsolve(
        term, solver, t0=0, t1=1, dt0=0.1, y0=(x0, v0),
        saveat=SaveAt(t0=False, t1=True)
    ).ys[0].ravel()

def shooting_method_resids(x0, x1, term, solver):
    return jax.jit(
        lambda v0, args: (x1 - exponential_map(x0, v0, term, solver)).ravel()
    )

# === try solving the boundary value problem === #
term = ODETerm(geodesic_vector_field(lambda x: jnp.eye(2)))
solver = Dopri5()

optimistix.least_squares(
    shooting_method_resids(jnp.zeros(2), jnp.ones(2), term, solver),
    LevenbergMarquardt(1e-3, 1e-3),
    -1 * jnp.ones(2)
)
patrick-kidger commented 8 months ago

Yup, you're completely correct in your diagnosis: Diffrax has a jax.custom_vjp for the autodifferentiation through diffeqsolve, and this doesn't support forward-mode autodiff, which is what is used by optx.LevenbergMarquardt to compute its Jacobians.

We have essentially two possible fixes: offer a way for Diffrax to use forward-mode autodifferentiation, or offer a way for Optimistix to use reverse-mode.

For now I've just added the latter. in #51. Try using Optimistix from that branch and see if it solves your problem! You'll need to pass optx.least_squares(..., options=dict(jac="bwd")).

(I'd like to add better forward-mode support for Diffrax, but the best way of doing this is really dependent on JAX just adding directly support for jvp-of-custom_vjp, which I have a draft of here but still seems to be buggy, so I haven't gotten around to finishing it.)

ahwillia commented 8 months ago

Amazing, works as intended (at least for the simple example I've tried)!