lindermanlab / ssm

Bayesian learning and inference for state space models
MIT License
575 stars 202 forks source link

Identity Emissions don't work #88

Closed bantin closed 4 years ago

bantin commented 4 years ago

None of the identity emissions models seems to work, e.g "GaussianIdentityEmissions, StudentstIdentityEmissions", etc.

The failure looks as follows:


---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-4-ef1137a96c5c> in <module>
      4 lds = ssm.LDS(N, D, dynamics="rotational", emissions="studentst_id")
      5 # %lprun -f ssm.emissions.GaussianOrthogonalEmissions.m_step lds.fit(data_list, num_iters=1)
----> 6 elbos, posterior = lds.fit(data_list, num_iters=20)

/Users/bantin/Documents/Linderman-Shenoy/ssm/ssm/util.py in wrapper(self, datas, inputs, masks, tags, **kwargs)
    108             tags = [tags]
    109 
--> 110         return f(self, datas, inputs=inputs, masks=masks, tags=tags, **kwargs)
    111 
    112     return wrapper

/Users/bantin/Documents/Linderman-Shenoy/ssm/ssm/lds.py in fit(self, datas, inputs, masks, tags, method, variational_posterior, variational_posterior_kwargs, initialize, num_init_iters, **kwargs)
    859         variational_posterior_kwargs = variational_posterior_kwargs or {}
    860         posterior = self._make_variational_posterior(variational_posterior, datas, inputs, masks, tags, method, **variational_posterior_kwargs)
--> 861         elbos = _fitting_methods[method](posterior, datas, inputs, masks, tags, learning=True, **kwargs)
    862         return elbos, posterior
    863 

/Users/bantin/Documents/Linderman-Shenoy/ssm/ssm/lds.py in _fit_laplace_em(self, variational_posterior, datas, inputs, masks, tags, num_iters, num_samples, continuous_optimizer, continuous_tolerance, continuous_maxiter, emission_optimizer, emission_optimizer_maxiter, parameters_update, alpha, learning)
    754                 self._fit_laplace_em_params_update(
    755                     discrete_expectations, continuous_expectations, datas, inputs, masks, tags,
--> 756                     emission_optimizer, emission_optimizer_maxiter, alpha)
    757             # Alternative is SGD on all parameters with samples from q(x)
    758             elif learning and parameters_update=="sgd":

/Users/bantin/Documents/Linderman-Shenoy/ssm/ssm/lds.py in _fit_laplace_em_params_update(self, discrete_expectations, continuous_expectations, datas, inputs, masks, tags, emission_optimizer, emission_optimizer_maxiter, alpha)
    653 
    654         # update emissions params
--> 655         curr_prms = copy.deepcopy(self.emissions.params)
    656         self.emissions.m_step(discrete_expectations, continuous_expectations,
    657                               datas, inputs, masks, tags,

/Users/bantin/Documents/Linderman-Shenoy/ssm/ssm/emissions.py in params(self)
    461     @property
    462     def params(self):
--> 463         return super(_StudentsTEmissionsMixin, self).params + (self.inv_etas, self.inv_nus)
    464 
    465     @params.setter

/Users/bantin/Documents/Linderman-Shenoy/ssm/ssm/emissions.py in params(self)
     21     @property
     22     def params(self):
---> 23         raise NotImplementedError
     24 
     25     @params.setter

NotImplementedError: 

This seems to be an issue with inheritance and mixins. All of the Identity emissions models inherit from a distribution class and the _IdentityEmissions class. For example, the GaussianIdentityEmissions declarations is

class GaussianIdentityEmissions(_GaussianEmissionsMixin, _IdentityEmissions):

So when code tries to access GaussianIdentityEmissions.params, python will find the .params method in the _GaussianEmissionsMixin, which will call the params method for _IdentityEmissions, which finally calls the params method for the base Emissions class and raises NotImplementedError.

So the call chain (I think) is: _GaussianEmissionsMixin.params => _IdentityEmissions.params => Emissions.params.

One solution here is to add a params function to _IdentityEmissions which returns an empty list. There might be a more elegant solution though.