@mo-osman wrote a parallel HMM sampler that we should incorporate into jax_moseq. Assigning this to @kaijfox once the lgssm sampler is done. The todos here are:
1) Confirm the correctness of the algorithm, and confirm that it affords a speed-up on GPU
2) Create an issue/PR for including it in dynamax
3) Have jax-moseq call the parallel version (based on a flag, similar to the parallel lgssm)
4) Confirm that it works and gives sensible outputs in keypoint-moseq on our tutorial dataset.
@mo-osman wrote a parallel HMM sampler that we should incorporate into jax_moseq. Assigning this to @kaijfox once the lgssm sampler is done. The todos here are:
1) Confirm the correctness of the algorithm, and confirm that it affords a speed-up on GPU 2) Create an issue/PR for including it in dynamax 3) Have jax-moseq call the parallel version (based on a flag, similar to the parallel lgssm) 4) Confirm that it works and gives sensible outputs in keypoint-moseq on our tutorial dataset.