pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.19k stars 239 forks source link

A tutorial for integrating ODEs multiple times #1450

Closed A2P2 closed 3 months ago

A2P2 commented 2 years ago

Hi, I'm wondering if there is a need to contribute with a tutorial for integrating ODEs multiple times based on different initial conditions. It's also possible to extend the current Lotka-Volterra example instead.

I oriented myself about the usage of jax.vmap for this purpose with the help of forums: https://forum.pyro.ai/t/parallelization-plate-and-odes/2763 https://forum.pyro.ai/t/making-a-for-loop-more-efficient/3437/22

Is it of interest? Let me know what you think.

fehiepsi commented 2 years ago

The contribution would be awesome! Currently, we only have an example for Lotka-Volterra. A tutorial with motivations for your usage case would be super helpful for users.

zz100chan commented 1 year ago

@A2P2 hi! I think this is a good contribution and I am also somewhat have similar issues now. I wonder if you have a tutorial already? :)

A2P2 commented 1 year ago

@zz100chan My apologies for the late reply and for generally ignoring the thread.

I do have a draft written, see it attached. I tried to follow the style of the current Lotka-Volterra tutorial. Let me know if it works for you and if something is not clear. Once again, it's in a draft state at the moment. Attached as a zip file. multiple_lotka_volterra lotka_volterra_multiple.zip

I've tried to implement a few things: 1) integrate with different initial conditions, 2) allowed to have missing values in the data. I use regularly spaced time arrays, but it's not a must.

The MCMC is somewhat slow. I've turned on numpyro.enable_x64(), otherwise problems with convergence were observed. For my job, I actually used solvers from the https://github.com/patrick-kidger/diffrax package instead of ones from jax, due to some convergence problems as well. ode_convergence

@fehiepsi I'm glad to finish the tutorial properly if you provide some feedback. Once again, my apologies for disappearing.

fehiepsi commented 1 year ago

Could you make a version of it on https://gist.github.com/?

A2P2 commented 1 year ago

@fehiepsi here you are https://gist.github.com/A2P2/5d2b3a15eafd5e0857ed1c49e4c1b1f4

fehiepsi commented 1 year ago

Thank you! I'll take a look later of the week.

fehiepsi commented 1 year ago

Hi @A2P2, the example looks great. In the tutorial, it would be nice to include some introduction to the model, how the dataset looks like, and motivation for integrating ODEs multiple times.

A2P2 commented 1 year ago

@fehiepsi thanks for checking it quickly. I've added a short motivation to the gist, I'll add the dataset description later. Shall I make a pull request with it?

fehiepsi commented 1 year ago

Could you turn it into a (probably short?) tutorial, rather than an example? We have several tutorials here (those without Example: prefix).

A2P2 commented 1 year ago

Will do!

A2P2 commented 10 months ago

@fehiepsi here is the tutorial version on colab, let me know what you think. https://gist.github.com/A2P2/ae09ed99f84372ff346b0d352ca7b4ed

fehiepsi commented 10 months ago

super cool, @A2P2 !