dattalab / pyhsmm-library-models

library models built on top of pyhsmm
0 stars 1 forks source link

Model truncation is broken #26

Closed alexbw closed 11 years ago

alexbw commented 11 years ago

I might be calling this wrong...

    def truncate(self, n_target_states, X, X_held_out=None):

        # Assert that we've fit before, that we have a valid hsmm_model
        assert self.hsmm_model, "Must have already instantiated a model with fit() or reinstantiate()"

        # Truncate the model
        self.hsmm_model.truncate_num_states(n_target_states, destructive=True)

        # Update the number of states
        self.n_states = n_target_states

        # Update labels, labels_norep, durations, transition matrix,
        self.labels_ = self._get_labels()
        self.labels_norep_ = self._get_labels_norep()
        self.durations_ = self._get_durations()
        self.transmat_ = self.hsmm_model.trans_distn.A

        # Get the current likelihood
        self.sample_likelihoods_ = [self.hsmm_model.log_likelihood(X)]

        # Get the current held-out likelihood
        if X_held_out != None:
            self.heldout_sample_likelihoods_ = [self.hsmm_model.log_likelihood(X_held_out)]

        return self
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-15-b45972a9a6aa> in <module>()
----> 1 lhsmmModel.truncate(10, data, data_test)

/Users/Alex/Code/pymouse/lhsmm.py in truncate(self, n_target_states, X, X_held_out)
    466 
    467         # Truncate the model
--> 468         self.hsmm_model.truncate_num_states(n_target_states, destructive=True)
    469 
    470         # Update the number of states

/Users/Alex/Code/pyhsmm_library_models/library_models.pyc in truncate_num_states(self, target_num, destructive)
    326         for s in new.states_list:
    327             s.clear_caches()
--> 328             s.Viterbi()
    329 
    330         return new

/Users/Alex/Code/pyhsmm_library_models/pyhsmm/internals/states.pyc in Viterbi(self)
    182     def Viterbi(self):
    183         scores, args = self.maxsum_messages_backwards()
--> 184         self.maximize_forwards(scores,args)
    185 
    186     @staticmethod

/Users/Alex/Code/pyhsmm_library_models/pyhsmm/internals/states.pyc in maximize_forwards(self, scores, args)
    919 
    920         stateseq = np.empty(T,dtype=np.int32)
--> 921         stateseq[0] = (scores[0,start_indices] + np.log(self.hsmm_pi_0) + self.hsmm_aBl[0]).argmax()
    922         initial_hmm_state = start_indices[stateseq[0]]
    923 

ValueError: operands could not be broadcast together with shapes (10) (200) 

Digging in to see if I'm not setting something properly on self.hsmm_model. Do note that hsmm_model has had Viterbi just called on it, when reinstantiating from labels.

alexbw commented 11 years ago

Btw, just figured out you can color code blocks by language with ```python

alexbw commented 11 years ago

Whoops, this should be in #24. All further comments there.

mattjj commented 11 years ago

Fixed in 16d985e