jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.36k stars 2.78k forks source link

[DOC] Incorrect interval in `jax.nn.log_softmax` #23992

Open Qazalbash opened 1 month ago

Qazalbash commented 1 month ago

Note that the mathematically correct interval should be

:math:`(-\infty,0)`

https://github.com/jax-ml/jax/blob/6790b90f91b0fbb8d818de8878f99eb3c4a871c2/jax/_src/nn/functions.py#L505 https://github.com/jax-ml/jax/blob/6790b90f91b0fbb8d818de8878f99eb3c4a871c2/jax/_src/scipy/special.py#L2632

justinjfu commented 4 weeks ago

I think the documentation should follow the code and not necessarily the mathematically correct definition, but either way the docs are wrong: This function does potentially output the floating point representation of -inf, but it can also output 0 due to rounding. So the interval should either be [-inf, 0] following code or (-inf, 0) following math.

>>> jax.nn.log_softmax(jnp.array([-jnp.inf, 0, 1])) # outputs -inf
Array([       -inf, -1.3132616 , -0.31326163], dtype=float32)

>>> jax.nn.log_softmax(jnp.array([-jnp.inf, -1000, 1])) # outputs 0
Array([  -inf, -1001.,     0.], dtype=float32)

@jakevdp or @mattjj thoughts? I can submit a quick PR to update this if it's a good idea.