jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.98k stars 2.75k forks source link

BDF solver for stiff ODEs #3686

Open skrsna opened 4 years ago

skrsna commented 4 years ago

Hi, I have a working version of BDF solver in JAX with JIT support which I ported from tensorflow_probability. I tested it with stiff chemical kinetics problems using VODE wrapper in SciPy and the results are promising. I would like to open a PR to add this to official JAX repo but I have some questions.

  1. Does the code has to follow functional programming model cause right now I used the tensorflow_probability's model.

  2. Where should I add this code? e.g. jax/experimental

  3. Is adjoint gradient method mandatory to open a PR? If so which model should I use? tensorflow_probability's or JAX's custom_vjp/jvp?

The current implementation is here. Feedback, suggestions and comments are welcome. Finally, thanks for such a great framework.

shoyer commented 4 years ago

Hi @skrsna, this sounds very exciting!

  1. The most important part is that the external model for the code looks like the rest of JAX. Ideally we would have something like a method argument on jax.experimental.ode.odeint that allows for selecting different integrators, e.g., stiff vs. non-stiff solvers.

  2. Yes, this would be a natural fit alongside the existing ODE solver in jax.experimental

  3. We already have the adjoint gradient implemented in JAX. The implementation is actually entirely agnostic to the details of how the ode is solved (notice that it calls the top level odeint), so I think the BDF solver could use it just as easily: https://github.com/google/jax/blob/b813ae3aff3b3f99367956a653007703bbfe3703/jax/experimental/ode.py#L255

skrsna commented 4 years ago

Hi @shoyer, thanks for your feedback. I'll open a PR once I get everything working in functional programming model. I'll close this issue.

skrsna commented 4 years ago

Hi @shoyer, sorry for closing and reopening the issue. I'm trying to convert BDF solver to functional programming model and also adding some extra features from tfp e.g. lazy jacobian evaluation and LU factorization. I'm using lax.cond but you mentioned here that lax.cond is slow on accelerators and I just noticed it is in fact slow after I added that. Right now, my current implementation uses named tuples registered as JAX pytrees to handle all internal solver data and using them as an operand for lax.cond so I don't think I can use jnp.where unless I come up with some a super smart hack. Is there an alternative for lax.cond that is jittable?. This feature is not that important but good to have cause the current implementation is evaluating jacobian and LU factorization at almost every step of integration and most CPU based stiff ODE solvers like VODE reuse jacobians and actively try to avoid evaluation. Currently for a small chemical system with ~60 ODEs JAX implementation is ~200 times slower with JIT when compared to VODE wrapper from SciPy.

shoyer commented 4 years ago

You can use tree_multimap(partial(jnp.where, condition), x, y) as a substitute for lax.cond to handle pytrees.

On Sat, Jul 11, 2020 at 9:30 AM Krishna Sirumalla notifications@github.com wrote:

Reopened #3686 https://github.com/google/jax/issues/3686.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/3686#event-3536416714, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVVOWUX4WPXLZX5JLB3R3CHQZANCNFSM4OTVXP2A .

skrsna commented 3 years ago

Hi @shoyer ,

I'm trying to write a adjoint gradient method for the BDF solver #3781 but looks like I hit a pretty major blocker. Specifically, the BDF solver I wrote follows ode_fn(t, y) where as the adjoint in jax/experimental has a form of ode_fn(y,t) to my best knowledge. I'm getting size mismatch etc if I use ravel_first_arg_ function and jax.linear_utils are not documented for public. I'd like to write something similar but for second argument. I'm not familiar with any of the functions/transformation in linear_utils so any help is greatly appreciated. I can send you a reproducible code but this will involve installing jax from that open PR.