Open trendelkampschroer opened 1 month ago
That you for the bug report. I had seen something possibly related recently, but didn't manage to find an example in a smaller model. This example should make it much easier to find the problem.
Right now you can work around the issue by freezing the pymc model:
from pymc.model.transform.optimization import freeze_dims_and_data
trace = nutpie.sample(nutpie.compile_pymc_model(freeze_dims_and_data(model), **kwargs))
Well, that was easier than I though, and won't be hard to fix. The problem is that the argument name for the point in parameter space is x
and if any shared variable (like the data) is also called x
this will give an argument name collision.
@aseyboldt : Thanks a lot for the very helpful reply. I renamed 'x' -> 'X' and now things are working.
This is really good to know, but probably something that should either be fixed by ensuring that a unique name for the point in parameter space is used, or forbidding 'x' as name in the model (which would be a bit cumbersome, since predictors are often denoted by 'x').
On the upside using JAX gives a very nice speedup on my machine (Apple M1) :-).
Yes, definitely needs a fix, I'll push one soon.
Out of curiosity (I don't have a apple), could you do me a small favor and run this with jax and numba and tell me what the compile and the runtime is each time?
frame = generate_data(num_samples=10_000)
model = make_model(frame)
kwargs = dict(backend="jax", gradient_backend="jax")
t0 = time.time()
compiled = nutpie.compile_pymc_model(model, **kwargs)
print(f"compile time: {time.time() - t0}")
t0 = time.time()
trace = nutpie.sample(compiled)
t = time.time() - t0
print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
print(summary)
frame = generate_data(num_samples=10_000)
model = make_model(frame)
kwargs = dict(backend="numba", gradient_backend="jax")
t0 = time.time()
compiled = nutpie.compile_pymc_model(model, **kwargs)
print(f"compile time: {time.time() - t0}")
t0 = time.time()
trace = nutpie.sample(compiled)
t = time.time() - t0
print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
print(summary)
compile time for kwargs={'backend': 'jax', 'gradient_backend': 'jax'}: 1.016s
Time for nutpie (compiled, kwargs={'backend': 'jax', 'gradient_backend': 'jax'}) sampling is t=2.094s.
compile time for kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}: 3.835s
Time for nutpie (compiled, kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}) sampling is t=0.564s.
I hope that helps. I have another follow up question: While I observe a great speed-up when using the JAX backend on my M1 Apple machine, I observe significantly slower sampling with the JAX backend compared to Numba/Pytensor when running on a Google Cloud VM with a lot more cores (32) and memory. This is for a hierarchical linear regression with thousands of groups and a couple of predictors.
On the VM sampling with the "jax" backend is about 30% slower compared to the "numba" backend. Specifically I observe that for the "numba" backend I get a couple of (4-8) thread/CPU bars in htop
with 100%, while for the JAX backend all 32 bars show "some occupancy at less than 50%".
If you have any ideas/insights what could cause this and also how to ensure best performance, then I'd be glad for any suggestions.
Thanks for the numbers :-)
First, I think it is important to distinguish compile time and sampling time. The numbers you just gave me show that the numba backend samples faster on the mac as well, only the compile time is much larger. If the model get's bigger the compile time will play less of a role, because it doesn't depend much on the data size.
I think what you observe with the jax backend is an issue with how the jax backend currently works: The chains run in different threads that are controlled in the rust code. With the numba backend the python interpreter is only used to generate the logp function, all sampling happens without any involvement of python. But doing this with jax is currently much harder (I hope this will change in the not too distant future though). jax compiles the logp function, but I can't easily access this compiled function from rust. So instead I have to call a python function that then calls the compiled jax function. While a bit silly, that wouldn't be too bad if python didn't have the GIL. But the GIL ensures that only one thread (ie chain) can use the python interpreter at the same time. So each logp function evaluation does something like the following:
If the computation of the logp function takes a long time, and there aren't that many threads, then most of the time only one or even no thread will hold the gil, because each threads spends most of its time in the "do the actual logp function evaluation" phase, and all is good. But if the logp function evaluation is relatively quick, then more than one thread will try to acquire the gil at the same time, and this means that the threads sit around waiting. Ie "low occupancy".
There are two things that might make this situation better in the future:
In the meantime: If the cores of your machine aren't used well, you can at least try to limit the number of threads that run at the same time by setting the cores
argument to sample to something smaller. This can reduce the lock contention and give you a modest speedup. It won't really fix the problem though...
If you are willing to go to some extra lengths: You can start multiple separate processes that sample your model (with different seeds!) and then combine the traces. This is much more annoying, but should completely avoid the lock contention. In that case you can run into other issues however, for instance if each process tries to use all available cores on the machine. Fixing that would then require using threadpoolctl
and/or tasksel
or some jax environment variable flags.
I hope that helps to clear it up a bit :-)
Thanks @aseyboldt this is really helpful. Do you know of a minimal example for the "start multilple separate processes" approach. I have seen https://discourse.pymc.io/t/harnessing-multiple-cores-to-speed-up-fits-with-small-number-of-chains/7669 where the idea is to concatenate multiple smaller chains to more efficiently harness the CPUs on a machine.
I'd e.g. try to use joblib
for that but I am not sure how much that interferes with the PyMC and nutpie internals. If you have any pointers I'd be very glad to look into it.
Btw: with num_samples = 100_000
the numbers look like this on Apple M1
compile time for kwargs={'backend': 'jax', 'gradient_backend': 'jax'}: 0.864s
Time for nutpie (compiled, kwargs={'backend': 'jax', 'gradient_backend': 'jax'}) sampling is t=5.842s.
compile time for kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}: 3.874s
Time for nutpie (compiled, kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}) sampling is t=19.923s.
So JAX is a lot faster for sampling - which also matches my observation for a hierarchical linear model.
For sampling in separate processes:
# At the very start...
import os
os.environ["JOBLIB_START_METHOD"] = "forkserver"
import joblib
from joblib import parallel_config, Parallel, delayed
import arviz
def run_chain(data, idx, seed):
model = make_model(data)
seeds = np.random.SeedSequence(seed)
seed = np.random.default_rng(seeds.spawn(idx + 1)[-1]).integers(2 ** 63)
compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax")
trace = nutpie.sample(compiled, seed=seed, chains=1, progress_bar=False)
return trace.assign_coords(chain=[idx])
with parallel_config(n_jobs=10, prefer='processes'):
traces = Parallel()(delayed(run_chain)(frame, i, 123) for i in range(10))
trace = arviz.concat(traces, dim="chain")
This comes with quite a bit of overhead (mostly constant though), so probably not worth it for smaller models.
Funnily enough, I see big differences between
# Option 1
mu = (beta * x).sum(axis=-1)
# Option 2
mu = x @ beta
And jax and numba react quite differently. Maybe an issue with the blas config? What blas implementation are you using?
(on conda-forge you can choose it as explained here: https://conda-forge.org/docs/maintainer/knowledge_base/#switching-blas-implementation) I think on M1 accelerate is usually the fastest).
Thanks a lot for the again very helpful suggestions. I will benchmark the two versions of the "dot-product" to see whether I observe different performance.
Regarding BLAS
On Apple-M1 I have
blas 2.124 openblas conda-forge
blas-devel 3.9.0 24_osxarm64_openblas conda-forge
libblas 3.9.0 24_osxarm64_openblas conda-forge
libcblas 3.9.0 24_osxarm64_openblas conda-forge
liblapack 3.9.0 24_osxarm64_openblas conda-forge
liblapacke 3.9.0 24_osxarm64_openblas conda-forge
libopenblas 0.3.27 openmp_h517c56d_1 conda-forge
openblas 0.3.27 openmp_h560b219_1 conda-forge
and on the VM
blas 2.120 mkl conda-forge
blas-devel 3.9.0 20_linux64_mkl conda-forge
libblas 3.9.0 20_linux64_mkl conda-forge
libcblas 3.9.0 20_linux64_mkl conda-forge
I can try to use the accelerate BLAS. But I am more curious to speed up things on the VM now.
Minimal example
Error message
Sampling with
backend="numba"
andgradient_backend="pytensor"
runs successfully.Version