Open georgematheos opened 1 month ago
@derifatives indicated that tfp.TruncatedNormal.sample
wraps jax.random.truncated_normal
, here. (We may be misunderstanding when this function is called.)
However, note that jax.random.truncated_normal
can be used to sample from the above truncated normal distribution without clear issues:
mean, std, minval, maxval = 0.5382424, 0.05, 0.80921564, 0.86921564
minval_centered, maxval_centered = (minval - mean) / std, (maxval - mean) / std
centered_sample = jax.random.truncated_normal(jax.random.PRNGKey(2), minval_centered, maxval_centered)
sample = centered_sample * std + mean
sample # = 0.80921566
However, this does not always work:
mean, std, minval, maxval = 0.09121108, 0.1, 0.62490195, 0.6849019
minval_centered, maxval_centered = (minval - mean) / std, (maxval - mean) / std
centered_sample = jax.random.truncated_normal(jax.random.PRNGKey(2), minval_centered, maxval_centered)
sample = centered_sample * std + mean
sample # = NaN
(I am finding these strange seeming configurations of numbers by running a fairly complex probabilistic inference program I have that is sampling millions of times from TruncatedNormals, and then filtering the results to find where NaNs were generated.)
Example:
returns NaN.
JAX version: 0.4.33. TFP version: 0.23.0.