patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
265 stars 12 forks source link

AbstractGaussNewton now supports reverse-autodiff for Jacobians. #51

Closed patrick-kidger closed 2 months ago

patrick-kidger commented 3 months ago

In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a jax.custom_vjp, see https://github.com/patrick-kidger/optimistix/issues/50

johannahaffner commented 2 months ago

Hi Patrick,

I tried this branch, and it did not work for my use case (parameter estimation for an ODE). I tried both LevenbergMarquardt and GaussNewton in combination with DirectAdjoint or RecursiveCheckpointAdjoint, and I get

TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

with both. I have an MWE if that would be useful for you.

During the install, this is the commit it resolved to:

Switched to a new branch 'gauss-newton-jacrev'
branch 'gauss-newton-jacrev' set up to track 'origin/gauss-newton-jacrev'.
Resolved https://github.com/patrick-kidger/optimistix to commit 776820485fd9df320d3089bcd302f2f69124cf14

Hope you had a nice weekend!

Johanna

johannahaffner commented 2 months ago

Nevermind, it works! I checked out #50 and realized that I need to pass options=dict(jac="bwd") to least_squares, now it works.

patrick-kidger commented 2 months ago

Awesome, I'm glad to hear it! I've just merged this in, so this will appear in the next release of Optimistix. :)

johannahaffner commented 1 month ago

Hi! Not sure where else to post this, but I wanted to note it somewhere: on my problem, I get dramatic performance reductions when using reverse-mode autodiff inside of least_squares.

Specifically:

The documentation does note that forward-mode is usually more efficient, so this might just be further confirmation of that observation :)

The solver I used is LevenbergMarquardt(atol=1e-06, rtol=1e-03).

patrick-kidger commented 1 month ago

Hey there!

Ah, indeed this might have been the case. It was hard to guess which was the greater overhead: DirectAdjoint + forward mode, or RecursiveCheckpointAdjoint + reverse mode.

FWIW this is motivating me to consider adding a forward-mode specific "adjoint". (I'd been holding off on this in the hopes that support could be added to RecursiveCheckpointAdjoint but that's something which requires changes in JAX itself.) Let me see if I can throw something together and we can see how it runs.

patrick-kidger commented 1 month ago

Okay, completely untested, but something like the following should probably work:

class ForwardMode(diffrax.AbstractAdjoint):
    def loop(
        self,
        *,
        solver,
        throw,
        passed_solver_state,
        passed_controller_state,
        **kwargs,
    ):
        del throw, passed_solver_state, passed_controller_state
        inner_while_loop = functools.partial(diffrax._adjoint._inner_loop, kind="lax")
        outer_while_loop = ftunctoolspartial(diffrax._adjoint._outer_loop, kind="lax")
        # Support forward-mode autodiff.
        # TODO: remove this hack once we can JVP through custom_vjps.
        if isinstance(solver, diffrax.AbstractRungeKutta) and solver.scan_kind is None:
            solver = eqx.tree_at(
                lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none
            )
        final_state = self._loop(
            solver=solver,
            inner_while_loop=inner_while_loop,
            outer_while_loop=outer_while_loop,
            **kwargs,
        )
        return final_state

by passing diffeqsolve(..., adjoint=ForwardMode()). Then the resulting diffeqsolve should be forward-mode autodifferentiable.

Let me know how this goes!

johannahaffner commented 1 month ago

Awesome, thank you! I'm happy to give it a proper go tomorrow. So far I am getting this error

.../site-packages/diffrax/_integrate.py#line=457), in loop()
    456     return new_state
--> 458 final_state = outer_while_loop(
    459     cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
    460 )
    462 def _save_t1(subsaveat, save_state):

.../contextlib.py#line=80), in inner()
     80 with self._recreate_cm():
---> 81     return func(*args, **kwds)

.../site-packages/equinox/internal/_loop/loop.py#line=102), in while_loop()
    102 del cond_fun, body_fun, init_val
--> 103 _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
    104 return final_val

JaxStackTraceBeforeTransformation: ValueError: Reverse-mode differentiation does not work for 
lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using 
fori_loop with static start/stop.

I'm a little unsure where the reverse-mode differentiation is coming from here (this is with options=dict(jac="fwd") in least_squares). But also happy to dig into it a little :)

BTW, I am assuming that _is_none = lambda x: x is None, as given as an example here.

johannahaffner commented 1 month ago

Ok, mini update: I haven't been able to dig into it any further, but I did get better benchmarks on my problem. I use jit(vmap(...)) to parallelize over the individuals, so the runtime is determined by the trajectories that take longest to fit. The sum of the maximum number of steps taken in each iteration works out to a value that is quite close to the total runtime, divided by the time it takes to simulate the data. (376 vs. 380.)

