Open xaviergonzalez opened 1 year ago
You might need tfp.util.ParameterProperties(event_ndims=1) for all of your parameters. It's awkwardly named, but basically indicates how many final dimensions of each parameter get consumed to produce a single event.
hey @brianwa84 could you please assign me the isssue
I am getting the strangest bugs when trying to make a Guassian mixture model class in the jax substrate of tfd, has anyone experienced this before or know what the correct course of action is? Basically, I either get an annoying error message and the right length of a sample, or no error message but the wrong length of a sample.
When I try this code
I get this error message WARNING:root: Distribution subclass GMM inherits
_parameter_properties from its parent (MixtureSameFamily) while also redefining
init. The inherited annotations cover the following parameters: dict_keys(['mixture_distribution', 'components_distribution']). It is likely that these do not match the subclass parameters. This may lead to errors when computing batch shapes, slicing into batch dimensions, calling
.copy(), flattening the distribution as a CompositeTensor (e.g., when it is passed or returned from a
tf.function), and possibly other cases. The recommended pattern for distribution subclasses is to define a new
_parameter_propertiesmethod with the subclass parameters, and to store the corresponding parameter values as
self._parametersin
init`, after calling the superclass constructor:WARNING:root: Distribution subclass GMM inherits
_parameter_properties from its parent (MixtureSameFamily) while also redefining
init. The inherited annotations cover the following parameters: dict_keys(['mixture_distribution', 'components_distribution']). It is likely that these do not match the subclass parameters. This may lead to errors when computing batch shapes, slicing into batch dimensions, calling
.copy(), flattening the distribution as a CompositeTensor (e.g., when it is passed or returned from a
tf.function), and possibly other cases. The recommended pattern for distribution subclasses is to define a new
_parameter_propertiesmethod with the subclass parameters, and to store the corresponding parameter values as
self._parametersin
init`, after calling the superclass constructor:Array(1.0883901, dtype=float32) but at least only one scalar gets sampled.
When I then try to follow the recommendations of the error message with this code
The error message disappears, but now I get a sample from the two different mixtures, which is not what I want! Array([ 1.85066 , -2.4407113], dtype=float32)
It seems related to these issues: