probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
636 stars 69 forks source link

Support parallel inference in LGSSM with time-varying dynamics #303

Closed ezhang94 closed 1 year ago

ezhang94 commented 1 year ago
slinderman commented 1 year ago

Summarizing our in person conversation:

slinderman commented 1 year ago

@murphyk @gileshd: This PR changes the LGSSM models to support sampling with time-varying parameters. However, the fitting code still only supports static parameters. I'm a little nervous that this could lead to confusion if someone constructed a model with time-varying parameters then called fit and found those parameters were replaced with static ones.

My proposal to @ezhang94 was that we should re-introduce an lgssm_joint_sample function into dynamax.lgssm.inference. That function would draw samples of states and emissions from the joint distribution. I know we removed such sample functions early on since they are not technically "inference" code, but it seems like such a low level function could be useful. Thoughts?

ezhang94 commented 1 year ago

@slinderman I have reverted models.py to its original format and added a lgssm_joint_sample function in inference.py as suggested.

gileshd commented 1 year ago

Sorry for being slow to respond to this.

This looks great, I agree that lgsmm.inference doesn't feel like the spiritual home for the lgssm_joint_sample function but it's probably better to have it slightly out of place than to have users confused when their time-varying parameters get silently replaced with static ones.

One potential small concern I have is that "joint_sample" might not be clear enough as a name. No succinct alternatives come to mind, lgssm_sample_states_and_emissions is a bit of a mouthful but perhaps that's not such a concern? At very least I think it might be helpful to be explicit in the docstring we are sampling from the joint distribution over states and emissions (not that there really is another natural joint that we would be talking about but if someone is new to these models it might be helpful).