google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.63k stars 174 forks source link

Avoid multiplication of boolean arrays #840

Closed copybara-service[bot] closed 6 months ago

copybara-service[bot] commented 6 months ago

Avoid multiplication of boolean arrays

Instead, use a jnp.where statement. That avoids type promotion of boolean arrays to integer arrays.

Fixes issue https://github.com/google-deepmind/optax/issues/828

Also made cosmetic changes to the docstring