Open andrewdipper opened 1 month ago
]
:sparkling_heart: Thanks for opening this pull request! :sparkling_heart: The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.
Looks like test failures are due to an older version of jax and a recent blackJAX PR to fix an argument deprecation to jnp.clip: https://github.com/blackjax-devs/blackjax/pull/664 https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html
Also added in a bugfix for pm.sample not respecting compute_convergence_checks with numpyro/blackjax sampler
Thanks @andrewdipper!
Probably need to wait for https://github.com/pymc-devs/pymc/pull/7317 to merge.
This should make the "scan" mode unnecessary. However, I left it in to not break anything.
Let's start deprecating it! Can you add a FutureWarning about this argument/ value being removed in the future when the user manually specifies it?
Rebasing from main should fix the test failures and allow us to confirm nothing got broken
I added the fix for reducing memory from the blackjax window_adaptation. But that depends on https://github.com/blackjax-devs/blackjax/pull/674 which was just merged earlier today. So blackjax would have to be up to date - not sure how that's best handled.
Unfortunately we can't test newer versions of blackjax because of https://github.com/conda-forge/jaxlib-feedstock/issues/249
I made some more changes for my own experiments to offload on lighter hardware as follows:
I planned on later putting it up as a proposal if I liked it but since this change is stalled I figured it might be worth doing all together if you think the modifications are worthwhile. It's been helpful so far for me - let me know
-Change blackjax sampling to only retain the relevant sampling info - reduces sampling memory requirements -Change
_postprocess_samples
to reuse the input arrays - reduces postprocessing memory requirementsDescription
By default the current pymc blackjax sampler accumulates all the info provided by blackjax only to subsequently delete it. This results in memory usage several times what is expected (some info is
num_samples * num_vars
in size). The change only stores what is used resulting in memory scaling similar to that of the numpyro jax sampler. It's worth noting thatblackjax.window_adaptation
also has excessive memory usage but that needs to be fixed from blackjax https://github.com/blackjax-devs/blackjax/issues/667. As such iftune
is not set sufficiently small memory usage will still be excessive.Changes the
"vmap"
mode of_postprocess_samples
to donate the input device arrays resulting in (for my rough tests / models) constant additional memory usage. This should make the"scan"
mode unnecessary. However, I left it in to not break anything.Related Issue
This should resolve by using
"vmap"
mode: https://github.com/pymc-devs/pymc/issues/6744 This should be unnecessary given the reduction in"vmap"
memory usage: https://github.com/pymc-devs/pymc/pull/7116Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7311.org.readthedocs.build/en/7311/