jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.37k stars 2.78k forks source link

truncated_normal can't handle extreme bounds #10951

Open matthewdhoffman opened 2 years ago

matthewdhoffman commented 2 years ago

jax.random.truncated_normal produces deeply wrong results when you set the lower bound too high or the upper bound too low. For example,

jax.random.truncated_normal(
    jax.random.PRNGKey(1), jnp.array([5., 5.25, 5.5]), 1e6, (3,))

results in

DeviceArray([5.1665773e+00, 9.9999994e+05, 9.9999994e+05], dtype=float32)

The transition is suspiciously close to the point at which erf runs out of precision:

lax.erf(jnp.array([5., 5.25, 5.5]) / np.sqrt(2))

gives

DeviceArray([0.9999994, 1.       , 1.       ], dtype=float32)

So this may be a fundamental limitation of the sampling-by-inversion scheme used in truncated_normal.

jakevdp commented 2 years ago

Thanks for the report, I'm taking a look