Closed ezhang94 closed 1 year ago
Summarizing our in person conversation:
_get_params
at the top of the file and open an issue to make a get_params
(or suitably renamed) util function.@murphyk @gileshd: This PR changes the LGSSM models to support sampling with time-varying parameters. However, the fitting code still only supports static parameters. I'm a little nervous that this could lead to confusion if someone constructed a model with time-varying parameters then called fit and found those parameters were replaced with static ones.
My proposal to @ezhang94 was that we should re-introduce an lgssm_joint_sample
function into dynamax.lgssm.inference
. That function would draw samples of states and emissions from the joint distribution. I know we removed such sample functions early on since they are not technically "inference" code, but it seems like such a low level function could be useful. Thoughts?
@slinderman I have reverted models.py
to its original format and added a lgssm_joint_sample
function in inference.py
as suggested.
Sorry for being slow to respond to this.
This looks great, I agree that lgsmm.inference
doesn't feel like the spiritual home for the lgssm_joint_sample
function but it's probably better to have it slightly out of place than to have users confused when their time-varying parameters get silently replaced with static ones.
One potential small concern I have is that "joint_sample" might not be clear enough as a name. No succinct alternatives come to mind, lgssm_sample_states_and_emissions
is a bit of a mouthful but perhaps that's not such a concern? At very least I think it might be helpful to be explicit in the docstring we are sampling from the joint distribution over states and emissions (not that there really is another natural joint that we would be talking about but if someone is new to these models it might be helpful).
Passes in time indices and make use of helper functions to get correct parameters for making
linear_gaussian_ssm.parallel_inference._make_associative_*_elements
Minor addition: Added sampling code that supports LinearGaussianSSM with time-varying parameters by specifying an optional time index
t
input to the emission and transition distributions. An alternative approach is to have the sampling code handle the time-varying parameters and only pass inparams_t
. This code was added to support the parallel inference test code with time-varying techniques and is not essential.