Closed patrick-kidger closed 6 months ago
Hi Patrick,
I tried this branch, and it did not work for my use case (parameter estimation for an ODE). I tried both LevenbergMarquardt
and GaussNewton
in combination with DirectAdjoint
or RecursiveCheckpointAdjoint
, and I get
TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
with both. I have an MWE if that would be useful for you.
During the install, this is the commit it resolved to:
Switched to a new branch 'gauss-newton-jacrev'
branch 'gauss-newton-jacrev' set up to track 'origin/gauss-newton-jacrev'.
Resolved https://github.com/patrick-kidger/optimistix to commit 776820485fd9df320d3089bcd302f2f69124cf14
Hope you had a nice weekend!
Johanna
Nevermind, it works! I checked out #50 and realized that I need to pass options=dict(jac="bwd")
to least_squares
, now it works.
Awesome, I'm glad to hear it! I've just merged this in, so this will appear in the next release of Optimistix. :)
Hi! Not sure where else to post this, but I wanted to note it somewhere: on my problem, I get dramatic performance reductions when using reverse-mode autodiff inside of least_squares
.
Specifically:
dfx.DirectAdjoint
, I get 0.286 s per trajectory if passing options=dict(jac="fwd")
andoptions=dict(jac="bwd")
, which drops todfx.RecursiveCheckpointAdjoint
instead of DirectAdjoint
.The documentation does note that forward-mode is usually more efficient, so this might just be further confirmation of that observation :)
The solver I used is LevenbergMarquardt(atol=1e-06, rtol=1e-03)
.
Hey there!
Ah, indeed this might have been the case. It was hard to guess which was the greater overhead: DirectAdjoint + forward mode
, or RecursiveCheckpointAdjoint + reverse mode
.
FWIW this is motivating me to consider adding a forward-mode specific "adjoint". (I'd been holding off on this in the hopes that support could be added to RecursiveCheckpointAdjoint
but that's something which requires changes in JAX itself.) Let me see if I can throw something together and we can see how it runs.
Okay, completely untested, but something like the following should probably work:
class ForwardMode(diffrax.AbstractAdjoint):
def loop(
self,
*,
solver,
throw,
passed_solver_state,
passed_controller_state,
**kwargs,
):
del throw, passed_solver_state, passed_controller_state
inner_while_loop = functools.partial(diffrax._adjoint._inner_loop, kind="lax")
outer_while_loop = ftunctoolspartial(diffrax._adjoint._outer_loop, kind="lax")
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, diffrax.AbstractRungeKutta) and solver.scan_kind is None:
solver = eqx.tree_at(
lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none
)
final_state = self._loop(
solver=solver,
inner_while_loop=inner_while_loop,
outer_while_loop=outer_while_loop,
**kwargs,
)
return final_state
by passing diffeqsolve(..., adjoint=ForwardMode())
. Then the resulting diffeqsolve
should be forward-mode autodifferentiable.
Let me know how this goes!
Awesome, thank you! I'm happy to give it a proper go tomorrow. So far I am getting this error
.../site-packages/diffrax/_integrate.py#line=457), in loop()
456 return new_state
--> 458 final_state = outer_while_loop(
459 cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
460 )
462 def _save_t1(subsaveat, save_state):
.../contextlib.py#line=80), in inner()
80 with self._recreate_cm():
---> 81 return func(*args, **kwds)
.../site-packages/equinox/internal/_loop/loop.py#line=102), in while_loop()
102 del cond_fun, body_fun, init_val
--> 103 _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
104 return final_val
JaxStackTraceBeforeTransformation: ValueError: Reverse-mode differentiation does not work for
lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using
fori_loop with static start/stop.
I'm a little unsure where the reverse-mode differentiation is coming from here (this is with options=dict(jac="fwd")
in least_squares
). But also happy to dig into it a little :)
BTW, I am assuming that _is_none = lambda x: x is None
, as given as an example here.
Ok, mini update: I haven't been able to dig into it any further, but I did get better benchmarks on my problem.
I use jit(vmap(...))
to parallelize over the individuals, so the runtime is determined by the trajectories that take longest to fit.
The sum of the maximum number of steps taken in each iteration works out to a value that is quite close to the total runtime, divided by the time it takes to simulate the data. (376 vs. 380.)
It looks like forward-mode differentiation on DirectAdjoint
is already super fast, at least as far as optimistix
is concerned. So unless the ForwardModeAdjoint
also speeds up the ODE solving, I would not expect too many performance gains to come from this.
So unless the ForwardModeAdjoint also speeds up the ODE solving
Indeed it does! (Granted, possibly not by that much. Depends on your problem.)
Do you have a quick MWE demonstrating the above crash?
Ah even faster ODEs! Nice, that would be cool. I tried this on my real data, will make an MWE after the weekend!
Here comes the MWE! The adjoint works fine in diffrax
. But optimistix attempts to reverse-mode differentiate through it.
import jax.numpy as jnp
import equinox as eqx
import diffrax
import optimistix
import functools
_is_none = lambda x: x is None # This function is needed in the forward adjoint
class ForwardMode(diffrax.AbstractAdjoint):
def loop(
self,
*,
solver,
throw,
passed_solver_state,
passed_controller_state,
**kwargs,
):
del throw, passed_solver_state, passed_controller_state
inner_while_loop = functools.partial(diffrax._adjoint._inner_loop, kind="lax")
outer_while_loop = functools.partial(diffrax._adjoint._outer_loop, kind="lax")
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, diffrax.AbstractRungeKutta) and solver.scan_kind is None:
solver = eqx.tree_at(
lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none
)
final_state = self._loop(
solver=solver,
inner_while_loop=inner_while_loop,
outer_while_loop=outer_while_loop,
**kwargs,
)
return final_state
class ToyModel(eqx.Module):
"""Toy model that provides a simple interface to generate an ODE solution,
subject to its parameters.
"""
_term: diffrax.ODETerm
_t0: float
_t1: float
_dt0: float
_y0: float
_saveat: diffrax.SaveAt
_solver: diffrax.AbstractERK
_adjoint: diffrax.AbstractAdjoint
def __init__(self, ode_model, initial_state, times, solver, adjoint):
self._term = diffrax.ODETerm(ode_model)
self._y0 = initial_state
self._t0 = times[0]
self._t1 = times[-1]
self._dt0 = 0.01
self._saveat = diffrax.SaveAt(ts=times)
self._solver = solver
self._adjoint = adjoint
def __call__(self, param):
sol = diffrax.diffeqsolve(
self._term,
self._solver,
self._t0, self._t1, self._dt0, self._y0,
args=param,
saveat=self._saveat,
adjoint=self._adjoint,
)
return sol.ys
def estimate_parameters(initial_guess, model, data, solver, solver_options: dict = dict(jac='fwd')):
"""Function that estimates the parameters."""
def residuals(param, args):
model, data = args
fit = model(param)
res = data - fit
return res
sol = optimistix.least_squares(
residuals,
solver,
initial_guess,
args = (model, data),
options = solver_options,
)
return sol
# Create the model
def dydt(t, y, k): # Toy ODE
return - k * y
t = jnp.linspace(0, 10, 50)
y0 = 10.
model = ToyModel(dydt, y0, t, diffrax.Tsit5(), ForwardMode())
# Solve ODE
k = 0.5 # True value
ode_solution = model(k) # This runs without issue
# Now try solving for the parameters
k0 = 0.1
solver = optimistix.LevenbergMarquardt(atol=1e-09, rtol=1e-06)
lm_solution = estimate_parameters(k0, model, ode_solution, solver) # This fails
Can you try https://github.com/patrick-kidger/optimistix/pull/61? Hopefully that should fix things :)
(It's also revealed an unrelated bug with complex numbers, but I don't think that should affect you.)
Hi Patrick,
amazing, thank you for the quick fix! I tried it on my real data, it works a charm and delivers a handsome 3x speedup.
delivers a handsome 3x speedup.
whoops, my bad, I forgot to control for whether my laptop was plugged in :)
Using DirectAdjoint
is 52% slower than using ForwardMode
for the parameter estimation, and solving just the ODE is 42% slower with DirectAdjoint
.
Still pretty nice! I'll keep using the ForwardMode above.
In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a
jax.custom_vjp
, see https://github.com/patrick-kidger/optimistix/issues/50