lindermanlab / ssm

Bayesian learning and inference for state space models
MIT License
540 stars 196 forks source link

HMM input driven observations module parameters retrieval #141

Closed NeuTTH closed 1 year ago

NeuTTH commented 2 years ago

Hey guys! Thank you so much for creating this package again. I have been using the input-driven HMM to do some behavior modeling for my project. I was trying to retrieve the observational GLM parameters for further analysis, and I realized that the weight matrix is one-off on the number of classes. This line is found from observations.py, line 647:

self.Wk = npr.randn(K, C - 1, M)

with K = number of states, C = number of distinct classes for each dimension of output, and M = input dimensions. The weight matrix gets re-propagated to shape (K, C, M) during the function calculate_logits (observations.py, line 680). If I am to retrieve this matrix with one less class, how do I assign the weights to each class (feature)? Thanks!!

slinderman commented 1 year ago

The reason Wk has one fewer dimension is that the categorical distribution over C classes only has C-1 degrees of freedom, since the probabilities have to sum to one. You can think of the weights for the last class as being all zero.

Let $W \in \mathbb{R}^{K \times C-1 \times M}$ denote the weights and $w_{k,c} \in \mathbb{R}^M$ denote the weights for state $k$ and class $c$. Let $u_t \in \mathbb{R}^M$ denote the inputs at time $t$. The log likelihood of seeing emission $x_t = c$ when in discrete state $z_t = k$ is,

$$ \log p(x_t = c \mid u_t, zt = k, W) = w{k,c}^\top ut - \log \left(1 + \sum{c'=1}^{C-1} w_{k,c'}^\top u_t \right) $$

if $c < C$, and

$$ \log p(x_t = C \mid u_t, zt = k, W) = - \log \left(1 + \sum{c'=1}^{C-1} w_{k,c'}^\top u_t \right) $$

if $c = C$. Equivalently, you can think of the weights as specifying the log odds ratio,

$$ \log \frac{p(x_t = c \mid u_t, z_t = k, W)}{p(x_t = C \mid u_t, zt = k, W)} = w{k,c}^\top u_t $$

This may be slightly harder to interpret than in the over-parameterized case model where $W \in \mathbb{R}^{K \times C \times M}$. However, in the over-parameterized model, the weights are only defined up to additive shift.

I'll go ahead and close this, but feel free to chime in @zashwood.