probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

Smoother code could be simplified by using jax.lax.scan reverse=True argument #364

Closed edeno closed 6 days ago

edeno commented 1 week ago

A number of smoother algorithms reverse the inputs and then run them through jax.lax.scan. From my understanding, this creates additional array copies in memory, so it would be an enhancement to use this keyword argument. I suspect this is because the reverse=True argument did not exist at the time of implementation.

Happy to add a PR if this is indeed an issue.

slinderman commented 1 week ago

Good catch, @edeno! I just checked and it looks like we only used the reverse kwarg in the parallel inference code. If you want to submit a PR to use reverse in the other code paths, that would be amazing!