Open skrsna opened 4 years ago
Hi @skrsna, this sounds very exciting!
The most important part is that the external model for the code looks like the rest of JAX. Ideally we would have something like a method
argument on jax.experimental.ode.odeint
that allows for selecting different integrators, e.g., stiff vs. non-stiff solvers.
Yes, this would be a natural fit alongside the existing ODE solver in jax.experimental
We already have the adjoint gradient implemented in JAX. The implementation is actually entirely agnostic to the details of how the ode is solved (notice that it calls the top level odeint
), so I think the BDF solver could use it just as easily:
https://github.com/google/jax/blob/b813ae3aff3b3f99367956a653007703bbfe3703/jax/experimental/ode.py#L255
Hi @shoyer, thanks for your feedback. I'll open a PR once I get everything working in functional programming model. I'll close this issue.
Hi @shoyer, sorry for closing and reopening the issue. I'm trying to convert BDF solver to functional programming model and also adding some extra features from tfp e.g. lazy jacobian evaluation and LU factorization. I'm using lax.cond
but you mentioned here that lax.cond
is slow on accelerators and I just noticed it is in fact slow after I added that. Right now, my current implementation uses named tuples registered as JAX pytrees to handle all internal solver data and using them as an operand for lax.cond
so I don't think I can use jnp.where
unless I come up with some a super smart hack. Is there an alternative for lax.cond
that is jittable?. This feature is not that important but good to have cause the current implementation is evaluating jacobian and LU factorization at almost every step of integration and most CPU based stiff ODE solvers like VODE reuse jacobians and actively try to avoid evaluation. Currently for a small chemical system with ~60 ODEs JAX implementation is ~200 times slower with JIT when compared to VODE wrapper from SciPy.
You can use tree_multimap(partial(jnp.where, condition), x, y) as a substitute for lax.cond to handle pytrees.
On Sat, Jul 11, 2020 at 9:30 AM Krishna Sirumalla notifications@github.com wrote:
Reopened #3686 https://github.com/google/jax/issues/3686.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/3686#event-3536416714, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVVOWUX4WPXLZX5JLB3R3CHQZANCNFSM4OTVXP2A .
Hi @shoyer ,
I'm trying to write a adjoint gradient method for the BDF solver #3781 but looks like I hit a pretty major blocker. Specifically,
the BDF solver I wrote follows ode_fn(t, y)
where as the adjoint in jax/experimental
has a form of ode_fn(y,t)
to my best knowledge. I'm getting size mismatch etc if I use ravel_first_arg_
function and jax.linear_utils
are not documented for public. I'd like to write something similar but for second argument
. I'm not familiar with any of the functions/transformation in linear_utils
so any help is greatly appreciated. I can send you a reproducible code but this will involve installing jax from that open PR.
Hi, I have a working version of BDF solver in JAX with JIT support which I ported from tensorflow_probability. I tested it with stiff chemical kinetics problems using VODE wrapper in SciPy and the results are promising. I would like to open a PR to add this to official JAX repo but I have some questions.
Does the code has to follow functional programming model cause right now I used the tensorflow_probability's model.
Where should I add this code? e.g.
jax/experimental
Is adjoint gradient method mandatory to open a PR? If so which model should I use?
tensorflow_probability's
or JAX'scustom_vjp/jvp
?The current implementation is here. Feedback, suggestions and comments are welcome. Finally, thanks for such a great framework.