Open ddrous opened 1 month ago
You're actually bumping into something that I think is a bit of an open reseach problem. :) Namely, how to do second-order autodifferentiation whilst using checkpointing! In particular what you're seeing here is that the backward pass for RecursiveCheckpointAdjoint
is not itself reverse-mode autodifferentiable.
I do note that alpha
appears to be a scalar. I've not thought through every detail, but for such cases it usually more efficient to use jax.jvp
to perform forward-mode autodifferentiation instead. Typically the goal is to frame the computation as a jvp-of-grad-of-loss. (Such 'forward over reverse' is usually most efficient overall.) This may allow you to sidestep this problem.
Failing that, then ysing DirectAdjoint
is probably the best option available here.
Thank you @patrick-kidger It helps to know what the real problem is. Looking forward to any research/development on this in the future.
Using JVPs is not really an option for me since my parameters are themselves neural nets (I turned alpha into a scalar just for the purpose of a MWE). So looks like I'm gonna have to use Directdjoint()
even-though I can barely handle its memory requirements (this after tweaking max_steps
).
Hi all, I've edited the introductory Neural ODE example to highlight a problem I'm facing with two-level optimisation: first (outer) level wrt the
model
, and second (inner) level wrt a parameteralpha
. JAX throws aJaxStackTraceBeforeTransformation
error if I useRecursiveCheckpointAdjoint
, but everything runs if I useDirectAdjoint
instead. In line with the recommendations in the documentation, I'd love to use the former adjoint rule. Please help, Thanks.