probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

Refactor of LGSSM parallel inference #324

Closed calebweinreb closed 1 year ago

calebweinreb commented 1 year ago

This PR slightly refactors the LGSSM parallel inference code and adds a few new features. The changes include:

The code passes the current tests + a new sampling test (courtesy of @ezhang94)

from dynamax.linear_gaussian_ssm.parallel_inference_test import (
    TestTimeVaryingParallelLGSSMSmoother,
    TestTimeVaryingParallelLGSSMSampler,
    TestParallelLGSSMSmoother)

TestTimeVaryingParallelLGSSMSampler.test_sampled_covariances(TestTimeVaryingParallelLGSSMSampler)
TestTimeVaryingParallelLGSSMSampler.test_sampled_means(TestTimeVaryingParallelLGSSMSampler)

TestTimeVaryingParallelLGSSMSmoother.test_smoothed_means(TestTimeVaryingParallelLGSSMSmoother)
TestTimeVaryingParallelLGSSMSmoother.test_smoothed_covariances(TestTimeVaryingParallelLGSSMSmoother)
TestTimeVaryingParallelLGSSMSmoother.test_filtered_means(TestTimeVaryingParallelLGSSMSmoother)
TestTimeVaryingParallelLGSSMSmoother.test_filtered_covariances(TestTimeVaryingParallelLGSSMSmoother)
TestTimeVaryingParallelLGSSMSmoother.test_marginal_loglik(TestTimeVaryingParallelLGSSMSmoother)

TestParallelLGSSMSmoother.test_smoothed_means(TestParallelLGSSMSmoother)
TestParallelLGSSMSmoother.test_smoothed_covariances(TestParallelLGSSMSmoother)
TestParallelLGSSMSmoother.test_filtered_means(TestParallelLGSSMSmoother)
TestParallelLGSSMSmoother.test_filtered_covariances(TestParallelLGSSMSmoother)
TestParallelLGSSMSmoother.test_marginal_loglik(TestParallelLGSSMSmoother)

Currently the bias terms in these tests are 0, which should probably be changed so that we can ensure the biases are being handled correctly. I could just introduce a random bias, but was also wondering if it would make more sense to just import the test-models that are generated in inference_test.py? In general it seems kind of awkward that dynamax currently has three different ways to generate simulated LGSSMs for testing purposes, one for serial LGSSMs, one for parallel LGSSMs, and one for EKF testing.

slinderman commented 1 year ago

Thanks @calebweinreb! I agree that unifying the test model generation would be a useful contribution! Let's do that in a separate issue/PR though.

slinderman commented 1 year ago

Also, in case it seems like some of these changes came out of thin air, Caleb and I discussed adding the *Message classes offline. I think these nicely complement the parallel HMM inference code now.