blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Convert explicit looping to bit twiddling for iterative_uturn calculations #696

Closed andrewdipper closed 3 months ago

andrewdipper commented 3 months ago

Removes explicit looping for calculating the uturn checks to perform and replaces it with bitwise calculations.

I don't see much performance difference with a cpu backend but it gives ~30% boost with gpu on some models I'm working with. This clearly will depend on the model and # of steps though. For jnp.bitwise_count((~n & (n + 1)) - 1) (~n & (n + 1)) isolates the bit that changes from 0 to 1 when adding 1 - this is the first zero before the last sequence of ones. - 1 clears the above bit and sets all the original last non-zero bits