tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

Feature Request: Gumbel Mixture Models #1598

Open bryorsnef opened 1 year ago

bryorsnef commented 1 year ago

It is possible to construct reparameterizable mixture distributions by replacing the categorical distribution with a gumbel (relaxed categorical) distribution. The ability to use a relaxed one-hot categorical distribution in mixture or mixtureSameFamily would be potentially very useful.

Differentiable mixture distributions implemented in torch here: https://github.com/nextBillyonair/DPM/blob/master/dpm/distributions/gumbel_mixture_model.py

bryorsnef commented 1 year ago

I believe https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/RelaxedOneHotCategorical is what you are looking for (the distribution goes under the name Relaxed One Hot Categorical, Gumbel Softmax and Concrete in the literature).

Yeah, the distribution is implemented, I was saying it would be useful to also have a mixture and mixtureSameFamily meta distribution that can take a relaxed one hot categorical as the cat argument instead of the categorical.

bryorsnef commented 1 year ago

From the looks of https://github.com/tensorflow/probability/blob/v0.17.0/tensorflow_probability/python/distributions/mixture_same_family.py#L266-L270 the changes needed to allow this don't look too complicated. In _sample_n, the mask can be replaced with samples from the relaxed_one_hot_categorical mixture selecting distribution in lines 266-270

mask = tf.one_hot( indices=mix_sample, # [n, B] depth=num_components, on_value=npdt(1), off_value=npdt(0)) # [n, B, k]

and remove this check for the mixture distribution dtypes.

if is_init and not dtype_util.is_integer(self.mixture_distribution.dtype): raise ValueError( 'mixture_distribution.dtype({}) is not over integers'.format( dtype_util.name(self.mixture_distribution.dtype)))

Other methods, like _log_prob, do not appear to need any changes.

brianwa84 commented 1 year ago

The reparameterize=True setting on MixtureSameFamily gives you an unbiased version of this using implicit differentiation. Are there settings where a soft mixture would be superior?

On Thu, Aug 4, 2022, 2:32 PM Bryor Snefjella @.***> wrote:

From the looks of https://github.com/tensorflow/probability/blob/v0.17.0/tensorflow_probability/python/distributions/mixture_same_family.py#L266-L270 the changes needed to allow this don't look too complicated. In _sample_n, the mask can be replaced with samples from the relaxed_one_hot_categorical mixture selecting distribution in lines 266-270

mask = tf.one_hot( indices=mix_sample, # [n, B] depth=num_components, on_value=npdt(1), off_value=npdt(0)) # [n, B, k]

and remove this check for the mixture distribution dtypes.

if is_init and not dtype_util.is_integer(self.mixture_distribution.dtype): raise ValueError( 'mixture_distribution.dtype({}) is not over integers'.format( dtype_util.name(self.mixture_distribution.dtype)))

Other methods, like _log_prob, do not appear to need any changes.

— Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1598#issuecomment-1205623735, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI7Y4LBNZATOTAGZXGTVXQEDNANCNFSM55QKL4BA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

bryorsnef commented 1 year ago

My network gives nans whenever I've tried setting that to true (IWAE with a mixture proposal distribution). I was able to set up a gumbel mixture, no nans. Didn't look deeply in where they were coming from.

On Tue, Oct 11, 2022, 12:33 AM Brian Patton @.***> wrote:

The reparameterize=True setting on MixtureSameFamily gives you an unbiased version of this using implicit differentiation. Are there settings where a soft mixture would be superior?

On Thu, Aug 4, 2022, 2:32 PM Bryor Snefjella @.***> wrote:

From the looks of

https://github.com/tensorflow/probability/blob/v0.17.0/tensorflow_probability/python/distributions/mixture_same_family.py#L266-L270 the changes needed to allow this don't look too complicated. In _sample_n, the mask can be replaced with samples from the relaxed_one_hot_categorical mixture selecting distribution in lines 266-270

mask = tf.one_hot( indices=mix_sample, # [n, B] depth=num_components, on_value=npdt(1), off_value=npdt(0)) # [n, B, k]

and remove this check for the mixture distribution dtypes.

if is_init and not dtype_util.is_integer(self.mixture_distribution.dtype): raise ValueError( 'mixture_distribution.dtype({}) is not over integers'.format( dtype_util.name(self.mixture_distribution.dtype)))

Other methods, like _log_prob, do not appear to need any changes.

— Reply to this email directly, view it on GitHub < https://github.com/tensorflow/probability/issues/1598#issuecomment-1205623735 , or unsubscribe < https://github.com/notifications/unsubscribe-auth/AFJFSI7Y4LBNZATOTAGZXGTVXQEDNANCNFSM55QKL4BA

. You are receiving this because you are subscribed to this thread.Message ID: @.***>

— Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1598#issuecomment-1274220656, or unsubscribe https://github.com/notifications/unsubscribe-auth/AIDPIFB62FXYHLDB5QSWQ6TWCUJ6FANCNFSM55QKL4BA . You are receiving this because you authored the thread.Message ID: @.***>