rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Bump jax version #61

Closed jeremiecoullon closed 3 years ago

jeremiecoullon commented 3 years ago

This fixes the 2 bugs with the older version of Jax (#36 and #58).

I then unpinned the old version of Jax. When I install everything from scratch and run the tests it passes.

Note: I also removed the static_argnums in integtrators.py which sets step_size (a float) as static. This was also causing a "non-hashable type" error for some reason. In any case, I don't think jitting floats adds a performance penalty (from playing around with Jax generally; I didn't test performance in this particular case).

rlouf commented 3 years ago

Since I put an emphasis on performance I should have benchmarks that are run with the CI at some point. Something like: https://github.com/marketplace/actions/continuous-benchmark

jeremiecoullon commented 3 years ago

I did some quick performance tests (on the example in the reamde):

• the "non-hashable type" and the kernel_factory fixes made no difference in performance (running on the old version Jax) • the newer version of Jax was faster for both the compiled and non-compiled example.

Details of test:

I ran 1D logistic regression (like in the readme) but with 2 chains and 3K samples:

Old jax version: Non-compiled: • warmup: 550 samples/sec • sampling: 3800 samples/sec

Compiled: • warmup: 1.9sec • sampling: 1.7 sec

New jax verson (0.2.6): Non-compiled: • warmup: 720 samples/sec • sampling: 5950 samples/sec

Complied: • warmup: 1.5sec • sampling: 1.3 sec

I ran these timings a few times and their variance was small; so the newer jax library definitely sped the sampler up in this example. I don't know if it will make a difference for more complicated examples though!

rlouf commented 3 years ago

Nice! Merging.