Closed reubenharry closed 10 months ago
+1. I think the easiest is to test if passing a lambda _: 10
to dynamic_hmc works as intended, especially the speed is the same using the same rng_key under CPU and GPU (should be the case but it is good to check). Then basically static_hmc would just call dynamic_hmc underneath.
As for the halton sequence we actually already have an implementation in https://github.com/blackjax-devs/blackjax/blob/540db419f0ccddb0368443049492b6c0448fe273/blackjax/adaptation/chees_adaptation.py#L451. probably just need to refactoring out to util.py
Actually taking a look at the implementation in hmc.py
, dynamic_hmc
is calling hmc_base
so code repetition is not too bad, with the dynamic_hmc
design which requires 2 additional functions (one for advancing the rng_key
and one for generating the step_size), I dont think refactoring into static_hmc calling dynamic_hmc with a delta function would reduce the code complexity and improve code clarity really that much.
Yeah, that's fair. I think it would be a little clearer, because then the difference between static and dynamic would be very apparent, but I don't think it's urgent or even necessary.
Current behavior
Currently, hmc and dynamic_hmc are separate algorithms. The latter differs from the former by drawing the length of each proposal from a distribution, while the former has a fixed length.
Desired behavior
Clearly, the former is a special case of the latter. The code overlap is reasonably substantial, and is about to be doubled, because we will also want a dynamic and static version of MH-MCHMC.
It would be nice therefore if we right a version of the static hmc which simply is the dynamic hmc called with a distribution over lengths that is delta on some given length.
Additional enhancement
There is also an alternative way to draw the lengths, outlined in this paper: https://arxiv.org/abs/2110.11576 . It would be nice to include this as an option in dynamic_hmc, which just amounts to providing a new distribution to draw lengths from: