rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.62k stars 932 forks source link

Adaptive Computation #62

Closed ghost closed 5 years ago

ghost commented 5 years ago

I wanted to ask you two questions about adaptive computation :)

  1. Could you clarify whether the following is how neural ODEs adapt the amount of computation? Judging by the following code snippet in "odenet_mnist.py", ODE solvers seem to adapt the amount of computation, not by adapting the terminal "time" t1 (as a function of the input), but by adapting the number of and magnitude of numerical integration steps from t_0 to a fixed terminal time t_1.

https://github.com/rtqichen/torchdiffeq/blob/a493d9adbe87624d419bbc57678a6ed5654e2a86/examples/odenet_mnist.py#L121

  1. For discrete-depth networks such as RNNs or Resnet, adapting computation is done via secondary neural networks that decide the number of evaluations by stopping at a depth where the features become "good enough" to perform the downstream task e.g. classifcation or segmentation. In other words, adapting the number of evaluations in this context affects the classification or segmentation accuracy. On the other hand, In the ODE context, adaptive computation seems to affect the accuracy of the numerical integration rather than the accuracy on the downstream classification or segmentation task. This got me wondering whether the "easy" inputs, for which it only takes a small number of Resnet evaluations for its features to become "good enough" for a given downstream task, would also require only a small number of integration steps in a neural ODE for the hidden state to become "good enough" for the same downstream task. If so, could there possibly be a deeper link between the two different forms of adaptive computation that contributes to this? Would love to get your thoughts on this.
rtqichen commented 5 years ago

To clarify: yes, you're right that the terminal time is fixed and the adaptive compute comes from the ODE solver itself choosing different step sizes to adapt to the complexity of the ODE.

I think you can think of the solver of an ODE as doing some form of approach 2, except that the residual block's parameters are itself defined based on a real-valued "depth" or some measure of progress. It's just that how to adapt step size in order to precisely solve an ODE is known, and you can think of the ODE solver as a greedy algorithm that tries to minimize the difference between the predict value and the true value at t_1. Instead, typically approaches in 2 might use RL or some form of indirect supervision that aims to directly minimize some loss function, but this comes at the cost of lack of interpretability when it comes to "why did it choose this amount of compute"?

There definitely should be a simple way to marry these two approach, and it could be interesting to use techniques from 2 to solve 1. However, note that any via approach would have to be not only efficient but also sufficiently robust, because training the ODE relies on solving it correctly.