Open ahwillia opened 8 months ago
Yup, you're completely correct in your diagnosis: Diffrax has a jax.custom_vjp
for the autodifferentiation through diffeqsolve
, and this doesn't support forward-mode autodiff, which is what is used by optx.LevenbergMarquardt
to compute its Jacobians.
We have essentially two possible fixes: offer a way for Diffrax to use forward-mode autodifferentiation, or offer a way for Optimistix to use reverse-mode.
For now I've just added the latter. in #51. Try using Optimistix from that branch and see if it solves your problem! You'll need to pass optx.least_squares(..., options=dict(jac="bwd"))
.
(I'd like to add better forward-mode support for Diffrax, but the best way of doing this is really dependent on JAX just adding directly support for jvp-of-custom_vjp
, which I have a draft of here but still seems to be buggy, so I haven't gotten around to finishing it.)
Amazing, works as intended (at least for the simple example I've tried)!
I'm running into some trouble applying
optimistix.least_squares(fn, LevenbergMarquardt(...), x0)
to certain problems. From the error message below, my understanding of the root cause is that forward-mode autodiff cannot be used onjax.custom_vjp
. In my case I am usingdiffrax
to solve an ODE withinfn(...)
, which I think might be causing the problem.Is my basic understanding correct? Are there specific constraints / assumptions that
fn(...)
must follow foroptimistix.least_squares
to work (e.g. cannot usejax.custom_vjp
)? Is there any way around this?The error I get is:
The full code to reproduce the error is below. By the way I get the same problem when trying to use
jaxopt.LevenbergMarquardt
on this problem.