Removes explicit looping for calculating the uturn checks to perform and replaces it with bitwise calculations.
I don't see much performance difference with a cpu backend but it gives ~30% boost with gpu on some models I'm working with. This clearly will depend on the model and # of steps though. For jnp.bitwise_count((~n & (n + 1)) - 1)(~n & (n + 1)) isolates the bit that changes from 0 to 1 when adding 1 - this is the first zero before the last sequence of ones. - 1 clears the above bit and sets all the original last non-zero bits
Removes explicit looping for calculating the uturn checks to perform and replaces it with bitwise calculations.
I don't see much performance difference with a cpu backend but it gives ~30% boost with gpu on some models I'm working with. This clearly will depend on the model and # of steps though. For
jnp.bitwise_count((~n & (n + 1)) - 1)
(~n & (n + 1))
isolates the bit that changes from 0 to 1 when adding 1 - this is the first zero before the last sequence of ones.- 1
clears the above bit and sets all the original last non-zero bits