rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

HMC warmup #29

Closed rlouf closed 3 years ago

rlouf commented 4 years ago

This PR adds the "Stan warmup" to the HMC program. The warmup consists in a specific sequence of step size and mass matrix adaptations. We run the warmup for each chain separately, as discussed in #21.

rlouf commented 4 years ago

There is one issue with the warmup: it is very slow (2 mins for 500 steps) due to JAX's compilation overhead. In the current implementation we use closures to embed parameters inside the metrics, proposals, integrators, etc. Every time a parameter changes during warmup we need to re-define and compile these functions, which eats all the computation time.

For each warmup step, I measure a total running time that is of the order of the tens of seconds while the average running time of each iteration within the step is in the order of the 1/10th of a second.

rlouf commented 3 years ago

I reorganized the code for the Stan warmup to make it more modular and give it a more functional style. Perhaps unsurprisingly, this overhaul reduced dramatically the time it takes to run the warmup. I have now a fixed JIT compilation time of 10s and a running time of a couple seconds for 1,000 steps & 4 chains. It scales sub linearly with the number of chains, the warmup takes less than 20s in total for a few hundred chains and 1,000 steps.

Next steps:

We can then consider cleaning the code, improving the test suite, documenting and automating the doc generation before a first release.

Questions