mattjj / pyhsmm

MIT License
548 stars 174 forks source link

Quick question #62

Closed mg10011 closed 8 years ago

mg10011 commented 8 years ago

This is a great little library, BTW.

I ran your example, and besides the plot of the various 2-dimensional distributions, how can I get the actual posterior estimates of the following:

  1. the mean vector and covariance matrix of each state
  2. the poisson parameter for the duration distribution.

Is there anyway to plot the samples of the duration, so as to see the sampled distribution?

Thanks, mg

mattjj commented 8 years ago

Hey, glad you like it!

Extracting estimates of parameters depends on what inference algorithm you're using and what observation/duration distributions you're using, but a basic way to access some relevant data of the Gaussian (like in the examples) is

# assumes model.obs_distns is a list of Gaussian instances
for o in model.obs_distns:
    print o.mu
    print o.sigma
    print

You can also get parameters as a dict by o.params for some distribution classes (including Gaussian), but I don't think that's implemented for all or even most of them.

If you're using EM (or Viterbi EM), those properties are set to the output of the M step on each iteration. If you're using Gibbs sampling, those are sampled values, so you may want to collect such samples across many iterations of the Markov chain (interleaving calls to model.resample_model() and copying [o.params for o in model.obs_distns], since resampling might modify the values in-place). If you're using mean field, you might want to access some other properties, but those parameter properties are usually updated with something like the mode of the variational factor.

The story is similar with duration distributions. For Poisson instances, do this:

for d in model.dur_distns:
    print d.lmbda

or try [d.params for d in model.dur_distns].

In general, to find what properties a distribution or model object might have, you'll want to read the code, either here or in pybasicbayes. There's not really any documentation but the code is mostly organized and readable.

As for plotting the samples of the duration, you'll have to write your own plotting code, but these commented-out lines might be helpful (I can't remember why I commented them out, but it's probably because I updated the HMM plotting code and didn't update that HSMM plotting code). In particular, an HMM or HSMM model has a states_list member, which is a list that gets appended to every time you call add_data. The elements of the list are States instances and each has a stateseq_norep property and a durations property. So to get a list of arrays of durations for state 5, you could write something like

[s.durations[s.stateseq_norep == 5] for s in model.states_list]

Hope that helps!