pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.09k stars 227 forks source link

Convert explicit looping to bit twiddling for nuts u-turn calculations #1818

Closed andrewdipper closed 2 weeks ago

andrewdipper commented 2 weeks ago

Removes explicit looping for calculating the uturn checks to perform and replaces it with bitwise calculations.

I didn't see too much performance difference with a cpu backend but it gives ~20% boost with gpu on some models I'm working with. This clearly will depend on the model and # of steps though. For jnp.bitwise_count((~n & (n + 1)) - 1) (~n & (n + 1)) isolates the bit that changes from 0 to 1 when adding 1 - this is the first zero before the last sequence of ones. - 1 clears the above bit and sets all the original last non-zero bits

As a side note it looks like the raveling / unraveling bog things down a bit too - but the option for block inverse mass matrices make that harder to resolve

andrewdipper commented 2 weeks ago

Unless I'm mistaken I don't think any of the failures are related to the change - they also pass locally. But I haven't seen similar failures on other pr runs

fehiepsi commented 2 weeks ago

@andrewdipper The issues are fixed upstream. Could you sync with the master?

andrewdipper commented 2 weeks ago

Ah, sweet. Thanks!