I am trying to run the example code "GaussianHMM." However, I got an error saying "TypeError: model must be convertible to dict (saw: DeviceArray).".
I searched for this error in the Jax community but could not find a solution. Could you please help me out? Thank you very much!
from ssm.hmm import GaussianHMM
import jax.random as jr
# create a true HMM model
hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(0))
states, data = hmm.sample(key=jr.PRNGKey(1), num_steps=500, num_samples=5)
# create a test HMM model
test_hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(32))
# fit it to our sampled data
log_probs, fitted_model, posteriors = test_hmm.fit(data, method="em")
Initializing...
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[6], line 12
9 test_hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(32))
11 # fit it to our sampled data
---> 12 log_probs, fitted_model, posteriors = test_hmm.fit(data, method="em")
File ~/tmp/ssm-jax/ssm/utils.py:254, in ensure_has_batch_dim.<locals>.ensure_has_batch_dim_decorator.<locals>.wrapper(*args, **kwargs)
250 if key in bound_args.arguments and bound_args.arguments[key] is not None:
251 bound_args.arguments[key] = \
252 tree_map(lambda x: x[None, ...], bound_args.arguments[key])
--> 254 return f(**bound_args.arguments)
File ~/tmp/ssm-jax/ssm/hmm/base.py:201, in HMM.fit(self, data, covariates, metadata, method, num_iters, tol, initialization_method, key, verbosity)
199 if initialization_method is not None:
200 if verbosity >= Verbosity.LOUD : print("Initializing...")
--> 201 self.initialize(key, data, method=initialization_method)
202 if verbosity >= Verbosity.LOUD: print("Done.", flush=True)
204 if method == "em":
File ~/tmp/ssm-jax/ssm/utils.py:254, in ensure_has_batch_dim.<locals>.ensure_has_batch_dim_decorator.<locals>.wrapper(*args, **kwargs)
250 if key in bound_args.arguments and bound_args.arguments[key] is not None:
251 bound_args.arguments[key] = \
252 tree_map(lambda x: x[None, ...], bound_args.arguments[key])
--> 254 return f(**bound_args.arguments)
File ~/tmp/ssm-jax/ssm/hmm/base.py:132, in HMM.initialize(self, key, data, covariates, metadata, method)
129 dummy_posteriors = DummyPosterior(one_hot(assignments, self._num_states))
131 # Do one m-step with the dummy posteriors
--> 132 self._emissions.m_step(data, dummy_posteriors)
File ~/tmp/ssm-jax/ssm/hmm/emissions.py:161, in ExponentialFamilyEmissions.m_step(self, dataset, posteriors, covariates, metadata)
145 def m_step(self, dataset, posteriors, covariates=None, metadata=None) -> ExponentialFamilyEmissions:
146 """Update the emissions distribution using an M-step.
147
148 Operates over a batch of data (posterior must have the same batch dim).
(...)
159 emissions (ExponentialFamilyEmissions): updated emissions object
160 """
--> 161 conditional = self._emissions_distribution_class.compute_conditional(
162 dataset, weights=posteriors.expected_states, prior=self._prior)
163 self._distribution = self._emissions_distribution_class.from_params(
164 conditional.mode())
165 return self
File ~/tmp/ssm-jax/ssm/distributions/expfam.py:98, in ExponentialFamilyDistribution.compute_conditional(cls, data, weights, prior)
95 stats = tree_map(np.add, stats, prior.natural_parameters)
97 # Compute the conditional distribution given the stats
---> 98 return cls.compute_conditional_from_stats(stats)
File ~/tmp/ssm-jax/ssm/distributions/expfam.py:75, in ExponentialFamilyDistribution.compute_conditional_from_stats(cls, stats)
73 @classmethod
74 def compute_conditional_from_stats(cls, stats):
---> 75 return get_prior(cls).from_natural_parameters(stats)
File ~/tmp/ssm-jax/ssm/distributions/niw.py:69, in NormalInverseWishart.from_natural_parameters(cls, natural_params)
67 loc = np.einsum("...i,...->...i", s2, 1 / mean_precision)
68 scale = s3 - np.einsum("...,...i,...j->...ij", mean_precision, loc, loc)
---> 69 return cls(loc, mean_precision, df, scale)
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:474, in JointDistributionNamed.__new__(cls, *args, **kwargs)
470 model = kwargs.get('model')
472 if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
473 for d in tf.nest.flatten(model)):
--> 474 return _JointDistributionNamed(*args, **kwargs)
475 return super(JointDistributionNamed, cls).__new__(cls)
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:323, in _JointDistributionNamed.__init__(self, model, batch_ndims, use_vectorized_map, validate_args, experimental_use_kahan_sum, name)
287 def __init__(self,
288 model,
289 batch_ndims=None,
(...)
292 experimental_use_kahan_sum=False,
293 name=None):
294 """Construct the `JointDistributionNamed` distribution.
295
296 Args:
(...)
321 Default value: `None` (i.e., `"JointDistributionNamed"`).
322 """
--> 323 super(_JointDistributionNamed, self).__init__(
324 model,
325 batch_ndims=batch_ndims,
326 use_vectorized_map=use_vectorized_map,
327 validate_args=validate_args,
328 experimental_use_kahan_sum=experimental_use_kahan_sum,
329 name=name or 'JointDistributionNamed')
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_sequential.py:362, in _JointDistributionSequential.__init__(self, model, batch_ndims, use_vectorized_map, validate_args, experimental_use_kahan_sum, name)
360 self._model_trackable = model
361 self._model = self._no_dependency(model)
--> 362 self._build(model)
364 super(_JointDistributionSequential, self).__init__(
365 dtype=None, # Ignored; we'll override.
366 batch_ndims=batch_ndims,
(...)
370 experimental_use_kahan_sum=experimental_use_kahan_sum,
371 name=name)
373 # If the model consists entirely of prebuilt distributions with no
374 # dependencies, cache them directly to avoid a sample call down the road.
File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:334, in _JointDistributionNamed._build(self, model)
332 """Creates `dist_fn`, `dist_fn_wrapped`, `dist_fn_args`, `dist_fn_name`."""
333 if not _is_dict_like(model):
--> 334 raise TypeError('`model` must be convertible to `dict` (saw: {}).'.format(
335 type(model).__name__))
336 [
337 self._dist_fn,
338 self._dist_fn_wrapped,
339 self._dist_fn_args,
340 self._dist_fn_name, # JointDistributionSequential doesn't have this.
341 ] = _prob_chain_rule_model_flatten(model)
TypeError: `model` must be convertible to `dict` (saw: DeviceArray).
Dear all,
I am trying to run the example code "GaussianHMM." However, I got an error saying "TypeError:
model
must be convertible todict
(saw: DeviceArray).".I searched for this error in the Jax community but could not find a solution. Could you please help me out? Thank you very much!