dattalab / jax-moseq

Other
5 stars 5 forks source link

Restructured modeling code #1

Closed mo-osman closed 1 year ago

mo-osman commented 1 year ago

Before accepting, please review this pull request carefully to ensure that the code is sensibly structured/mathematically sound, and feel free to edit it accordingly.

This update refactors the modeling code into separate ARHMM, SLDS, and Keypoint SLDS modules. Each module has separate files for initialization, resampling, and log likelihood computation, and each relies on the code for the model(s) it builds on.

The following small changes to the model initialization method should be noted, as they are breaking with respect to the keypoint-moseq code: (1) the initialization method now takes a data dictionary containing Y, mask, and (optionally) conf rather than separate keyword arguments for each of these values and (2) the latent_dimension keyword argument is now instead specified in the latent_dim field of the ar_hypparams. A small complementary commit will be made to keypoint-moseq to reflect these alterations.