blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Merge dynamic_hmc and hmc #617

Closed reubenharry closed 10 months ago

reubenharry commented 10 months ago

Current behavior

Currently, hmc and dynamic_hmc are separate algorithms. The latter differs from the former by drawing the length of each proposal from a distribution, while the former has a fixed length.

Desired behavior

Clearly, the former is a special case of the latter. The code overlap is reasonably substantial, and is about to be doubled, because we will also want a dynamic and static version of MH-MCHMC.

It would be nice therefore if we right a version of the static hmc which simply is the dynamic hmc called with a distribution over lengths that is delta on some given length.

Additional enhancement

There is also an alternative way to draw the lengths, outlined in this paper: https://arxiv.org/abs/2110.11576 . It would be nice to include this as an option in dynamic_hmc, which just amounts to providing a new distribution to draw lengths from:

def halton(t, max_bits=10):
    """for t= 0., 1., 2., ... it outputs halton sequence at that index (0.5, 0.25, 0.75, ...)
        taken from: https://github.com/tensorflow/probability/blob/main/discussion/snaper_hmc/SNAPER-HMC.ipynb"""
    float_index = jnp.asarray(t)
    bit_masks = 2**jnp.arange(max_bits, dtype=float_index.dtype)
    return jnp.einsum('i,i->', jnp.mod((float_index + 1) // bit_masks, 2), 0.5 / bit_masks)

def rescale(mu):
    """returns s, such that 
        round(U(0, 1) * s + 0.5)
       has expected value mu.    
    """
    k = jnp.floor(2 * mu -1)
    x = k * (mu - 0.5 *(k+1)) / (k + 1 - mu)
    return k + x

def trajectory_length(t, mu):
    s = rescale(mu)
    return jnp.rint(0.5 + halton(t) * s)
junpenglao commented 10 months ago

+1. I think the easiest is to test if passing a lambda _: 10 to dynamic_hmc works as intended, especially the speed is the same using the same rng_key under CPU and GPU (should be the case but it is good to check). Then basically static_hmc would just call dynamic_hmc underneath.

As for the halton sequence we actually already have an implementation in https://github.com/blackjax-devs/blackjax/blob/540db419f0ccddb0368443049492b6c0448fe273/blackjax/adaptation/chees_adaptation.py#L451. probably just need to refactoring out to util.py

junpenglao commented 10 months ago

Actually taking a look at the implementation in hmc.py, dynamic_hmc is calling hmc_base so code repetition is not too bad, with the dynamic_hmc design which requires 2 additional functions (one for advancing the rng_key and one for generating the step_size), I dont think refactoring into static_hmc calling dynamic_hmc with a delta function would reduce the code complexity and improve code clarity really that much.

reubenharry commented 10 months ago

Yeah, that's fair. I think it would be a little clearer, because then the difference between static and dynamic would be very apparent, but I don't think it's urgent or even necessary.