google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.65k stars 181 forks source link

Correct handling of -inf in softmax_cross_entropy. Fix #898. #916

Closed copybara-service[bot] closed 5 months ago

copybara-service[bot] commented 5 months ago

Correct handling of -inf in softmax_cross_entropy. Fix #898. Added tests for correct gradient evaluations.