probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
667 stars 76 forks source link

LGSSM with diagonal noise #332

Closed calebweinreb closed 1 year ago

calebweinreb commented 1 year ago

Changes

This PR addresses issue https://github.com/probml/dynamax/issues/331 by allowing the emissions covariance to be stored as a static or time-varying 1D array containing just the diagonal entries and optimizing certain computations when this is the case. Changes include:

Benchmark

The optimizations seem to speed things up when the emissions dim surpasses ~50.

Benchmarking code: https://gist.github.com/calebweinreb/35bcb53f25d83d0f002b731ccca3b91a

slinderman commented 1 year ago

Thanks @calebweinreb! This is a really nice feature. I requested a few minor changes above.

calebweinreb commented 1 year ago

The requested changes are implemented here https://github.com/probml/dynamax/commit/9c202d853f247125cfc4196376b4fe70517128cb

I'm a little confused about the tests failing since it seems to be a dependency issue... In any casepytest dynamax passes when I run it locally.

slinderman commented 1 year ago

I think the tests were failing because they were still using python 3.7. I just updated our tests to use 3.10.6 (current Colab version) and they seem to run. Can you merge the latest changes from main and push again? That should trigger another run of the tests.

calebweinreb commented 1 year ago

Hey! Looks like all the tests pass now.

bantin commented 9 months ago

Thanks for all your work on this @calebweinreb. Sorry to revive an old thread -- is there an easy way of fitting an LDS with diagonal emissions covariance? As best I can tell, the EM code in models.py assums a full covariance.