Closed rlouf closed 3 years ago
There is one issue with the warmup: it is very slow (2 mins for 500 steps) due to JAX's compilation overhead. In the current implementation we use closures to embed parameters inside the metrics, proposals, integrators, etc. Every time a parameter changes during warmup we need to re-define and compile these functions, which eats all the computation time.
For each warmup step, I measure a total running time that is of the order of the tens of seconds while the average running time of each iteration within the step is in the order of the 1/10th of a second.
I reorganized the code for the Stan warmup to make it more modular and give it a more functional style. Perhaps unsurprisingly, this overhaul reduced dramatically the time it takes to run the warmup. I have now a fixed JIT compilation time of 10s and a running time of a couple seconds for 1,000 steps & 4 chains. It scales sub linearly with the number of chains, the warmup takes less than 20s in total for a few hundred chains and 1,000 steps.
Next steps:
We can then consider cleaning the code, improving the test suite, documenting and automating the doc generation before a first release.
Questions
find_reasonable_step_size
uses kernel with a single step?
This PR adds the "Stan warmup" to the HMC program. The warmup consists in a specific sequence of step size and mass matrix adaptations. We run the warmup for each chain separately, as discussed in #21.