mattjj / pyslds

MIT License
90 stars 35 forks source link

Shape AssertionError on starting example #29

Closed kekehia123 closed 6 years ago

kekehia123 commented 6 years ago

When I run the given example:

------------------------------------------------------------------------------

import numpy.random as npr from pyslds.models import DefaultSLDS

K = 5 # Number of discrete latent states D_obs = 1 # Observed data dimension D_latent = 2 # Latent state dimension D_input = 0 # Exogenous input dimension T = 2000 # Number of time steps to simulate

true_model = DefaultSLDS(K, D_obs, D_latent, D_input) inputs = npr.randn(T, D_input) y, x, z = true_model.generate(T, inputs=inputs)

Compute the log likelihood of the data with the true params

true_ll = true_model.log_likelihood()

Create a separate model and add the observed data

test_model = DefaultSLDS(K, D_obs, D_latent, D_input) test_model.add_data(y)

Run the Gibbs sampler

N_samples = 1000 def update(model): model.resample_model() return model.log_likelihood()

lls = [update(testmodel) for in range(N_samples)]

-----------------------------------------------------------------------------

I got an error:

AssertionError Traceback (most recent call last)

in () 9 return model.log_likelihood() 10 ---> 11 lls = [update(test_model) for _ in range(N_samples)] in (.0) 9 return model.log_likelihood() 10 ---> 11 lls = [update(test_model) for _ in range(N_samples)] in update(model) 6 N_samples = 1000 7 def update(model): ----> 8 model.resample_model() 9 return model.log_likelihood() 10 ~/anaconda3/lib/python3.5/site-packages/pyhsmm/models.py in resample_model(self, num_procs) 440 def resample_model(self,num_procs=0): 441 self.resample_parameters() --> 442 self.resample_states(num_procs=num_procs) 443 444 @line_profiled ~/anaconda3/lib/python3.5/site-packages/pyhsmm/models.py in resample_states(self, num_procs) 465 if num_procs == 0: 466 for s in self.states_list: --> 467 s.resample() 468 else: 469 self._joblib_resample_states(self.states_list,num_procs) /data/home/sxk/Documents/eeg_anal/pyslds/pyslds/states.py in resample(self, niter) 416 for itr in range(niter): 417 self.resample_discrete_states() --> 418 self.resample_gaussian_states() 419 420 def _init_gibbs_from_mf(self): /data/home/sxk/Documents/eeg_anal/pyslds/pyslds/states.py in resample_gaussian_states(self) 427 self._aBl = None # clear any caching 428 self._gaussian_normalizer, self.gaussian_states = \ --> 429 info_sample(*self.info_params) 430 431 ~/anaconda3/lib/python3.5/site-packages/pylds/lds_messages_interface.py in wrapped(*args, **kwargs) 63 @wraps(func) 64 def wrapped(*args, **kwargs): ---> 65 return func(*check(*args,**kwargs)) 66 return wrapped 67 ~/anaconda3/lib/python3.5/site-packages/pylds/lds_messages_interface.py in _info_argcheck(J_init, h_init, log_Z_init, J_pair_11, J_pair_21, J_pair_22, h_pair_1, h_pair_2, log_Z_pair, J_node, h_node, log_Z_node) 93 J_pair_11, J_pair_21, J_pair_22, J_node = \ 94 map(partial(_ensure_ndim, T=T, ndim=3), ---> 95 [J_pair_11, J_pair_21, J_pair_22, J_node]) 96 h_pair_1, h_pair_2 = \ 97 map(partial(_ensure_ndim, T=T, ndim=2), ~/anaconda3/lib/python3.5/site-packages/pylds/lds_messages_interface.py in _ensure_ndim(X, T, ndim) 22 assert ndim-1 <= X.ndim <= ndim 23 if X.ndim == ndim: ---> 24 assert X.shape[0] == T 25 return X 26 else: AssertionError: --------------------------------------------------------------------------- Then I print T and the shape of X. I found T is 2000 and X.shape is (1999, 2, 2) . Then I changed "assert X.shape[0] == T" as "assert X.shape[0]+1 == T" in _ensure_ndim(X, T, ndim). But I still got an error in later loops. It turned out that X.shape[0] has both values of 1999 and 2000. Could anyone help me with this problem? Thanks in advance!