rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Warmup #54

Closed rlouf closed 3 years ago

rlouf commented 3 years ago

Draft

The current design for the warmup is not optimal and I would like to start from something more principled. The warmup, as it is commonly understood, has two purposes: adapting the evaluator's parameters and reaching the typical set. Since parameters are bound to change even during the adaptation, we decompose an evaluator into a set of parameters and a factory that generates a kernel given the parameter values. For instance, a NUTS evaluator is the combination of a specific value of the tuple (step_size, mass_matrix) and a kernel factory that returns the evaluator's transition kernel when passed the parameter tuple.

An adaptation algorithm, then, is nothing more than a function that takes a set of parameters, a factory and returns a new set of parameters:

adapt(Parameters, chain_state, factory) -> Parameters', chain_state'

This way adaptation algorithms can be easily combined:

DualAveragingStepAdaptation((step_size, mass_matrix), state, factory) -> (new_step_size, mass_matrix), state'
MassMatrixAdaptation((new_step_size, mass_matrix), state', factory) -> (new_step_size, new_mass_matrix), state''

Note that we decided to start the second adaptation step from the final position of the first step, but did not have to do so.

Interestingly, in this framework one can use any factory that is compatible with the parameters being adapted; it is totally possible to adapt step size and mass matrix with NUTS, use the parameters for HMC. While this may seem silly, it looks like the authors if the empirical HMC paper use NUTS for the window adaptation and then HMC to adapt the number of integration steps.

Implementation

Each adaptation step requires the sampling machinery. We suggest a new adapt function/class that is very close in implementation to sample except it carries the values of the parameters as well as the chain state. Since the implementation of sample is modular we can use re-use some of its component.

Adaptation algorithms are implemented as kernels that update the chain and their internal state. A final function takes the internal state and the parameter tuple to return an updated parameters tuple.

rlouf commented 3 years ago

Moved to BlackJAX