pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.47k stars 1.97k forks source link

Reduce JAX sampler memory usage #7311

Open andrewdipper opened 1 month ago

andrewdipper commented 1 month ago

-Change blackjax sampling to only retain the relevant sampling info - reduces sampling memory requirements -Change _postprocess_samples to reuse the input arrays - reduces postprocessing memory requirements

Description

By default the current pymc blackjax sampler accumulates all the info provided by blackjax only to subsequently delete it. This results in memory usage several times what is expected (some info is num_samples * num_vars in size). The change only stores what is used resulting in memory scaling similar to that of the numpyro jax sampler. It's worth noting that blackjax.window_adaptation also has excessive memory usage but that needs to be fixed from blackjax https://github.com/blackjax-devs/blackjax/issues/667. As such if tune is not set sufficiently small memory usage will still be excessive.

Changes the "vmap" mode of _postprocess_samples to donate the input device arrays resulting in (for my rough tests / models) constant additional memory usage. This should make the "scan" mode unnecessary. However, I left it in to not break anything.

Related Issue

This should resolve by using "vmap" mode: https://github.com/pymc-devs/pymc/issues/6744 This should be unnecessary given the reduction in "vmap" memory usage: https://github.com/pymc-devs/pymc/pull/7116

Checklist

Type of change


📚 Documentation preview 📚: https://pymc--7311.org.readthedocs.build/en/7311/

welcome[bot] commented 1 month ago

Thank You Banner] :sparkling_heart: Thanks for opening this pull request! :sparkling_heart: The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

andrewdipper commented 1 month ago

Looks like test failures are due to an older version of jax and a recent blackJAX PR to fix an argument deprecation to jnp.clip: https://github.com/blackjax-devs/blackjax/pull/664 https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html

andrewdipper commented 1 month ago

Also added in a bugfix for pm.sample not respecting compute_convergence_checks with numpyro/blackjax sampler

twiecki commented 1 month ago

Thanks @andrewdipper!

twiecki commented 1 month ago

Probably need to wait for https://github.com/pymc-devs/pymc/pull/7317 to merge.

ricardoV94 commented 1 month ago

This should make the "scan" mode unnecessary. However, I left it in to not break anything.

Let's start deprecating it! Can you add a FutureWarning about this argument/ value being removed in the future when the user manually specifies it?

ricardoV94 commented 1 month ago

Rebasing from main should fix the test failures and allow us to confirm nothing got broken

andrewdipper commented 1 month ago

I added the fix for reducing memory from the blackjax window_adaptation. But that depends on https://github.com/blackjax-devs/blackjax/pull/674 which was just merged earlier today. So blackjax would have to be up to date - not sure how that's best handled.

ricardoV94 commented 1 month ago

Unfortunately we can't test newer versions of blackjax because of https://github.com/conda-forge/jaxlib-feedstock/issues/249

andrewdipper commented 1 month ago

I made some more changes for my own experiments to offload on lighter hardware as follows:

I planned on later putting it up as a proposal if I liked it but since this change is stalled I figured it might be worth doing all together if you think the modifications are worthwhile. It's been helpful so far for me - let me know