google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 165 forks source link

Add `axis` and `where` arguments to loss functions #902

Open carlosgmartin opened 3 months ago

carlosgmartin commented 3 months ago

Feature request: Add the following arguments:

to the following loss functions:

I can submit a PR for this.

vroulet commented 3 months ago

Hello @carlosgmartin,

Yes, that would be great. Thanks for catching this!

carlosgmartin commented 3 months ago

Great. Once #898 is merged I'll put together a PR.