blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
805 stars 105 forks source link

Diffusive Gibbs Sampler #744

Open jcopo opened 2 days ago

jcopo commented 2 days ago

Presentation of the new sampler

DiGS is an auxiliary variable MCMC method where the auxiliary variable $\tilde{x}$ is a noisy version of the original variable $x$. DiGS enhances mixing and helps escape local modes by alternately sampling from the distributions $p(\tilde{x}|x)$, which introduces noise via Gaussian convolution, and $p(x|\tilde{x})$, which denoises the sample back to the original space using a score-based update (eg a Langevin diffusion). https://arxiv.org/abs/2402.03008

I had very good results with it in small to medium dimensions. It really helps escaping local modes. A very powerful usage is to use it as a proposal in an SMC based procedure where it helps moving samples back in depleted zones. In high dimension the acceptance ratio of the MH step becomes the tricky part.

If you think it is a sensible addition to Blackjax I'll be happy to contribute.

How does it compare to other algorithms in blackjax?

The number of denoising steps is flexible so it can be computationally efficient. The algorithm is quite simple conceptually but is applicable to a wide class of problems.

Where does it fit in blackjax

As an MCMC kernel per se or as an SMC proposal.

Are you willing to open a PR?

Yes - I have a version that I used for my research and would be happy to contribute it to Blackjax

AdrienCorenflos commented 2 days ago

Can you explain how this algorithm is different from the auxiliary perspective of MALA in https://rss.onlinelibrary.wiley.com/doi/full/10.1111/rssb.12269?

Edit: Gaussian convolution and then Langevin "denoising" is exactly aMALA in my books, so where's the difference?

jcopo commented 2 days ago

Maybe given your comment I should underline that I have no personal interest in the paper. I was not familiar with the reference you provided but the method does look similar.

At first glance aMALA doesn't seem to have the multilevel noise schedule of Diffusive Gibbs and initialization of the denoising step (eq. 14 of DiGS) seems to be different. Is the contraction idea and the corresponding variance of eq. 10 also in aMALA ? Had a look at your code implementation in marginal_latent_gaussian.py but wasn't obvious

AdrienCorenflos commented 1 day ago

The marginal latent Gaussian is the counterpart for Gaussian priors, but I really think it's related: the auxiliary target is the same one, and then it looks like it's just a bunch of MALA steps with increasing step size -> this is not unprecedented in literature (although people typically take step sizes coming from Chebyshev polynomials, for which there is theory).

I am not against having underperforming or academically not super novel samplers in the library mind you, I'm mostly thinking we may want to think carefully about the components and implement these, rather than the special instance that DiGS offers.

jcopo commented 1 day ago

it's just a bunch of MALA steps with increasing step size

I don't see how this is true? The schedule modifies the proximal version of the score function (eq. 12) and not directly the step size.

Putting academic novelty aside or which paper gets credited, I think this is a simple yet effective sampler. I'm happy to think about the components and how to implement these. But it's not clear to me what algorithm DiGS should be a special case of. If you have references/ideas I'd be interested in having a look.

AdrienCorenflos commented 1 day ago

But Eq (12) is immediately the gradient of the conditional proximal density though, exactly as would happen in auxiliary MALA with a different choice of decomposition in terms of $\alpha, \sigma$, I'm not sure what you mean. image

From what I understand (I am not saying I'm not missing something though), the algo to sample from p(x, u) = p(x) p(u | x) is the following: Given the current state $X^*$,

  1. Sample an auxiliary variable $U \sim p(u | X)$
  2. Form the conditional $p(x | U) \propto p(x, U)$
  3. Jitter (or not depending on MH accept) $X^*$
  4. Apply MALA a bunch of times for $p(x | U)$ with increasing step-sizes, which corresponds to a warm-up with a non-adaptive schedule

So, in some sense, what I can see here is that maybe we want to disconnect the step-size and the scale in our MALA algo to allow for different balancing in auxiliary schemes, but that's kind of it? Also I'm really not sure the choice of balancing they have is the best one. All in all, I'd support a small refactoring to allow for more flexible parameterization of MALA (or maybe we already do have this) and some Gaussian proximal auxiliary utility, then add the DiGS sampler as an example, not as core library one.

jcopo commented 1 day ago

Ok yes I get you. It wasn't clear to me which part of the alg. you were referring to.

Also I'm really not sure the choice of balancing they have is the best one.

I agree but the idea of dilation/contraction of the space with $\alpha$ is interesting especially in a multimodal setting.

Is there already something in place for dealing with auxiliary variable samplers? On top of what you propose I think this could be a nice add

AdrienCorenflos commented 1 day ago

No there's not yet I don't think but it's a relatively good and easy addition though we may want to be a bit careful about the design.

So, my suggestion today is to probably implement the algo in the sampling books by using components from the main library and see if you are blocked by anything. If you are, then it means we need to refactor. If you are not, let's see how DiGS picks up to put it in the core library?

@junpenglao, any comment on that?

On Wed, 2 Oct 2024, 18:59 Jacopo Iollo, @.***> wrote:

Ok yes I get you. It wasn't clear to me which part of the alg. you were referring to.

Also I'm really not sure the choice of balancing they have is the best one.

I agree but the idea of dilation/contraction of the space with $\alpha$ is interesting especially in a multimodal setting.

Is there already something in place for dealing with auxiliary variable samplers? On top of what you propose I think this could be a nice add

— Reply to this email directly, view it on GitHub https://github.com/blackjax-devs/blackjax/issues/744#issuecomment-2389294282, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEYGFZYPH5FWYHX5U4TZEGDZZQYAPAVCNFSM6AAAAABPGOZ5JKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGOBZGI4TIMRYGI . You are receiving this because you commented.Message ID: @.***>