probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
636 stars 69 forks source link

Is the `hmm_smoother` function missing a time-step when there are time-dependent transition functions? #310

Open mikewojnowicz opened 1 year ago

mikewojnowicz commented 1 year ago

In the "non-stationary" setting (which seems to be operationally defined here as referring to time-dependent parameters), the hmm_smoother function applied to a time-series of T elements returns T-2 transition probabilities, rather than T-1 transition probabilities.

E.g. if one enters into the unit tests here, and adds the two assertions below, both pass.

 assert jnp.shape(post.trans_probs)[0]==num_timesteps-2
 assert jnp.shape(post2.trans_probs)[0]==num_timesteps-2

Isn't this a mistake?

  1. I expected T-1 such probabilities based on the math
  2. The hmm_expected_states function in the old ssm repo, which performed a similar role, returned an expected_joints value that had T-1, not T-2, entries.
DBraun commented 1 month ago

I'm trying to make a toy example of a non-stationary model (if you have one I'd like to see it too). I noticed the same pattern of shapes you noticed. In my case it's related to these two lines:

https://github.com/probml/dynamax/blob/46fb2338f6daa628225b8e1934aef57cad2264b7/dynamax/hidden_markov_model/inference.py#L589

https://github.com/probml/dynamax/blob/46fb2338f6daa628225b8e1934aef57cad2264b7/dynamax/hidden_markov_model/inference.py#L598

where each line takes 1 away from the length of the transition probabilities. Maybe the -1 in len(filtered_probs)-1 isn't necessary?