mattjj / pyhsmm

MIT License
546 stars 173 forks source link

Example codes for using SVI for HDP-HSMM without change points #48

Open wangsix opened 9 years ago

wangsix commented 9 years ago

Basically I was wondering if it`s possible to repeat the routine demonstrated in examples/hsmm.py with SVI training not sampling. The other examples used the possible change point models which are not presented in the data I have. Thanks!

mattjj commented 9 years ago

You might be able to interpolate between examples/svi.py and examples/hsmm.py. Basically, replace the HMM instantiation in examples/svi.py with an HSMM instantiation and try calling meanfield_sgdstep. I think Poisson durations and some version of NegativeBinomial durations implement the MeanFieldSVI interface, which indicates that they can be used with SVI, but other duration models would need their own implementations (I don't have any plans to implement more at the moment).

powersimmani commented 7 years ago

I tried to change the code that you have mentioned which changing examples/svi.py's hmm to hsmm. also I tried to change other resampling codes to svi or meanfield. But It always reply the AttritubteError like below. Could you tell us little more detail that how to fix the code to the SVI or Meanfield Inference on HDP-HSMM?

Here is the error output

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-4-dd411d9e9464> in <module>()
     44 sgdseq = sgd_passes(tau=0,kappa=0.7,datalist=datas)
     45 for t, (data, rho_t) in progprint(enumerate(sgdseq)):
---> 46     hmm.meanfield_sgdstep(data, data.shape[0] / training_size, rho_t)
     47 
     48     if t % 10 == 0:

/home/simmani91/anaconda2/lib/python2.7/site-packages/pyhsmm/models.pyc in meanfield_sgdstep(self, minibatch, prob, stepsize, num_procs, **kwargs)
    584         if num_procs == 0:
    585             for s in mb_states_list:
--> 586                 s.meanfieldupdate()
    587         else:
    588             self._joblib_meanfield_update_states(mb_states_list,num_procs)

/home/simmani91/anaconda2/lib/python2.7/site-packages/pyhsmm/internals/hmm_states.pyc in meanfieldupdate(self)
    451         self.clear_caches()
    452         self.all_expected_stats = self._expected_statistics(
--> 453             self.mf_trans_matrix,self.mf_pi_0,self.mf_aBl)
    454         self._mf_param_snapshot = (
    455             np.log(self.mf_trans_matrix), np.log(self.mf_pi_0),

/home/simmani91/anaconda2/lib/python2.7/site-packages/pyhsmm/internals/hmm_states.pyc in mf_trans_matrix(self)
    433     @property
    434     def mf_trans_matrix(self):
--> 435         return self.model.trans_distn.exp_expected_log_trans_matrix
    436 
    437     @property

AttributeError: 'WeakLimitStickyHDPHMMTransitions' object has no attribute 'exp_expected_log_trans_matrix'
jengelman commented 7 years ago

@mattjj, I ran into the same issue as @powersimmani, have you found a workaround for this?

mattjj commented 7 years ago

I haven't implemented mean field for the Sticky HDPHMM, so unfortunately it can't do SVI either. If you want to add it, the place to look is in transitions.py, and in particular for the classes that have MeanField and SVI in their names. Unfortunately I don't have any plans to add that at the moment, but pull requests are welcome!

jengelman commented 7 years ago

Started working on the WeakLimit samplers, and in the meantime I was able to run the SVI example using the DATruncHDPHSMM sampler. From reading the papers, seems like implementing the Sticky samplers with DA truncation should be pretty east, albeit slower than WeakLimit, so might start there. While DATruncHDPHMM runs at around the same speed as HMM (.05s avg/svi pass), DATruncHDPHSMM is around 100x slower (4.88s avg/svi pass) with worse results. Is this expected?

I also got the following warnings for every model:

/Users/joshengelman/miniconda2/lib/python2.7/site-packages/pyhsmm/internals/hmm_states.py:659: RuntimeWarning: divide by zero encountered in log np.log(trans_potential),likelihood_log_potential,alphal,betal, /Users/joshengelman/miniconda2/lib/python2.7/site-packages/pyhsmm/internals/hmm_states.py:455: RuntimeWarning: divide by zero encountered in log np.log(self.mf_trans_matrix), np.log(self.mf_pi_0),