tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.25k stars 1.1k forks source link

Dirichlet distribution sampling issue when jit_compile=True #1789

Open LorenzoRimella opened 7 months ago

LorenzoRimella commented 7 months ago

It seems that some seeds produce nans when sampling from a Dirichlet distribution. Any idea why? Example script below that was tested on Google Colab.

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

dirichlet_lambda = tf.convert_to_tensor([2., 5., 0., 10., 10., 12., 10., 10., 1., 1.], dtype = tf.float32)
seed_s2 = tf.convert_to_tensor([-1012227931,  -757448172], dtype = tf.int32)
seed_s3 = tf.convert_to_tensor([-1012227931,  -757448170], dtype = tf.int32)

@tf.function(jit_compile = True)
def jitwhat(concentration, seed):
    theta_j_k = tfp.distributions.Dirichlet(concentration = concentration).sample((13, 10), seed = seed) #.sample(seed = seed_s2) #

    return theta_j_k

foo = jitwhat(dirichlet_lambda, seed_s2)
np.where(np.isnan(foo))

Note that the Dirichlet distribution is "degenerate" as it has one of the parameters that is zero. However generally the output from the sampling method is just a zero in the corresponding position, while with that specific seed it gives NaN.

chrism0dwk commented 7 months ago

Verified as a potential bug. Colab here.