blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Preconditioned mclmc #673

Closed reubenharry closed 4 months ago

reubenharry commented 4 months ago

This PR exists in response to issue #616.

It adds a "mass matrix" (std_mat) to the mclmc algorithm, and also updates the tuning scheme. A concomitant update is the function with_isokinetic_maruyama which automates the stochastic momentum update. It seemed appropriate to place this (and therefore partial_momentum_update on which it depends) in the file with the integrators.

This branch depends on the branch of #672, so that PR should be merged first.

Note: I had some test failures which seemed quite arbitrary, to do with the Yoshida integrator, which had tol=1e-6. Since I didn't change any code that seemed likely to affect the Euclidean Yoshida integrator, I put tol=1e-4, and the tests passed. But worth noting.

Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.

reubenharry commented 4 months ago

@junpenglao This is the next in the series of PRs, then #681 then #675

reubenharry commented 4 months ago

This is now ready for review