titu1994 / tfdiffeq

Tensorflow implementation of Ordinary Differential Equation Solvers with full GPU support
MIT License
218 stars 52 forks source link

Enable adjoint method #3

Closed eozd closed 4 years ago

eozd commented 4 years ago

Fixes #2

This PR proposes a way to make the adjoint method work with tensorflow custom_gradient interface. The main changes are in tfdiffeq/adjoint.py and can be summarized as:

  1. Don't pass the ODE parameters to OdeintAdjointMethod function. We instead get these parameters from the variables keyword argument of grad function.
  2. tf.custom_gradient requires grad function to return two sets of gradients as a pair. These are i. The gradient with respect to the inputs of OdeintAdjointMethod. These are x0 and t in our case. ii. The gradient with respect to the parameters which are tf.Variable objects stored in our ODE object.
  3. To prevent getting all the tf.Variable objects created in adams optimizer, we mark them as non-trainable. However, there still seems

Caveats: I wasn't able to make the method work with the adams method (therefore adams - adjoint test is not enabled either). The problem is that the elements of the tuple returned from augmented_dynamics function have different shapes, and this causes problems with adams.py:138

titu1994 commented 4 years ago

This solution is ingenious ! I completely missed that I can recover the parameters from variables. Thank you so very much for your help with this.

As to adams-bashforth implementation, it seems there are certain issues with the current implementation, which I am closely following in the pytorch discussions.

As dopri tests pass, I will be glad to merge this PR upon your go ahead.

eozd commented 4 years ago

If the idea looks good to you, then by any means please go ahead. By the way, I would also like to thank you for the original implementation. As I will be working with tfdiffeq in the immediate future I will make sure to post any issues I may find with the changes I introduced.

titu1994 commented 4 years ago

Merged ! I do advise wrapping the callable portion of the ode function call(u,t) inside a tf.function block to see some noticeable speedups. There's some performance bottlenecks I'd like to look into, and hopefully somehow implement the universal ordinary differential equations paper in the future, if I ever get to parse the Julia codebase