blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Add low-rank-modified metric #684

Open aseyboldt opened 4 months ago

aseyboldt commented 4 months ago

Add a low-rank-modified metric

This adds a new type of metric for hmc-style samlers: Instead of just allowing a diagonal or a full mass matrix, we combine the diagonal mass matrix with a couple of additional eigenvectors and eigenvalues of the remaining correlation.

This leads to the following form for the mass matrix:

$$ \begin{align} P &= V(\Sigma^{-1} - I)V^T + I \ M &= D^{-\frac{1}{2}}PD^{-\frac{1}{2}} \end{align} $$

$P$ has exactly the eigenvalues specified in $\Sigma$ with the corresponding eigenvectors in $V \in R^{n, k}$. All remaining eigenvalues are one.

All operations needed in hmc/nuts can be done in $O(nk)$, so if the number of eigenvalues is small this has a cost much closer to diagonal mass matrix estimation.

This PR for now only contributes the metric, but corresponding adaptation should come a bit later...

Related issue: #683

A few important guidelines and requirements before we can merge your PR:

Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.