It looks like forward-mode differentiation on DirectAdjoint is already super fast, at least as far as optimistix is concerned. So unless the ForwardModeAdjoint also speeds up the ODE solving, I would not expect too many performance gains to come from this.

patrick-kidger commented 1 month ago

So unless the ForwardModeAdjoint also speeds up the ODE solving

Indeed it does! (Granted, possibly not by that much. Depends on your problem.)

Do you have a quick MWE demonstrating the above crash?

johannahaffner commented 1 month ago

Ah even faster ODEs! Nice, that would be cool. I tried this on my real data, will make an MWE after the weekend!

johannahaffner commented 1 month ago

Here comes the MWE! The adjoint works fine in diffrax. But optimistix attempts to reverse-mode differentiate through it.

import jax.numpy as jnp

import equinox as eqx
import diffrax
import optimistix

import functools

_is_none = lambda x: x is None  # This function is needed in the forward adjoint

class ForwardMode(diffrax.AbstractAdjoint):
    def loop(
        self,
        *,
        solver,
        throw,
        passed_solver_state,
        passed_controller_state,
        **kwargs,
    ):
        del throw, passed_solver_state, passed_controller_state
        inner_while_loop = functools.partial(diffrax._adjoint._inner_loop, kind="lax")
        outer_while_loop = functools.partial(diffrax._adjoint._outer_loop, kind="lax") 
        # Support forward-mode autodiff.
        # TODO: remove this hack once we can JVP through custom_vjps.
        if isinstance(solver, diffrax.AbstractRungeKutta) and solver.scan_kind is None:
            solver = eqx.tree_at(
                lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none
            )
        final_state = self._loop(
            solver=solver,
            inner_while_loop=inner_while_loop,
            outer_while_loop=outer_while_loop,
            **kwargs,
        )
        return final_state

class ToyModel(eqx.Module):
    """Toy model that provides a simple interface to generate an ODE solution,
    subject to its parameters.
    """
    _term: diffrax.ODETerm
    _t0: float
    _t1: float
    _dt0: float
    _y0: float
    _saveat: diffrax.SaveAt
    _solver: diffrax.AbstractERK
    _adjoint: diffrax.AbstractAdjoint

    def __init__(self, ode_model, initial_state, times, solver, adjoint):
        self._term = diffrax.ODETerm(ode_model)
        self._y0 = initial_state

        self._t0 = times[0]
        self._t1 = times[-1]
        self._dt0 = 0.01
        self._saveat = diffrax.SaveAt(ts=times)

        self._solver = solver
        self._adjoint = adjoint

    def __call__(self, param):
        sol = diffrax.diffeqsolve(
            self._term, 
            self._solver, 
            self._t0, self._t1, self._dt0, self._y0, 
            args=param, 
            saveat=self._saveat, 
            adjoint=self._adjoint,
        )
        return sol.ys

def estimate_parameters(initial_guess, model, data, solver, solver_options: dict = dict(jac='fwd')):
    """Function that estimates the parameters."""

    def residuals(param, args):
        model, data = args
        fit = model(param)
        res = data - fit
        return res

    sol = optimistix.least_squares(
        residuals, 
        solver, 
        initial_guess,
        args = (model, data),
        options = solver_options,
    )
    return sol

# Create the model
def dydt(t, y, k):  # Toy ODE
    return - k * y
t = jnp.linspace(0, 10, 50)  
y0 = 10.
model = ToyModel(dydt, y0, t, diffrax.Tsit5(), ForwardMode())

# Solve ODE
k = 0.5  # True value
ode_solution = model(k)  # This runs without issue

# Now try solving for the parameters
k0 = 0.1
solver = optimistix.LevenbergMarquardt(atol=1e-09, rtol=1e-06)
lm_solution = estimate_parameters(k0, model, ode_solution, solver)  # This fails
patrick-kidger commented 1 month ago

Can you try https://github.com/patrick-kidger/optimistix/pull/61? Hopefully that should fix things :)

(It's also revealed an unrelated bug with complex numbers, but I don't think that should affect you.)

johannahaffner commented 1 month ago

Hi Patrick,

amazing, thank you for the quick fix! I tried it on my real data, it works a charm and delivers a handsome 3x speedup.

johannahaffner commented 1 month ago

delivers a handsome 3x speedup.

whoops, my bad, I forgot to control for whether my laptop was plugged in :)

Using DirectAdjoint is 52% slower than using ForwardMode for the parameter estimation, and solving just the ODE is 42% slower with DirectAdjoint.

Still pretty nice! I'll keep using the ForwardMode above.