Closed ghost closed 5 years ago
To clarify: yes, you're right that the terminal time is fixed and the adaptive compute comes from the ODE solver itself choosing different step sizes to adapt to the complexity of the ODE.
I think you can think of the solver of an ODE as doing some form of approach 2, except that the residual block's parameters are itself defined based on a real-valued "depth" or some measure of progress. It's just that how to adapt step size in order to precisely solve an ODE is known, and you can think of the ODE solver as a greedy algorithm that tries to minimize the difference between the predict value and the true value at t_1. Instead, typically approaches in 2 might use RL or some form of indirect supervision that aims to directly minimize some loss function, but this comes at the cost of lack of interpretability when it comes to "why did it choose this amount of compute"?
There definitely should be a simple way to marry these two approach, and it could be interesting to use techniques from 2 to solve 1. However, note that any via approach would have to be not only efficient but also sufficiently robust, because training the ODE relies on solving it correctly.
I wanted to ask you two questions about adaptive computation :)
https://github.com/rtqichen/torchdiffeq/blob/a493d9adbe87624d419bbc57678a6ed5654e2a86/examples/odenet_mnist.py#L121