Open ibraheem-moosa opened 3 years ago
Good idea!
You might be able to use the same trick used in the vjp rule, where to integrate xdot = f(x, t) from t=0 to -1, you instead integrate xdot = -f(x, -t) from t=0 to 1. I think that's essentially what scipy.odeint does, which is why the time values must be monotonically increasing or decreasing.
More concretely, I think odeint(f, x0, t)
, where t has monotonically decreasing values, would produce the same values as odeint(lambda x, t: -f(x, -t), x0, -t)
, where in this call the time values are monotonically increasing. I didn't work out the change-of-variable argument, so that claim may be buggy, but something like that should work.
Thanks @mattjj.
I tried with a few dxdt
functions like x+t, x-t, x/t, 1/(x+t)
.
Using this odeint(lambda x, t: -f(x, -t), x0, -t)
seems to be working.
@ibraheem-moosa is your code available somewhere? I was recently trying to implement CNFs as well and ran into #7142. Are you seeing any issues there?
@jatentaki Does Flax have the same interaction issue with odeint as haiku does? I think I have a similar problem in here but odeint in mine is called outside of model. I use -f(x, -t) with t =[10, 0] as @mattjj mentioned. Then, I compared trace at line 85 to original code and confirmed match. However, odeint give a very different behavior to the original one.
@jatentaki you can find a working CNF here by @sw32-seo .
We are trying to implement the Neural ODE paper as part of the JAX/FLAX Community Week at Hugging Face. While trying to implement the Continuous Normalizing Flow part of the paper we realized we have to run the ODE solver reversed in time.
This feature is not supported by the current
odeint
in jax.experimental.ode. In contrast this feature is supported in scipy and torchdiffeq.It would be great if this feature is supported in the
jax.experimental.ode
.