lindermanlab / ssm-jax

Bayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend
MIT License
58 stars 7 forks source link

Factorial HMM #9

Closed ahwillia closed 2 years ago

ahwillia commented 2 years ago

Here is a prototype for a factorial HMM. @slinderman -- I probably need your help specifying the posterior distribution in a way that tensorflow probability will accept...

Note -- I think the the m-step of factorial transitions will be greatly simplified if we could pass expected_transitions to the m-step function instead of (dataset, posteriors).

slinderman commented 2 years ago

This looks great so far! Happy to pair program tomorrow morning if you're available.

It could be tough to get TFP to play nicely since expected_transitions is here a tuple of unknown length. TFP Distributions want to know the number of parameters in advance so they can slice and broadcast appropriately. That said, this approach is really clever, so I'm partial toward making it work even if it means FactorialHMMPosterior can't be a Distribution.

I think you can sum the expected transitions more easily by doing tree_map(partial(np.sum, axis=0), expected_transitions)

github-actions[bot] commented 2 years ago

Unit Test Results

  1 files    1 suites   2m 10s :stopwatch: 26 tests 26 :heavy_check_mark: 0 :zzz: 0 :x:

Results for commit 1bae4404.

:recycle: This comment has been updated with latest results.