google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.57k stars 169 forks source link

Fix softmax_cross_entropy to handle -inf logits correctly when corresponding label is 0. #898

Closed carlosgmartin closed 3 months ago

carlosgmartin commented 3 months ago

896

fabianp commented 3 months ago

Would perhaps using logsumexp (https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.logsumexp.html) be a solution here?

Also, please add tests :-)

carlosgmartin commented 3 months ago

@fabianp Added tests.

log_softmax uses logsumexp internally: https://github.com/google-deepmind/optax/issues/896#issuecomment-2032688959.