DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.35k stars 125 forks source link

Reversible ODEs? #134

Closed StephenHogg closed 2 years ago

StephenHogg commented 2 years ago

Thanks for creating this package, I see a lot of potential uses for it. One thing I'm very interested in is making an ODE and running it both forwards and backwards through the time domain. So for instance, how would I alter (or use) the model in the quickstart notebook to achieve this? Specifically, if I wanted that model to act like an autoencoder instead of a classifier, what would change? Hope this isn't a bother.

It looks to me like it's just a matter of making the model that represents the vector field consider the t argument somehow and then feed t_span either forward or in reverse (is this necessary, as opposed to just reversing t_span?), but if so what's a sane way to change that model to allow this?

StephenHogg commented 2 years ago

...actually am I just thinking about normalising flows here?

Zymrael commented 2 years ago

Hi @StephenHogg! To run a Neural ODE in reverse you only need to flip the t_span; the vector field can be both time-varying i.e. use t in some way or not. Backward integration is detected by odeint (as long as t_span is flipped) and the sign of t inside the vector field calls will be automatically reversed.

StephenHogg commented 2 years ago

Thanks @Zymrael for clarifying - when I do this with the toy example, a call like:

model(z_S_coords, flipped_t_span)

gives a t_eval object that looks like this:

tensor([-1.0000, -0.9899, -0.9798, -0.9697, -0.9596, -0.9495, -0.9394, -0.9293,
         -0.9192, -0.9091, -0.8990, -0.8889, -0.8788, -0.8687, -0.8586, -0.8485,
         -0.8384, -0.8283, -0.8182, -0.8081, -0.7980, -0.7879, -0.7778, -0.7677,
         -0.7576, -0.7475, -0.7374, -0.7273, -0.7172, -0.7071, -0.6970, -0.6869,
         -0.6768, -0.6667, -0.6566, -0.6465, -0.6364, -0.6263, -0.6162, -0.6061,
         -0.5960, -0.5859, -0.5758, -0.5657, -0.5556, -0.5455, -0.5354, -0.5253,
         -0.5152, -0.5051, -0.4949, -0.4848, -0.4747, -0.4646, -0.4545, -0.4444,
         -0.4343, -0.4242, -0.4141, -0.4040, -0.3939, -0.3838, -0.3737, -0.3636,
         -0.3535, -0.3434, -0.3333, -0.3232, -0.3131, -0.3030, -0.2929, -0.2828,
         -0.2727, -0.2626, -0.2525, -0.2424, -0.2323, -0.2222, -0.2121, -0.2020,
         -0.1919, -0.1818, -0.1717, -0.1616, -0.1515, -0.1414, -0.1313, -0.1212,
         -0.1111, -0.1010, -0.0909, -0.0808, -0.0707, -0.0606, -0.0505, -0.0404,
         -0.0303, -0.0202, -0.0101, -0.0000], grad_fn=<_ODEProblemFuncBackward>)

anything to be worried about here?

Zymrael commented 2 years ago

Nothing to be worried about, the minus sign is simply an internal indicator used to determine that the ODE is solved backwards!

StephenHogg commented 2 years ago

@Zymrael great, thanks! I'm going to try something like the following, just to demonstrate it to myself:

Let me know if that seems sane? If so, I'm happy to contribute it as a notebook potentially. The ability to use ODEs like this seems a little undersold to me. Hope this isn't a bother.

Zymrael commented 2 years ago

Would you be using the adjoint method? Do update when you manage to get it working, I'm curious about the results and also about which uses cases you see for this.

StephenHogg commented 2 years ago

I'm not necessarily attached to that, do you see any gotchas with regards to adjoint sensitivity in this case?

StephenHogg commented 2 years ago

bumping regarding this question? Hope it's not a bother

StephenHogg commented 2 years ago

Closing this, I managed to get a reversible model going. Thanks for the help!