Open Qazalbash opened 1 month ago
I think the documentation should follow the code and not necessarily the mathematically correct definition, but either way the docs are wrong: This function does potentially output the floating point representation of -inf, but it can also output 0 due to rounding. So the interval should either be [-inf, 0] following code or (-inf, 0) following math.
>>> jax.nn.log_softmax(jnp.array([-jnp.inf, 0, 1])) # outputs -inf
Array([ -inf, -1.3132616 , -0.31326163], dtype=float32)
>>> jax.nn.log_softmax(jnp.array([-jnp.inf, -1000, 1])) # outputs 0
Array([ -inf, -1001., 0.], dtype=float32)
@jakevdp or @mattjj thoughts? I can submit a quick PR to update this if it's a good idea.
Note that the mathematically correct interval should be
https://github.com/jax-ml/jax/blob/6790b90f91b0fbb8d818de8878f99eb3c4a871c2/jax/_src/nn/functions.py#L505 https://github.com/jax-ml/jax/blob/6790b90f91b0fbb8d818de8878f99eb3c4a871c2/jax/_src/scipy/special.py#L2632