rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
326 stars 17 forks source link

reshape momentum inside the integrator? #7

Closed rlouf closed 4 years ago

rlouf commented 4 years ago

There is a tension inside the HMC algorithm between the momentum generator, which generates 1D arrays, and the log-probability functions which takes tuples of array. Since the integrator step involves computations that mix for, we either need to work with flattened variables or reshape the momentum and use jax.flatten_util.ravel_pytree to obtain an unraveling function. Here are the tradeoffs:

Flattening variables

This can be easily achieved by composing the unraveling function with the logpdf. We can then use variables as a concatenated 1D Array.

This leads to the following extra operations:

  1. Unraveling in the flattened logpdf during trajectory integration: N_samples * num_stages_integrator * num_integration_steps;
  2. Unraveling in the flattened logpdf when computing the logprob before the acceptance step: N_samples;
  3. N_samples unraveling steps to re-create the trace.

So approximately:

N_samples * num_stages_integrator * num_integration_steps

Pros

Cons

Unraveling momentum during inference

This can be done at several stages. Either in the momentum generator (like Numpyro) or the integrator. It seems kind of awkward to do it in the momentum generator: it becomes disconnected from the place it is actually needed and will look confusing to someone first encountering the code.

It would lead to the following extra operations:

  1. Unraveling in the momentum generation: N_iter
  2. Raveling in the kinetic energy: N_samples * num_stages_integrator * num_integration_steps
  3. Raveling/unraveling with jax.tree_multimap: same order of magnitude.

Pro

Cons

rlouf commented 4 years ago

I settled on the second option for now. The momentum is an array because of the existence of the mas matrix; this may be a quirk of the euclidean metric: RHMC's momentum generator and kinetic energy is a function of the position, for instance. We should thus not carry this specificity around in the code.

Note that this is a potential source of performance improvement, although I suspect the raveling/unraveling takes substantially less time than logpdf and gradient of the logpdf evaluation.