Closed twiecki closed 10 months ago
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Looks like this model is pretty much a benchmark of the matrix multiply speed in the different backends. For me all three samplers take about 15s with MKL. What blas do you have installed?
I also get warnings from the numba backend:
/home/adr/git/nuts-py/python/nutpie/compile_pymc.py:364: NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (Array(float64, 2, 'A', False, aligned=True), Array(float64, 2, 'C', False, aligned=True))
return inner(x)
/home/adr/git/nuts-py/python/nutpie/compile_pymc.py:364: NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (Array(float64, 2, 'C', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))
Maybe the ordered transform introduces some non-contigous arrays?
@twiecki Could you maybe check what you get for sampling time alone and compile time alone for numba?
Ie
import nutpie
compiled = nutpie.compile_pymc_model(PPCA)
%time nutpie.sample(compiled)
(for me compilation is ~11s and sampling ~4s)
We should also use a fixed seed in the notebook, otherwise the data will be different each time we execute it. For me neither of the samplers ends up with a converged posterior, which makes comparing the times pretty pointless, but that might just be because of the seed I used...
I added a seed.
compilation time is rather low for this model.
Wall time: 33.4 s for compilation + sampling vs 47s for just sampling.
Cause of Numba slowness (as "debugged" with @aseyboldt just now): OpenBLAS. Installing Accelerate via micromamba install "libblas=*=*accelerate"
got nutpie down to JAX-level speeds (minus compilation time).
@OriolAbril I have implemented the requested changes.
Update previous JAX sampling NB
The previous NB was very outdated, I changed the example to be PPCA and update to use
nuts_sampler
kwarg.@aseyboldt nutpie is kinda slow on this example, not sure why.