Open nwlambert opened 10 months ago
Yup, I'm afraid this is expected. RecusiveCheckpointAdjoint
does some smart things under-the-hood to be very efficient when computing specifically first-order gradients, but unfortunately this also makes it incompatible with certain kinds of higher-order autodiff.
First of all, when looking to compute the Hessian, it is usually more efficient to use forward-over-reverse rather than reverse-over-reverse (and indeed this is what jax.hessian
does). RecursiveCheckpointAdjoint
should actually be compatible with that in most cases.
But nonetheless, in the general case, using DirectAdjoint
is indeed the appropriate fix. (And handling edge cases like this is the reason it exists,)
You might also like the example on second-order sensitivies from the documentation.
I'm going to tag this under "refactor" as this could probably do with a more informative error message.
Thanks for the quick reply, I missed that documentation, it was very helpful.
Playing around a bit with a more complex example I am struggling with, I see what you mean... doing forward-over-reverse with RecusiveCheckpointAdjoint() works and seems both faster and more memory efficient than using DirectAdjoint(), so that was extremely useful! thanks!
I am a total beginner with Jax and diffrax, not sure if this is a bug or expected, but if i try to find the second or higher derivative of a solution from diffeqsolve() I get an error. Changing the adjoint to DirectAdjoint() seems to fix the problem.
Minimal working example (using the default ODE example from the diffrax introduction):
The error returned is: "print("found 2nd deriative ", d2rhozdz(z)) #fails with default adjoint ^^^^^^^^^^^ ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop."