Closed andrewdipper closed 2 weeks ago
Unless I'm mistaken I don't think any of the failures are related to the change - they also pass locally. But I haven't seen similar failures on other pr runs
@andrewdipper The issues are fixed upstream. Could you sync with the master?
Ah, sweet. Thanks!
Removes explicit looping for calculating the uturn checks to perform and replaces it with bitwise calculations.
I didn't see too much performance difference with a cpu backend but it gives ~20% boost with gpu on some models I'm working with. This clearly will depend on the model and # of steps though. For
jnp.bitwise_count((~n & (n + 1)) - 1)
(~n & (n + 1))
isolates the bit that changes from 0 to 1 when adding 1 - this is the first zero before the last sequence of ones.- 1
clears the above bit and sets all the original last non-zero bitsAs a side note it looks like the raveling / unraveling bog things down a bit too - but the option for block inverse mass matrices make that harder to resolve