probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

Use reverse=True keyword argument in lax.scan for smoothers #365

Closed edeno closed 6 days ago

edeno commented 1 week ago

Fixes issue https://github.com/probml/dynamax/issues/364 by using the reverse=True keyword argument in lax.scan function.

Tests are passing locally except for the unscented kalman filter inference tests, but these were also failing for the original code as far as I can tell?

I also had to pin numpy < 2.0 because tensorflow_probability was failing (note that this is also a problem in the docs.

There could potentially be further improvement in eliminating unnecessary memory copies by array slicing but I think it would destroy some of the readability of the code and result in some computation overhead. For example, the stack operations (jnp.vstack([smoothed_probs, filtered_probs[-1]]) and slicing filtered_probs[:-1]) create copies. This isn't really a problem unless you have a large number of states:

    # Run the HMM smoother
    _, smoothed_probs = lax.scan(
        _step,
        filtered_probs[-1],
        (jnp.arange(num_timesteps - 1), filtered_probs[:-1], predicted_probs[1:]),
        reverse=True,
    )

    # Concatenate the arrays and return
    smoothed_probs = jnp.vstack([smoothed_probs, filtered_probs[-1]])
slinderman commented 6 days ago

We can track the test failure here: https://github.com/probml/dynamax/issues/367