Open matthewdhoffman opened 2 years ago
jax.random.truncated_normal produces deeply wrong results when you set the lower bound too high or the upper bound too low. For example,
jax.random.truncated_normal
jax.random.truncated_normal( jax.random.PRNGKey(1), jnp.array([5., 5.25, 5.5]), 1e6, (3,))
results in
DeviceArray([5.1665773e+00, 9.9999994e+05, 9.9999994e+05], dtype=float32)
The transition is suspiciously close to the point at which erf runs out of precision:
erf
lax.erf(jnp.array([5., 5.25, 5.5]) / np.sqrt(2))
gives
DeviceArray([0.9999994, 1. , 1. ], dtype=float32)
So this may be a fundamental limitation of the sampling-by-inversion scheme used in truncated_normal.
truncated_normal
Thanks for the report, I'm taking a look
jax.random.truncated_normal
produces deeply wrong results when you set the lower bound too high or the upper bound too low. For example,results in
The transition is suspiciously close to the point at which
erf
runs out of precision:gives
So this may be a fundamental limitation of the sampling-by-inversion scheme used in
truncated_normal
.