Open emanuele opened 1 year ago
:tada: Welcome to PyMC! :tada: We're really excited to have your input into the project! :sparkling_heart:
If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.
I can reproduce in my (4x2) core CPU. I think it boils down to BLAS/ LAPACK multithreading. By default it saturates all the cores.
Setting the MKL threads to 1 fixes it for me:
%env MKL_NUM_THREADS=1
%env OPENBLAS_NUM_THREADS=1
Or setting the number of threads to 2 and reducing the number of chains to 2/3 is also fine.
More context for this can be found in: https://discourse.pymc.io/t/regarding-the-use-of-multiple-cores/4249/3 https://discourse.pymc.io/t/nuts-uses-all-cores/909/10
Maybe it's useful to add a warning message at the top of that notebook?
Part of this may be helped by fixing pymc-devs/pymc#6717
Good points @ricardoV94 . But why this happens only on Linux? Why specifically with this problem?
Anedoctically, it is the first time I see multiprocessing/multithreading being a problem on Linux and not on Windows or MacOS :D
My guess is that neither of those are actually triggering multi-threading. For instance, I think OpenMP (which could be another indirect source of multi-threading) is not installed by default on MacOs. Not sure without having a machine to try those out.
The way Python multi-processes are created is also different across those 3 operating systems IIRC, that could also matter.
Another user afflicted with this: https://discourse.pymc.io/t/bayesian-var-multivariate-slow-performance/12463/2?u=ricardov94
I think we should change the flags explicitly on the pymc example and have a warning message on why we are doing this. Not everyone needs it, but many users do. We can mention users can try to turn it off and see if it still samples fast.
Updating the flags sounds great
What is the suggested way to update the flags just in the case of Linux? A simple cell at the beginning of the notebook like
%env MKL_NUM_THREADS=1
%env OPENBLAS_NUM_THREADS=1
is not conditional on the operating system. Is checking via sys.platform
the way to go? Or PyMC has a more structured way?
I think those flags are fine in most OSes?
It would also be worth testing on the lastest PyMC (5.6.0) release as we managed to remove some useless Blas operations on MvNormal models. The problem may not exist anymore. We would need someone who had problems before to test it out.
Just tested on Linux with PyMC v5.6.0 with the code above without setting the flags and, unfortunately, the problem still persists. Maybe the sampling time diverges less dramatically: sampling reaches 10% in 3 minutes and all cores are completely filled (mostly in red color with htop, which means kernel time, so not what we want here). But after setting the flags as above, 10% of sampling requires <30 seconds.
%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray
gives:
Last updated: Wed Jul 12 2023
Python implementation: CPython
Python version : 3.9.7
IPython version : 8.14.0
pytensor: 2.12.3
aeppl : not installed
xarray : 2023.6.0
pymc : 5.6.0
numpy : 1.25.1
pandas: 2.0.3
Watermark: 2.4.3
I think those flags are fine in most OSes?
I really have no idea :)
Yes I think they work in other OSes, and won't hurt if not
Describe the issue:
When running the Bayesian Vector Autoregressive Models — PyMC example gallery 1 on x86_64 (tested on Ubuntu 20.04 + fresh install of PyMC via official installation instructions), the sampling time explodes to tens of hours instead of the expected few minutes. I can reproduce the issue on multiple machines.
Differently, on arm64 (M1 Macbook) the issue does not occur, and the sampling time is a few minutes as expected.
Below is a minimal example, extracted from the notebook above, that consistently shows the issue on x86_64.
The critical line is l.113, where
pm.sample()
has default values and the NUTS sampler is auto-assigned.By manually specifying
pm.sample(cores=1, ...
or using a non-default NUTS sampler (nuts_sampler="blackjax"
, ornuts_sampler="numpyro"
), the issue disappears on x86_64.Related discussion on Discord here.
Screenshot of the issue:
Reproduceable code example:
Error message:
PyMC version information:
Fresh install of PyMC via official installation instructions.
Context for the issue: