Open LorenzoRimella opened 7 months ago
It seems that some seeds produce nans when sampling from a Dirichlet distribution. Any idea why? Example script below that was tested on Google Colab.
import numpy as np import tensorflow as tf import tensorflow_probability as tfp dirichlet_lambda = tf.convert_to_tensor([2., 5., 0., 10., 10., 12., 10., 10., 1., 1.], dtype = tf.float32) seed_s2 = tf.convert_to_tensor([-1012227931, -757448172], dtype = tf.int32) seed_s3 = tf.convert_to_tensor([-1012227931, -757448170], dtype = tf.int32) @tf.function(jit_compile = True) def jitwhat(concentration, seed): theta_j_k = tfp.distributions.Dirichlet(concentration = concentration).sample((13, 10), seed = seed) #.sample(seed = seed_s2) # return theta_j_k foo = jitwhat(dirichlet_lambda, seed_s2) np.where(np.isnan(foo))
Note that the Dirichlet distribution is "degenerate" as it has one of the parameters that is zero. However generally the output from the sampling method is just a zero in the corresponding position, while with that specific seed it gives NaN.
Verified as a potential bug. Colab here.
It seems that some seeds produce nans when sampling from a Dirichlet distribution. Any idea why? Example script below that was tested on Google Colab.
Note that the Dirichlet distribution is "degenerate" as it has one of the parameters that is zero. However generally the output from the sampling method is just a zero in the corresponding position, while with that specific seed it gives NaN.