martenlienen / torchode

A parallel ODE solver for PyTorch
https://torchode.readthedocs.io
MIT License
232 stars 13 forks source link

JAX implementation? Or calling from JAX code? #42

Closed jucor closed 2 months ago

jucor commented 2 months ago

Apologies if this is a stupid question, but do you have a JAX implementation of the "heterogeneous step within a batch"? If not, do you know how painful (performance-wise) calling torchode from within JAX code is likely to be, and what gotchas to look for, please?

My use case, also described in https://github.com/patrick-kidger/diffrax/issues/500 : I need to evaluate a 1-dimensional ODE on a fixed time grid but with 1000,000 different parameters, while using a a JAX-based Sequential Monte Carlo sampler on a ODE-based likelihood. Each parameter value is a monte-carlo sample. 98% of my sampling time is spent calling the ODE solver for all the 100,000 samples, so batching in a GPU-based solver would be great -- but of course I have widely differing stiffness depending on the parameter values, and naive-batching is likely to bite me.

Thanks for any hint to make the best use of your work!

martenlienen commented 2 months ago

Hey, I think diffrax should be capable of doing exactly what you want as lockwo explained in the issue. You would set up the solver and then vmap it over all of your ODE problems.

jucor commented 2 months ago

Amazing, thanks a lot! Wowser, this is terrific. Can't believe how simple it is, yet it works :) Thanks a lot!