Open adam-hartshorne opened 1 year ago
This is a really nice paper -- thank you for sharing it!
At least in the weight-tied regime (=common for "true" neural ODEs, uncommon for ResNets) this seems to suggest that we can halve the cost of backpropagation when using Leapfrog -- we only really need to track the branch with non-negligible gradients.
I wonder how this problem would generalise to other linear multistep methods -- most practical usage of autodiff+diffeqsolves has involved Runge--Kutta methods.
They state on p3,
Fundamentally, the oscillation arises from the fact that Leapfrog calls for two history steps, while in contrast, mismatch used in the back-propagation in auto-differentiation can provide only one value at the final step t = 1. As a consequence, the chain-rule-type gradient computation provides a non-physical assignment to the final two data points as the initialization for the backpropagation. We provide calculations only for Leapfrog, but the same phenomenon is expected to hold for general LMM type neural ODEs.
But as far as I can see, no more information is given.
This may be of interest.
https://arxiv.org/abs/2306.02192