pymc-devs / pymc-examples

Examples of PyMC models, including a library of Jupyter notebooks.
https://www.pymc.io/projects/examples/en/latest/
MIT License
259 stars 212 forks source link

External nuts sampler #560

Closed twiecki closed 10 months ago

twiecki commented 1 year ago

Update previous JAX sampling NB

The previous NB was very outdated, I changed the example to be PPCA and update to usenuts_sampler kwarg.

@aseyboldt nutpie is kinda slow on this example, not sure why.

review-notebook-app[bot] commented 1 year ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

aseyboldt commented 1 year ago

Looks like this model is pretty much a benchmark of the matrix multiply speed in the different backends. For me all three samplers take about 15s with MKL. What blas do you have installed?

I also get warnings from the numba backend:

/home/adr/git/nuts-py/python/nutpie/compile_pymc.py:364: NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (Array(float64, 2, 'A', False, aligned=True), Array(float64, 2, 'C', False, aligned=True))
  return inner(x)
/home/adr/git/nuts-py/python/nutpie/compile_pymc.py:364: NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (Array(float64, 2, 'C', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))

Maybe the ordered transform introduces some non-contigous arrays?

aseyboldt commented 1 year ago

@twiecki Could you maybe check what you get for sampling time alone and compile time alone for numba?

Ie

import nutpie

compiled = nutpie.compile_pymc_model(PPCA)
%time nutpie.sample(compiled)

(for me compilation is ~11s and sampling ~4s)

We should also use a fixed seed in the notebook, otherwise the data will be different each time we execute it. For me neither of the samplers ends up with a converged posterior, which makes comparing the times pretty pointless, but that might just be because of the seed I used...

twiecki commented 1 year ago

I added a seed.

compilation time is rather low for this model.

Wall time: 33.4 s for compilation + sampling vs 47s for just sampling.

twiecki commented 1 year ago

Cause of Numba slowness (as "debugged" with @aseyboldt just now): OpenBLAS. Installing Accelerate via micromamba install "libblas=*=*accelerate" got nutpie down to JAX-level speeds (minus compilation time).

twiecki commented 10 months ago

@OriolAbril I have implemented the requested changes.