Closed sarathnayar closed 1 year ago
I am also wondering how to achieve this, instead with rSLDS models. Cross-validation is pretty important; can I use held-out data to test my trained models?
I have tried using the cross_val_scores function within ssm.model_selection; however, either I am not understanding the function parameters or there is a logic error in the function, because it always returns 0 for test scores.
Sounds like there are a few questions here. One is with regard to computing marginal log likelihoods of held-out test data. For that, you can use hmm.log_likelihood(datas, inputs)
where datas
is an array or list of arrays and inputs
is a matching array or list of arrays. Then compare the log likelihood of the held out data for each model and take the model with the highest score.
The other question is about making predictions with the input driven model. If you want to predict the emission classes for held-out data under an HMM with InputDrivenObservations
, you can use the calculate_logits
function to get a T x K x C
array of log probabilities at each time step, for each discrete state, and for each emission class. You probably want to average over discrete states somehow. Maybe using the predicted probabilities given past observations, or under the posterior distribution given all observations. You can do the latter with the expected_states
function. It would look something like this:
# hmm: a trained HMM with InputDrivenObservations
# data: a TxD array of held out observations
# input: a TxM array held out inputs
expected_states = hmm.expected_states(data, input) # shape TxK
emission_logits = hmm.observations.calculate_logits(input) # shape TxKxC
emission_probs = np.exp(emission_logits)
predictions = np.einsum('tk,tkc->tc', expected_states, emission_probs)
Dear all,
I have fitted a GLMHMM (with input driven observations) model to behavioral data. To quantify the performance, I compute the log likelihood on a held out test set. I am also interested in predicting the observations on the test set using the fitted model to find the predictive accuracy. I wanted to know how can I predict the observations from inputs using the fitted GLMHMM model. Thank you very much for any pointers.
Regards, Sarath