jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

Edge case: Normal CDF incorrect when called with extremely small integer value if location is specified as integer array #24059

Open jleugeri opened 1 month ago

jleugeri commented 1 month ago

Description

Computing the CDF of a normal distribution gives the wrong result if the location is specified as an integer array, and CDF is called with the smallest legal int32 value.

MWE: The following should all return 0.0, as the CDF $F(x)$ of a Gaussian should approach 0 for $x \to -\infty$

print(dists.Normal(jnp.array([1]),0.01).cdf(jnp.astype(-jnp.inf, int)))
print(dists.Normal(jnp.array([1], dtype=int),0.01).cdf(-2147483648))
print(dists.Normal(jnp.array([1], dtype=int),1.0).cdf(-2147483648))
print(jnp.exp(dists.Normal(jnp.array([1], dtype=int),1.0).log_cdf(-2147483648))) # <-
print(dists.Normal(jnp.array([1], dtype=float),0.01).cdf(-2147483648))
print(dists.Normal(jnp.array([1]),0.01).cdf(-2147483647))
print(dists.Normal(  1,0.01).cdf(-2147483648))
print(dists.Normal(1.0,0.01).cdf(-2147483648))
print(dists.Normal(jnp.array([1]),0.01).cdf(-jnp.inf))

But instead it yields:

[1.]
[1.]
[1.]
[0.] # <-
[0.]
[0.]
0.0
0.0
[0.]

Note that log_cdf seems to be OK.

A workaround for me is to explicitly cast the location vector to float. This is a highly unlikely edge-case, but it did occur for me during a real-world problem (computing mutual information).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.33
jaxlib: 0.4.33
numpy:  2.1.1
python: 3.12.5 (main, Aug 22 2024, 08:14:36) [Clang 15.0.0 (clang-1500.1.0.2.5)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Beauty.local', release='24.0.0', version='Darwin Kernel Version 24.0.0: Mon Aug 12 20:52:12 PDT 2024; root:xnu-11215.1.10~2/RELEASE_ARM64_T6020', machine='arm64')
jakevdp commented 4 weeks ago

Hi - thanks for the report! Can you clarify what dists.Normal is here? JAX doesn't have any API with that name.