jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.57k stars 2.81k forks source link

Support reverse time integration with ODE. #7269

Open ibraheem-moosa opened 3 years ago

ibraheem-moosa commented 3 years ago

We are trying to implement the Neural ODE paper as part of the JAX/FLAX Community Week at Hugging Face. While trying to implement the Continuous Normalizing Flow part of the paper we realized we have to run the ODE solver reversed in time.

This feature is not supported by the current odeint in jax.experimental.ode. In contrast this feature is supported in scipy and torchdiffeq.

It would be great if this feature is supported in the jax.experimental.ode.

mattjj commented 3 years ago

Good idea!

You might be able to use the same trick used in the vjp rule, where to integrate xdot = f(x, t) from t=0 to -1, you instead integrate xdot = -f(x, -t) from t=0 to 1. I think that's essentially what scipy.odeint does, which is why the time values must be monotonically increasing or decreasing.

mattjj commented 3 years ago

More concretely, I think odeint(f, x0, t), where t has monotonically decreasing values, would produce the same values as odeint(lambda x, t: -f(x, -t), x0, -t), where in this call the time values are monotonically increasing. I didn't work out the change-of-variable argument, so that claim may be buggy, but something like that should work.

ibraheem-moosa commented 3 years ago

Thanks @mattjj.

I tried with a few dxdt functions like x+t, x-t, x/t, 1/(x+t).

Using this odeint(lambda x, t: -f(x, -t), x0, -t) seems to be working.

jatentaki commented 3 years ago

@ibraheem-moosa is your code available somewhere? I was recently trying to implement CNFs as well and ran into #7142. Are you seeing any issues there?

ibraheem-moosa commented 3 years ago

@jatentaki You can find my code here and here. It is really a WIP. I'm not sure how helpful it would be. I have been struggling with this for the past few days.

sw32-seo commented 3 years ago

@jatentaki Does Flax have the same interaction issue with odeint as haiku does? I think I have a similar problem in here but odeint in mine is called outside of model. I use -f(x, -t) with t =[10, 0] as @mattjj mentioned. Then, I compared trace at line 85 to original code and confirmed match. However, odeint give a very different behavior to the original one.

mandelbrot-walker commented 3 years ago

@jatentaki you can find a working CNF here by @sw32-seo .