tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.09k forks source link

Struggles with Gaussian Mixture Models in jax port of tensorflow_distributions #1730

Open xaviergonzalez opened 1 year ago

xaviergonzalez commented 1 year ago

I am getting the strangest bugs when trying to make a Guassian mixture model class in the jax substrate of tfd, has anyone experienced this before or know what the correct course of action is? Basically, I either get an annoying error message and the right length of a sample, or no error message but the wrong length of a sample.

When I try this code

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
class GMM(tfd.MixtureSameFamily):
  def __init__(self, locs, log_scales, log_weights):
    self.locs = locs
    self.log_scales = log_scales
    self.log_weights = log_weights
    mixture_dist = tfd.Categorical(logits=log_weights)
    component_dist = tfd.Normal(loc=locs, scale=jnp.exp(log_scales))
    super().__init__(mixture_dist, component_dist)

gmm = GMM(jnp.array([0.,1.]), jnp.array([0.,1.]), jnp.array([0.,0.]))
gmm.sample( sample_shape=(), seed=jr.PRNGKey(0))

I get this error message WARNING:root: Distribution subclass GMM inherits _parameter_properties from its parent (MixtureSameFamily) while also redefininginit. The inherited annotations cover the following parameters: dict_keys(['mixture_distribution', 'components_distribution']). It is likely that these do not match the subclass parameters. This may lead to errors when computing batch shapes, slicing into batch dimensions, calling.copy(), flattening the distribution as a CompositeTensor (e.g., when it is passed or returned from atf.function), and possibly other cases. The recommended pattern for distribution subclasses is to define a new _parameter_propertiesmethod with the subclass parameters, and to store the corresponding parameter values asself._parametersininit`, after calling the superclass constructor:

class MySubclass(tfd.SomeDistribution):

  def __init__(self, param_a, param_b):
    parameters = dict(locals())
    # ... do subclass initialization ...
    super(MySubclass, self).__init__(**base_class_params)
    # Ensure that the subclass (not base class) parameters are stored.
    self._parameters = parameters

  def _parameter_properties(self, dtype, num_classes=None):
    return dict(
      # Annotations may optionally specify properties, such as `event_ndims`,
      # `default_constraining_bijector_fn`, `specifies_shape`, etc.; see
      # the `ParameterProperties` documentation for details.
      param_a=tfp.util.ParameterProperties(),
      param_b=tfp.util.ParameterProperties())

WARNING:root: Distribution subclass GMM inherits _parameter_properties from its parent (MixtureSameFamily) while also redefininginit. The inherited annotations cover the following parameters: dict_keys(['mixture_distribution', 'components_distribution']). It is likely that these do not match the subclass parameters. This may lead to errors when computing batch shapes, slicing into batch dimensions, calling.copy(), flattening the distribution as a CompositeTensor (e.g., when it is passed or returned from atf.function), and possibly other cases. The recommended pattern for distribution subclasses is to define a new _parameter_propertiesmethod with the subclass parameters, and to store the corresponding parameter values asself._parametersininit`, after calling the superclass constructor:

class MySubclass(tfd.SomeDistribution):

  def __init__(self, param_a, param_b):
    parameters = dict(locals())
    # ... do subclass initialization ...
    super(MySubclass, self).__init__(**base_class_params)
    # Ensure that the subclass (not base class) parameters are stored.
    self._parameters = parameters

  def _parameter_properties(self, dtype, num_classes=None):
    return dict(
      # Annotations may optionally specify properties, such as `event_ndims`,
      # `default_constraining_bijector_fn`, `specifies_shape`, etc.; see
      # the `ParameterProperties` documentation for details.
      param_a=tfp.util.ParameterProperties(),
      param_b=tfp.util.ParameterProperties())

Array(1.0883901, dtype=float32) but at least only one scalar gets sampled.

When I then try to follow the recommendations of the error message with this code

class GMM(tfd.MixtureSameFamily):
  def __init__(self, locs, log_scales, log_weights):
    parameters = dict(locals())
    self.locs = locs
    self.log_scales = log_scales
    self.log_weights = log_weights
    mixture_dist = tfd.Categorical(logits=log_weights)
    component_dist = tfd.Normal(loc=locs, scale=jnp.exp(log_scales))
    super().__init__(mixture_dist, component_dist)
    self._parameters = parameters

  def _parameter_properties(self, dtype=jnp.float32, num_classes=None):
    return dict(
      locs=tfp.util.ParameterProperties(),
      log_scales=tfp.util.ParameterProperties(),
      log_weights=tfp.util.ParameterProperties())
gmm = GMM(jnp.array([0.,1.]), jnp.array([0.,1.]), jnp.array([0.,0.]))
gmm.sample( sample_shape=(), seed=jr.PRNGKey(0))

The error message disappears, but now I get a sample from the two different mixtures, which is not what I want! Array([ 1.85066 , -2.4407113], dtype=float32)

It seems related to these issues:

brianwa84 commented 1 year ago

You might need tfp.util.ParameterProperties(event_ndims=1) for all of your parameters. It's awkwardly named, but basically indicates how many final dimensions of each parameter get consumed to produce a single event.

vibhaw1904 commented 1 year ago

hey @brianwa84 could you please assign me the isssue