tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.1k forks source link

Sampling from TruncatedNormal can yield NaN #1844

Open georgematheos opened 1 month ago

georgematheos commented 1 month ago

Example:

from tensorflow_probability.substrates import jax as tfp
tfp.distributions.TruncatedNormal(
    0.5382424, 0.05, 0.80921564, 0.86921564
).sample(seed=jax.random.PRNGKey(2))

returns NaN.

JAX version: 0.4.33. TFP version: 0.23.0.

georgematheos commented 1 month ago

@derifatives indicated that tfp.TruncatedNormal.sample wraps jax.random.truncated_normal, here. (We may be misunderstanding when this function is called.)

However, note that jax.random.truncated_normal can be used to sample from the above truncated normal distribution without clear issues:

mean, std, minval, maxval = 0.5382424, 0.05, 0.80921564, 0.86921564
minval_centered, maxval_centered = (minval - mean) / std, (maxval - mean) / std
centered_sample = jax.random.truncated_normal(jax.random.PRNGKey(2), minval_centered, maxval_centered)
sample = centered_sample * std + mean
sample # = 0.80921566
georgematheos commented 1 month ago

However, this does not always work:

mean, std, minval, maxval = 0.09121108, 0.1, 0.62490195, 0.6849019
minval_centered, maxval_centered = (minval - mean) / std, (maxval - mean) / std
centered_sample = jax.random.truncated_normal(jax.random.PRNGKey(2), minval_centered, maxval_centered)
sample = centered_sample * std + mean
sample # = NaN

(I am finding these strange seeming configurations of numbers by running a fairly complex probabilistic inference program I have that is sampling millions of times from TruncatedNormals, and then filtering the results to find where NaNs were generated.)