lindermanlab / ssm-jax

Bayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend
MIT License
58 stars 7 forks source link

Error: `model` must be convertible to `dict` (saw: DeviceArray) #38

Closed lx704612715 closed 1 year ago

lx704612715 commented 1 year ago

Dear all,

I am trying to run the example code "GaussianHMM." However, I got an error saying "TypeError: model must be convertible to dict (saw: DeviceArray).".

I searched for this error in the Jax community but could not find a solution. Could you please help me out? Thank you very much!

from ssm.hmm import GaussianHMM
import jax.random as jr

# create a true HMM model
hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(0))
states, data = hmm.sample(key=jr.PRNGKey(1), num_steps=500, num_samples=5)

# create a test HMM model
test_hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(32))

# fit it to our sampled data
log_probs, fitted_model, posteriors = test_hmm.fit(data, method="em")
Initializing...

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[6], line 12
      9 test_hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(32))
     11 # fit it to our sampled data
---> 12 log_probs, fitted_model, posteriors = test_hmm.fit(data, method="em")

File ~/tmp/ssm-jax/ssm/utils.py:254, in ensure_has_batch_dim.<locals>.ensure_has_batch_dim_decorator.<locals>.wrapper(*args, **kwargs)
    250         if key in bound_args.arguments and bound_args.arguments[key] is not None:
    251             bound_args.arguments[key] = \
    252                 tree_map(lambda x: x[None, ...], bound_args.arguments[key])
--> 254 return f(**bound_args.arguments)

File ~/tmp/ssm-jax/ssm/hmm/base.py:201, in HMM.fit(self, data, covariates, metadata, method, num_iters, tol, initialization_method, key, verbosity)
    199 if initialization_method is not None:
    200     if verbosity >= Verbosity.LOUD : print("Initializing...")
--> 201     self.initialize(key, data, method=initialization_method)
    202     if verbosity >= Verbosity.LOUD: print("Done.", flush=True)
    204 if method == "em":

File ~/tmp/ssm-jax/ssm/utils.py:254, in ensure_has_batch_dim.<locals>.ensure_has_batch_dim_decorator.<locals>.wrapper(*args, **kwargs)
    250         if key in bound_args.arguments and bound_args.arguments[key] is not None:
    251             bound_args.arguments[key] = \
    252                 tree_map(lambda x: x[None, ...], bound_args.arguments[key])
--> 254 return f(**bound_args.arguments)

File ~/tmp/ssm-jax/ssm/hmm/base.py:132, in HMM.initialize(self, key, data, covariates, metadata, method)
    129 dummy_posteriors = DummyPosterior(one_hot(assignments, self._num_states))
    131 # Do one m-step with the dummy posteriors
--> 132 self._emissions.m_step(data, dummy_posteriors)

File ~/tmp/ssm-jax/ssm/hmm/emissions.py:161, in ExponentialFamilyEmissions.m_step(self, dataset, posteriors, covariates, metadata)
    145 def m_step(self, dataset, posteriors, covariates=None, metadata=None) -> ExponentialFamilyEmissions:
    146     """Update the emissions distribution using an M-step.
    147 
    148     Operates over a batch of data (posterior must have the same batch dim).
   (...)
    159         emissions (ExponentialFamilyEmissions): updated emissions object
    160     """
--> 161     conditional = self._emissions_distribution_class.compute_conditional(
    162         dataset, weights=posteriors.expected_states, prior=self._prior)
    163     self._distribution = self._emissions_distribution_class.from_params(
    164         conditional.mode())
    165     return self

File ~/tmp/ssm-jax/ssm/distributions/expfam.py:98, in ExponentialFamilyDistribution.compute_conditional(cls, data, weights, prior)
     95     stats = tree_map(np.add, stats, prior.natural_parameters)
     97 # Compute the conditional distribution given the stats
---> 98 return cls.compute_conditional_from_stats(stats)

File ~/tmp/ssm-jax/ssm/distributions/expfam.py:75, in ExponentialFamilyDistribution.compute_conditional_from_stats(cls, stats)
     73 @classmethod
     74 def compute_conditional_from_stats(cls, stats):
---> 75     return get_prior(cls).from_natural_parameters(stats)

File ~/tmp/ssm-jax/ssm/distributions/niw.py:69, in NormalInverseWishart.from_natural_parameters(cls, natural_params)
     67 loc = np.einsum("...i,...->...i", s2, 1 / mean_precision)
     68 scale = s3 - np.einsum("...,...i,...j->...ij", mean_precision, loc, loc)
---> 69 return cls(loc, mean_precision, df, scale)

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:474, in JointDistributionNamed.__new__(cls, *args, **kwargs)
    470   model = kwargs.get('model')
    472 if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
    473            for d in tf.nest.flatten(model)):
--> 474   return _JointDistributionNamed(*args, **kwargs)
    475 return super(JointDistributionNamed, cls).__new__(cls)

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:323, in _JointDistributionNamed.__init__(self, model, batch_ndims, use_vectorized_map, validate_args, experimental_use_kahan_sum, name)
    287 def __init__(self,
    288              model,
    289              batch_ndims=None,
   (...)
    292              experimental_use_kahan_sum=False,
    293              name=None):
    294   """Construct the `JointDistributionNamed` distribution.
    295 
    296   Args:
   (...)
    321       Default value: `None` (i.e., `"JointDistributionNamed"`).
    322   """
--> 323   super(_JointDistributionNamed, self).__init__(
    324       model,
    325       batch_ndims=batch_ndims,
    326       use_vectorized_map=use_vectorized_map,
    327       validate_args=validate_args,
    328       experimental_use_kahan_sum=experimental_use_kahan_sum,
    329       name=name or 'JointDistributionNamed')

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_sequential.py:362, in _JointDistributionSequential.__init__(self, model, batch_ndims, use_vectorized_map, validate_args, experimental_use_kahan_sum, name)
    360 self._model_trackable = model
    361 self._model = self._no_dependency(model)
--> 362 self._build(model)
    364 super(_JointDistributionSequential, self).__init__(
    365     dtype=None,  # Ignored; we'll override.
    366     batch_ndims=batch_ndims,
   (...)
    370     experimental_use_kahan_sum=experimental_use_kahan_sum,
    371     name=name)
    373 # If the model consists entirely of prebuilt distributions with no
    374 # dependencies, cache them directly to avoid a sample call down the road.

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:334, in _JointDistributionNamed._build(self, model)
    332 """Creates `dist_fn`, `dist_fn_wrapped`, `dist_fn_args`, `dist_fn_name`."""
    333 if not _is_dict_like(model):
--> 334   raise TypeError('`model` must be convertible to `dict` (saw: {}).'.format(
    335       type(model).__name__))
    336 [
    337     self._dist_fn,
    338     self._dist_fn_wrapped,
    339     self._dist_fn_args,
    340     self._dist_fn_name,  # JointDistributionSequential doesn't have this.
    341 ] = _prob_chain_rule_model_flatten(model)

TypeError: `model` must be convertible to `dict` (saw: DeviceArray).
lx704612715 commented 1 year ago

It is solved by downgrading the tensorflow_probability to 1.7.0.