pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.09k stars 227 forks source link

Memory reduction fixes for MCMC sampler #1802

Closed andrewdipper closed 2 weeks ago

andrewdipper commented 1 month ago

Here are some proposed changes for reducing GPU memory footprint for MCMC sampling. I was hitting OOM earlier than expected so tracked some of it down:

Below are some rough numbers for peak memory usage / runtime with and without the changes for two models (split by //) just to give an initial view.

baseline / no progress bar: 3592MB / 63sec
baseline / with progress bar: 5126MB/ 109sec // 14412MB / 515sec new / no progress bar: 2052MB / 63sec new / with progress bar: 2052MB / 70sec // 5182MB / 320sec

Let me know if you think any of the changes would be useful / any modifications are needed

andrewdipper commented 1 month ago

Is there any way to get the output / insight into the failing test? FAILED test/test_examples.py::test_cpu[stochastic_volatility.py --num-samples 100 --num-warmup 100] - subprocess.CalledProcessError: Command '['/opt/hostedtoolcache/Python/3.9.19/x64/bin/python', '/home/runner/work/numpyro/numpyro/examples/stochastic_volatility.py', '--num-samples', '100', '--num-warmup', '100']' returned non-zero exit status 1.

For the latest update I've run the tests both locally and on a kaggle T4x2 session and they passed. The prolda test sometimes has issues with getting the dataset but I've had no problems with the stochastic_volatility test.

andrewdipper commented 3 weeks ago

I accidentally updated this with a merge instead of rebasing - let me know if that's an issue. Anyhow it's back up to date

andrewdipper commented 3 weeks ago

For sure, makes sense

fehiepsi commented 2 weeks ago

Happy to merge! 💯