patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.32k stars 122 forks source link

Paper - Correcting auto-differentiation in neural-ODE training #265

Open adam-hartshorne opened 1 year ago

adam-hartshorne commented 1 year ago

This may be of interest.

Does the use of auto-differentiation yield reasonable updates to deep neural networks that represent neural ODEs? Through mathematical analysis and numerical evidence, we find that when the neural network employs high-order forms to approximate the underlying ODE flows (such as the Linear Multistep Method (LMM)), brute-force computation using auto-differentiation often produces non-converging artificial oscillations. In the case of Leapfrog, we propose a straightforward post-processing technique that effectively eliminates these oscillations, rectifies the gradient computation and thus respects the updates of the underlying flow.

https://arxiv.org/abs/2306.02192

patrick-kidger commented 1 year ago

This is a really nice paper -- thank you for sharing it!

At least in the weight-tied regime (=common for "true" neural ODEs, uncommon for ResNets) this seems to suggest that we can halve the cost of backpropagation when using Leapfrog -- we only really need to track the branch with non-negligible gradients.

I wonder how this problem would generalise to other linear multistep methods -- most practical usage of autodiff+diffeqsolves has involved Runge--Kutta methods.

adam-hartshorne commented 1 year ago

They state on p3,

Fundamentally, the oscillation arises from the fact that Leapfrog calls for two history steps, while in contrast, mismatch used in the back-propagation in auto-differentiation can provide only one value at the final step t = 1. As a consequence, the chain-rule-type gradient computation provides a non-physical assignment to the final two data points as the initialization for the backpropagation. We provide calculations only for Leapfrog, but the same phenomenon is expected to hold for general LMM type neural ODEs.

But as far as I can see, no more information is given.