Closed carlosgmartin closed 3 months ago
Would perhaps using logsumexp (https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.logsumexp.html) be a solution here?
Also, please add tests :-)
@fabianp Added tests.
log_softmax
uses logsumexp
internally: https://github.com/google-deepmind/optax/issues/896#issuecomment-2032688959.
896